MLIR  17.0.0git
AffineExpr.cpp
Go to the documentation of this file.
1 //===- AffineExpr.cpp - MLIR Affine Expr Classes --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "AffineExprDetail.h"
12 #include "mlir/IR/AffineExpr.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/IntegerSet.h"
17 #include "mlir/Support/TypeID.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include <numeric>
20 #include <optional>
21 
22 using namespace mlir;
23 using namespace mlir::detail;
24 
25 MLIRContext *AffineExpr::getContext() const { return expr->context; }
26 
27 AffineExprKind AffineExpr::getKind() const { return expr->kind; }
28 
29 /// Walk all of the AffineExprs in this subgraph in postorder.
30 void AffineExpr::walk(std::function<void(AffineExpr)> callback) const {
31  struct AffineExprWalker : public AffineExprVisitor<AffineExprWalker> {
32  std::function<void(AffineExpr)> callback;
33 
34  AffineExprWalker(std::function<void(AffineExpr)> callback)
35  : callback(std::move(callback)) {}
36 
37  void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { callback(expr); }
38  void visitConstantExpr(AffineConstantExpr expr) { callback(expr); }
39  void visitDimExpr(AffineDimExpr expr) { callback(expr); }
40  void visitSymbolExpr(AffineSymbolExpr expr) { callback(expr); }
41  };
42 
43  AffineExprWalker(std::move(callback)).walkPostOrder(*this);
44 }
45 
46 // Dispatch affine expression construction based on kind.
48  AffineExpr rhs) {
49  if (kind == AffineExprKind::Add)
50  return lhs + rhs;
51  if (kind == AffineExprKind::Mul)
52  return lhs * rhs;
53  if (kind == AffineExprKind::FloorDiv)
54  return lhs.floorDiv(rhs);
55  if (kind == AffineExprKind::CeilDiv)
56  return lhs.ceilDiv(rhs);
57  if (kind == AffineExprKind::Mod)
58  return lhs % rhs;
59 
60  llvm_unreachable("unknown binary operation on affine expressions");
61 }
62 
63 /// This method substitutes any uses of dimensions and symbols (e.g.
64 /// dim#0 with dimReplacements[0]) and returns the modified expression tree.
67  ArrayRef<AffineExpr> symReplacements) const {
68  switch (getKind()) {
70  return *this;
71  case AffineExprKind::DimId: {
72  unsigned dimId = cast<AffineDimExpr>().getPosition();
73  if (dimId >= dimReplacements.size())
74  return *this;
75  return dimReplacements[dimId];
76  }
78  unsigned symId = cast<AffineSymbolExpr>().getPosition();
79  if (symId >= symReplacements.size())
80  return *this;
81  return symReplacements[symId];
82  }
88  auto binOp = cast<AffineBinaryOpExpr>();
89  auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
90  auto newLHS = lhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
91  auto newRHS = rhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
92  if (newLHS == lhs && newRHS == rhs)
93  return *this;
94  return getAffineBinaryOpExpr(getKind(), newLHS, newRHS);
95  }
96  llvm_unreachable("Unknown AffineExpr");
97 }
98 
100  return replaceDimsAndSymbols(dimReplacements, {});
101 }
102 
105  return replaceDimsAndSymbols({}, symReplacements);
106 }
107 
108 /// Replace dims[offset ... numDims)
109 /// by dims[offset + shift ... shift + numDims).
110 AffineExpr AffineExpr::shiftDims(unsigned numDims, unsigned shift,
111  unsigned offset) const {
113  for (unsigned idx = 0; idx < offset; ++idx)
114  dims.push_back(getAffineDimExpr(idx, getContext()));
115  for (unsigned idx = offset; idx < numDims; ++idx)
116  dims.push_back(getAffineDimExpr(idx + shift, getContext()));
117  return replaceDimsAndSymbols(dims, {});
118 }
119 
120 /// Replace symbols[offset ... numSymbols)
121 /// by symbols[offset + shift ... shift + numSymbols).
122 AffineExpr AffineExpr::shiftSymbols(unsigned numSymbols, unsigned shift,
123  unsigned offset) const {
125  for (unsigned idx = 0; idx < offset; ++idx)
126  symbols.push_back(getAffineSymbolExpr(idx, getContext()));
127  for (unsigned idx = offset; idx < numSymbols; ++idx)
128  symbols.push_back(getAffineSymbolExpr(idx + shift, getContext()));
129  return replaceDimsAndSymbols({}, symbols);
130 }
131 
132 /// Sparse replace method. Return the modified expression tree.
135  auto it = map.find(*this);
136  if (it != map.end())
137  return it->second;
138  switch (getKind()) {
139  default:
140  return *this;
141  case AffineExprKind::Add:
142  case AffineExprKind::Mul:
145  case AffineExprKind::Mod:
146  auto binOp = cast<AffineBinaryOpExpr>();
147  auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
148  auto newLHS = lhs.replace(map);
149  auto newRHS = rhs.replace(map);
150  if (newLHS == lhs && newRHS == rhs)
151  return *this;
152  return getAffineBinaryOpExpr(getKind(), newLHS, newRHS);
153  }
154  llvm_unreachable("Unknown AffineExpr");
155 }
156 
157 /// Sparse replace method. Return the modified expression tree.
160  map.insert(std::make_pair(expr, replacement));
161  return replace(map);
162 }
163 /// Returns true if this expression is made out of only symbols and
164 /// constants (no dimensional identifiers).
166  switch (getKind()) {
168  return true;
170  return false;
172  return true;
173 
174  case AffineExprKind::Add:
175  case AffineExprKind::Mul:
178  case AffineExprKind::Mod: {
179  auto expr = this->cast<AffineBinaryOpExpr>();
180  return expr.getLHS().isSymbolicOrConstant() &&
181  expr.getRHS().isSymbolicOrConstant();
182  }
183  }
184  llvm_unreachable("Unknown AffineExpr");
185 }
186 
187 /// Returns true if this is a pure affine expression, i.e., multiplication,
188 /// floordiv, ceildiv, and mod is only allowed w.r.t constants.
190  switch (getKind()) {
194  return true;
195  case AffineExprKind::Add: {
196  auto op = cast<AffineBinaryOpExpr>();
197  return op.getLHS().isPureAffine() && op.getRHS().isPureAffine();
198  }
199 
200  case AffineExprKind::Mul: {
201  // TODO: Canonicalize the constants in binary operators to the RHS when
202  // possible, allowing this to merge into the next case.
203  auto op = cast<AffineBinaryOpExpr>();
204  return op.getLHS().isPureAffine() && op.getRHS().isPureAffine() &&
205  (op.getLHS().template isa<AffineConstantExpr>() ||
206  op.getRHS().template isa<AffineConstantExpr>());
207  }
210  case AffineExprKind::Mod: {
211  auto op = cast<AffineBinaryOpExpr>();
212  return op.getLHS().isPureAffine() &&
213  op.getRHS().template isa<AffineConstantExpr>();
214  }
215  }
216  llvm_unreachable("Unknown AffineExpr");
217 }
218 
219 // Returns the greatest known integral divisor of this affine expression.
221  AffineBinaryOpExpr binExpr(nullptr);
222  switch (getKind()) {
224  [[fallthrough]];
226  return 1;
228  [[fallthrough]];
230  // If the RHS is a constant and divides the known divisor on the LHS, the
231  // quotient is a known divisor of the expression.
232  binExpr = this->cast<AffineBinaryOpExpr>();
233  auto rhs = binExpr.getRHS().dyn_cast<AffineConstantExpr>();
234  // Leave alone undefined expressions.
235  if (rhs && rhs.getValue() != 0) {
236  int64_t lhsDiv = binExpr.getLHS().getLargestKnownDivisor();
237  if (lhsDiv % rhs.getValue() == 0)
238  return lhsDiv / rhs.getValue();
239  }
240  return 1;
241  }
243  return std::abs(this->cast<AffineConstantExpr>().getValue());
244  case AffineExprKind::Mul: {
245  binExpr = this->cast<AffineBinaryOpExpr>();
246  return binExpr.getLHS().getLargestKnownDivisor() *
247  binExpr.getRHS().getLargestKnownDivisor();
248  }
249  case AffineExprKind::Add:
250  [[fallthrough]];
251  case AffineExprKind::Mod: {
252  binExpr = cast<AffineBinaryOpExpr>();
253  return std::gcd((uint64_t)binExpr.getLHS().getLargestKnownDivisor(),
254  (uint64_t)binExpr.getRHS().getLargestKnownDivisor());
255  }
256  }
257  llvm_unreachable("Unknown AffineExpr");
258 }
259 
260 bool AffineExpr::isMultipleOf(int64_t factor) const {
261  AffineBinaryOpExpr binExpr(nullptr);
262  uint64_t l, u;
263  switch (getKind()) {
265  [[fallthrough]];
267  return factor * factor == 1;
269  return cast<AffineConstantExpr>().getValue() % factor == 0;
270  case AffineExprKind::Mul: {
271  binExpr = cast<AffineBinaryOpExpr>();
272  // It's probably not worth optimizing this further (to not traverse the
273  // whole sub-tree under - it that would require a version of isMultipleOf
274  // that on a 'false' return also returns the largest known divisor).
275  return (l = binExpr.getLHS().getLargestKnownDivisor()) % factor == 0 ||
276  (u = binExpr.getRHS().getLargestKnownDivisor()) % factor == 0 ||
277  (l * u) % factor == 0;
278  }
279  case AffineExprKind::Add:
282  case AffineExprKind::Mod: {
283  binExpr = cast<AffineBinaryOpExpr>();
284  return std::gcd((uint64_t)binExpr.getLHS().getLargestKnownDivisor(),
285  (uint64_t)binExpr.getRHS().getLargestKnownDivisor()) %
286  factor ==
287  0;
288  }
289  }
290  llvm_unreachable("Unknown AffineExpr");
291 }
292 
293 bool AffineExpr::isFunctionOfDim(unsigned position) const {
294  if (getKind() == AffineExprKind::DimId) {
295  return *this == mlir::getAffineDimExpr(position, getContext());
296  }
297  if (auto expr = this->dyn_cast<AffineBinaryOpExpr>()) {
298  return expr.getLHS().isFunctionOfDim(position) ||
299  expr.getRHS().isFunctionOfDim(position);
300  }
301  return false;
302 }
303 
304 bool AffineExpr::isFunctionOfSymbol(unsigned position) const {
305  if (getKind() == AffineExprKind::SymbolId) {
306  return *this == mlir::getAffineSymbolExpr(position, getContext());
307  }
308  if (auto expr = this->dyn_cast<AffineBinaryOpExpr>()) {
309  return expr.getLHS().isFunctionOfSymbol(position) ||
310  expr.getRHS().isFunctionOfSymbol(position);
311  }
312  return false;
313 }
314 
316  : AffineExpr(ptr) {}
318  return static_cast<ImplType *>(expr)->lhs;
319 }
321  return static_cast<ImplType *>(expr)->rhs;
322 }
323 
325 unsigned AffineDimExpr::getPosition() const {
326  return static_cast<ImplType *>(expr)->position;
327 }
328 
329 /// Returns true if the expression is divisible by the given symbol with
330 /// position `symbolPos`. The argument `opKind` specifies here what kind of
331 /// division or mod operation called this division. It helps in implementing the
332 /// commutative property of the floordiv and ceildiv operations. If the argument
333 ///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv
334 /// operation, then the commutative property can be used otherwise, the floordiv
335 /// operation is not divisible. The same argument holds for ceildiv operation.
336 static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
337  AffineExprKind opKind) {
338  // The argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
339  assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
340  opKind == AffineExprKind::CeilDiv) &&
341  "unexpected opKind");
342  switch (expr.getKind()) {
344  return expr.cast<AffineConstantExpr>().getValue() == 0;
346  return false;
348  return (expr.cast<AffineSymbolExpr>().getPosition() == symbolPos);
349  // Checks divisibility by the given symbol for both operands.
350  case AffineExprKind::Add: {
351  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
352  return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) &&
353  isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
354  }
355  // Checks divisibility by the given symbol for both operands. Consider the
356  // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
357  // this is a division by s1 and both the operands of modulo are divisible by
358  // s1 but it is not divisible by s1 always. The third argument is
359  // `AffineExprKind::Mod` for this reason.
360  case AffineExprKind::Mod: {
361  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
362  return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos,
364  isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos,
366  }
367  // Checks if any of the operand divisible by the given symbol.
368  case AffineExprKind::Mul: {
369  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
370  return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) ||
371  isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
372  }
373  // Floordiv and ceildiv are divisible by the given symbol when the first
374  // operand is divisible, and the affine expression kind of the argument expr
375  // is same as the argument `opKind`. This can be inferred from commutative
376  // property of floordiv and ceildiv operations and are as follow:
377  // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
378  // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
379  // It will fail if operations are not same. For example:
380  // (exps1 ceildiv exp2) floordiv exp3 can not be simplified.
383  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
384  if (opKind != expr.getKind())
385  return false;
386  return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind());
387  }
388  }
389  llvm_unreachable("Unknown AffineExpr");
390 }
391 
392 /// Divides the given expression by the given symbol at position `symbolPos`. It
393 /// considers the divisibility condition is checked before calling itself. A
394 /// null expression is returned whenever the divisibility condition fails.
395 static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos,
396  AffineExprKind opKind) {
397  // THe argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
398  assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
399  opKind == AffineExprKind::CeilDiv) &&
400  "unexpected opKind");
401  switch (expr.getKind()) {
403  if (expr.cast<AffineConstantExpr>().getValue() != 0)
404  return nullptr;
405  return getAffineConstantExpr(0, expr.getContext());
407  return nullptr;
409  return getAffineConstantExpr(1, expr.getContext());
410  // Dividing both operands by the given symbol.
411  case AffineExprKind::Add: {
412  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
413  return getAffineBinaryOpExpr(
414  expr.getKind(), symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind),
415  symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind));
416  }
417  // Dividing both operands by the given symbol.
418  case AffineExprKind::Mod: {
419  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
420  return getAffineBinaryOpExpr(
421  expr.getKind(),
422  symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
423  symbolicDivide(binaryExpr.getRHS(), symbolPos, expr.getKind()));
424  }
425  // Dividing any of the operand by the given symbol.
426  case AffineExprKind::Mul: {
427  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
428  if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind))
429  return binaryExpr.getLHS() *
430  symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind);
431  return symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind) *
432  binaryExpr.getRHS();
433  }
434  // Dividing first operand only by the given symbol.
437  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
438  return getAffineBinaryOpExpr(
439  expr.getKind(),
440  symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
441  binaryExpr.getRHS());
442  }
443  }
444  llvm_unreachable("Unknown AffineExpr");
445 }
446 
447 /// Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv
448 /// operations when the second operand simplifies to a symbol and the first
449 /// operand is divisible by that symbol. It can be applied to any semi-affine
450 /// expression. Returned expression can either be a semi-affine or pure affine
451 /// expression.
453  switch (expr.getKind()) {
457  return expr;
458  case AffineExprKind::Add:
459  case AffineExprKind::Mul: {
460  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
461  return getAffineBinaryOpExpr(expr.getKind(),
462  simplifySemiAffine(binaryExpr.getLHS()),
463  simplifySemiAffine(binaryExpr.getRHS()));
464  }
465  // Check if the simplification of the second operand is a symbol, and the
466  // first operand is divisible by it. If the operation is a modulo, a constant
467  // zero expression is returned. In the case of floordiv and ceildiv, the
468  // symbol from the simplification of the second operand divides the first
469  // operand. Otherwise, simplification is not possible.
472  case AffineExprKind::Mod: {
473  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
474  AffineExpr sLHS = simplifySemiAffine(binaryExpr.getLHS());
475  AffineExpr sRHS = simplifySemiAffine(binaryExpr.getRHS());
476  AffineSymbolExpr symbolExpr =
478  if (!symbolExpr)
479  return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
480  unsigned symbolPos = symbolExpr.getPosition();
481  if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind()))
482  return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
483  if (expr.getKind() == AffineExprKind::Mod)
484  return getAffineConstantExpr(0, expr.getContext());
485  return symbolicDivide(sLHS, symbolPos, expr.getKind());
486  }
487  }
488  llvm_unreachable("Unknown AffineExpr");
489 }
490 
491 static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position,
492  MLIRContext *context) {
493  auto assignCtx = [context](AffineDimExprStorage *storage) {
494  storage->context = context;
495  };
496 
497  StorageUniquer &uniquer = context->getAffineUniquer();
498  return uniquer.get<AffineDimExprStorage>(
499  assignCtx, static_cast<unsigned>(kind), position);
500 }
501 
502 AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) {
503  return getAffineDimOrSymbol(AffineExprKind::DimId, position, context);
504 }
505 
507  : AffineExpr(ptr) {}
509  return static_cast<ImplType *>(expr)->position;
510 }
511 
512 AffineExpr mlir::getAffineSymbolExpr(unsigned position, MLIRContext *context) {
513  return getAffineDimOrSymbol(AffineExprKind::SymbolId, position, context);
514  ;
515 }
516 
518  : AffineExpr(ptr) {}
520  return static_cast<ImplType *>(expr)->constant;
521 }
522 
523 bool AffineExpr::operator==(int64_t v) const {
524  return *this == getAffineConstantExpr(v, getContext());
525 }
526 
528  auto assignCtx = [context](AffineConstantExprStorage *storage) {
529  storage->context = context;
530  };
531 
532  StorageUniquer &uniquer = context->getAffineUniquer();
533  return uniquer.get<AffineConstantExprStorage>(assignCtx, constant);
534 }
535 
536 /// Simplify add expression. Return nullptr if it can't be simplified.
538  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
539  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
540  // Fold if both LHS, RHS are a constant.
541  if (lhsConst && rhsConst)
542  return getAffineConstantExpr(lhsConst.getValue() + rhsConst.getValue(),
543  lhs.getContext());
544 
545  // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
546  // If only one of them is a symbolic expressions, make it the RHS.
547  if (lhs.isa<AffineConstantExpr>() ||
548  (lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())) {
549  return rhs + lhs;
550  }
551 
552  // At this point, if there was a constant, it would be on the right.
553 
554  // Addition with a zero is a noop, return the other input.
555  if (rhsConst) {
556  if (rhsConst.getValue() == 0)
557  return lhs;
558  }
559  // Fold successive additions like (d0 + 2) + 3 into d0 + 5.
560  auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
561  if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Add) {
562  if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>())
563  return lBin.getLHS() + (lrhs.getValue() + rhsConst.getValue());
564  }
565 
566  // Detect "c1 * expr + c_2 * expr" as "(c1 + c2) * expr".
567  // c1 is rRhsConst, c2 is rLhsConst; firstExpr, secondExpr are their
568  // respective multiplicands.
569  std::optional<int64_t> rLhsConst, rRhsConst;
570  AffineExpr firstExpr, secondExpr;
571  AffineConstantExpr rLhsConstExpr;
572  auto lBinOpExpr = lhs.dyn_cast<AffineBinaryOpExpr>();
573  if (lBinOpExpr && lBinOpExpr.getKind() == AffineExprKind::Mul &&
574  (rLhsConstExpr = lBinOpExpr.getRHS().dyn_cast<AffineConstantExpr>())) {
575  rLhsConst = rLhsConstExpr.getValue();
576  firstExpr = lBinOpExpr.getLHS();
577  } else {
578  rLhsConst = 1;
579  firstExpr = lhs;
580  }
581 
582  auto rBinOpExpr = rhs.dyn_cast<AffineBinaryOpExpr>();
583  AffineConstantExpr rRhsConstExpr;
584  if (rBinOpExpr && rBinOpExpr.getKind() == AffineExprKind::Mul &&
585  (rRhsConstExpr = rBinOpExpr.getRHS().dyn_cast<AffineConstantExpr>())) {
586  rRhsConst = rRhsConstExpr.getValue();
587  secondExpr = rBinOpExpr.getLHS();
588  } else {
589  rRhsConst = 1;
590  secondExpr = rhs;
591  }
592 
593  if (rLhsConst && rRhsConst && firstExpr == secondExpr)
594  return getAffineBinaryOpExpr(
595  AffineExprKind::Mul, firstExpr,
596  getAffineConstantExpr(*rLhsConst + *rRhsConst, lhs.getContext()));
597 
598  // When doing successive additions, bring constant to the right: turn (d0 + 2)
599  // + d1 into (d0 + d1) + 2.
600  if (lBin && lBin.getKind() == AffineExprKind::Add) {
601  if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
602  return lBin.getLHS() + rhs + lrhs;
603  }
604  }
605 
606  // Detect and transform "expr - q * (expr floordiv q)" to "expr mod q", where
607  // q may be a constant or symbolic expression. This leads to a much more
608  // efficient form when 'c' is a power of two, and in general a more compact
609  // and readable form.
610 
611  // Process '(expr floordiv c) * (-c)'.
612  if (!rBinOpExpr)
613  return nullptr;
614 
615  auto lrhs = rBinOpExpr.getLHS();
616  auto rrhs = rBinOpExpr.getRHS();
617 
618  AffineExpr llrhs, rlrhs;
619 
620  // Check if lrhsBinOpExpr is of the form (expr floordiv q) * q, where q is a
621  // symbolic expression.
622  auto lrhsBinOpExpr = lrhs.dyn_cast<AffineBinaryOpExpr>();
623  // Check rrhsConstOpExpr = -1.
624  auto rrhsConstOpExpr = rrhs.dyn_cast<AffineConstantExpr>();
625  if (rrhsConstOpExpr && rrhsConstOpExpr.getValue() == -1 && lrhsBinOpExpr &&
626  lrhsBinOpExpr.getKind() == AffineExprKind::Mul) {
627  // Check llrhs = expr floordiv q.
628  llrhs = lrhsBinOpExpr.getLHS();
629  // Check rlrhs = q.
630  rlrhs = lrhsBinOpExpr.getRHS();
631  auto llrhsBinOpExpr = llrhs.dyn_cast<AffineBinaryOpExpr>();
632  if (!llrhsBinOpExpr || llrhsBinOpExpr.getKind() != AffineExprKind::FloorDiv)
633  return nullptr;
634  if (llrhsBinOpExpr.getRHS() == rlrhs && lhs == llrhsBinOpExpr.getLHS())
635  return lhs % rlrhs;
636  }
637 
638  // Process lrhs, which is 'expr floordiv c'.
639  AffineBinaryOpExpr lrBinOpExpr = lrhs.dyn_cast<AffineBinaryOpExpr>();
640  if (!lrBinOpExpr || lrBinOpExpr.getKind() != AffineExprKind::FloorDiv)
641  return nullptr;
642 
643  llrhs = lrBinOpExpr.getLHS();
644  rlrhs = lrBinOpExpr.getRHS();
645 
646  if (lhs == llrhs && rlrhs == -rrhs) {
647  return lhs % rlrhs;
648  }
649  return nullptr;
650 }
651 
653  return *this + getAffineConstantExpr(v, getContext());
654 }
656  if (auto simplified = simplifyAdd(*this, other))
657  return simplified;
658 
660  return uniquer.get<AffineBinaryOpExprStorage>(
661  /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), *this, other);
662 }
663 
664 /// Simplify a multiply expression. Return nullptr if it can't be simplified.
666  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
667  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
668 
669  if (lhsConst && rhsConst)
670  return getAffineConstantExpr(lhsConst.getValue() * rhsConst.getValue(),
671  lhs.getContext());
672 
673  assert(lhs.isSymbolicOrConstant() || rhs.isSymbolicOrConstant());
674 
675  // Canonicalize the mul expression so that the constant/symbolic term is the
676  // RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a
677  // constant. (Note that a constant is trivially symbolic).
678  if (!rhs.isSymbolicOrConstant() || lhs.isa<AffineConstantExpr>()) {
679  // At least one of them has to be symbolic.
680  return rhs * lhs;
681  }
682 
683  // At this point, if there was a constant, it would be on the right.
684 
685  // Multiplication with a one is a noop, return the other input.
686  if (rhsConst) {
687  if (rhsConst.getValue() == 1)
688  return lhs;
689  // Multiplication with zero.
690  if (rhsConst.getValue() == 0)
691  return rhsConst;
692  }
693 
694  // Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6.
695  auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
696  if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Mul) {
697  if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>())
698  return lBin.getLHS() * (lrhs.getValue() * rhsConst.getValue());
699  }
700 
701  // When doing successive multiplication, bring constant to the right: turn (d0
702  // * 2) * d1 into (d0 * d1) * 2.
703  if (lBin && lBin.getKind() == AffineExprKind::Mul) {
704  if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
705  return (lBin.getLHS() * rhs) * lrhs;
706  }
707  }
708 
709  return nullptr;
710 }
711 
713  return *this * getAffineConstantExpr(v, getContext());
714 }
716  if (auto simplified = simplifyMul(*this, other))
717  return simplified;
718 
720  return uniquer.get<AffineBinaryOpExprStorage>(
721  /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), *this, other);
722 }
723 
724 // Unary minus, delegate to operator*.
726  return *this * getAffineConstantExpr(-1, getContext());
727 }
728 
729 // Delegate to operator+.
730 AffineExpr AffineExpr::operator-(int64_t v) const { return *this + (-v); }
732  return *this + (-other);
733 }
734 
736  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
737  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
738 
739  // mlir floordiv by zero or negative numbers is undefined and preserved as is.
740  if (!rhsConst || rhsConst.getValue() < 1)
741  return nullptr;
742 
743  if (lhsConst)
744  return getAffineConstantExpr(
745  floorDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext());
746 
747  // Fold floordiv of a multiply with a constant that is a multiple of the
748  // divisor. Eg: (i * 128) floordiv 64 = i * 2.
749  if (rhsConst == 1)
750  return lhs;
751 
752  // Simplify (expr * const) floordiv divConst when expr is known to be a
753  // multiple of divConst.
754  auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
755  if (lBin && lBin.getKind() == AffineExprKind::Mul) {
756  if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
757  // rhsConst is known to be a positive constant.
758  if (lrhs.getValue() % rhsConst.getValue() == 0)
759  return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
760  }
761  }
762 
763  // Simplify (expr1 + expr2) floordiv divConst when either expr1 or expr2 is
764  // known to be a multiple of divConst.
765  if (lBin && lBin.getKind() == AffineExprKind::Add) {
766  int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
767  int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
768  // rhsConst is known to be a positive constant.
769  if (llhsDiv % rhsConst.getValue() == 0 ||
770  lrhsDiv % rhsConst.getValue() == 0)
771  return lBin.getLHS().floorDiv(rhsConst.getValue()) +
772  lBin.getRHS().floorDiv(rhsConst.getValue());
773  }
774 
775  return nullptr;
776 }
777 
778 AffineExpr AffineExpr::floorDiv(uint64_t v) const {
780 }
782  if (auto simplified = simplifyFloorDiv(*this, other))
783  return simplified;
784 
786  return uniquer.get<AffineBinaryOpExprStorage>(
787  /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::FloorDiv), *this,
788  other);
789 }
790 
792  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
793  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
794 
795  if (!rhsConst || rhsConst.getValue() < 1)
796  return nullptr;
797 
798  if (lhsConst)
799  return getAffineConstantExpr(
800  ceilDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext());
801 
802  // Fold ceildiv of a multiply with a constant that is a multiple of the
803  // divisor. Eg: (i * 128) ceildiv 64 = i * 2.
804  if (rhsConst.getValue() == 1)
805  return lhs;
806 
807  // Simplify (expr * const) ceildiv divConst when const is known to be a
808  // multiple of divConst.
809  auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
810  if (lBin && lBin.getKind() == AffineExprKind::Mul) {
811  if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
812  // rhsConst is known to be a positive constant.
813  if (lrhs.getValue() % rhsConst.getValue() == 0)
814  return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
815  }
816  }
817 
818  return nullptr;
819 }
820 
821 AffineExpr AffineExpr::ceilDiv(uint64_t v) const {
823 }
825  if (auto simplified = simplifyCeilDiv(*this, other))
826  return simplified;
827 
829  return uniquer.get<AffineBinaryOpExprStorage>(
830  /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::CeilDiv), *this,
831  other);
832 }
833 
835  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
836  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
837 
838  // mod w.r.t zero or negative numbers is undefined and preserved as is.
839  if (!rhsConst || rhsConst.getValue() < 1)
840  return nullptr;
841 
842  if (lhsConst)
843  return getAffineConstantExpr(mod(lhsConst.getValue(), rhsConst.getValue()),
844  lhs.getContext());
845 
846  // Fold modulo of an expression that is known to be a multiple of a constant
847  // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128)
848  // mod 64 is folded to 0, and less trivially, (i*(j*4*(k*32))) mod 128 = 0.
849  if (lhs.getLargestKnownDivisor() % rhsConst.getValue() == 0)
850  return getAffineConstantExpr(0, lhs.getContext());
851 
852  // Simplify (expr1 + expr2) mod divConst when either expr1 or expr2 is
853  // known to be a multiple of divConst.
854  auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
855  if (lBin && lBin.getKind() == AffineExprKind::Add) {
856  int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
857  int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
858  // rhsConst is known to be a positive constant.
859  if (llhsDiv % rhsConst.getValue() == 0)
860  return lBin.getRHS() % rhsConst.getValue();
861  if (lrhsDiv % rhsConst.getValue() == 0)
862  return lBin.getLHS() % rhsConst.getValue();
863  }
864 
865  // Simplify (e % a) % b to e % b when b evenly divides a
866  if (lBin && lBin.getKind() == AffineExprKind::Mod) {
867  auto intermediate = lBin.getRHS().dyn_cast<AffineConstantExpr>();
868  if (intermediate && intermediate.getValue() >= 1 &&
869  mod(intermediate.getValue(), rhsConst.getValue()) == 0) {
870  return lBin.getLHS() % rhsConst.getValue();
871  }
872  }
873 
874  return nullptr;
875 }
876 
878  return *this % getAffineConstantExpr(v, getContext());
879 }
881  if (auto simplified = simplifyMod(*this, other))
882  return simplified;
883 
885  return uniquer.get<AffineBinaryOpExprStorage>(
886  /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mod), *this, other);
887 }
888 
890  SmallVector<AffineExpr, 8> dimReplacements(map.getResults().begin(),
891  map.getResults().end());
892  return replaceDimsAndSymbols(dimReplacements, {});
893 }
894 raw_ostream &mlir::operator<<(raw_ostream &os, AffineExpr expr) {
895  expr.print(os);
896  return os;
897 }
898 
899 /// Constructs an affine expression from a flat ArrayRef. If there are local
900 /// identifiers (neither dimensional nor symbolic) that appear in the sum of
901 /// products expression, `localExprs` is expected to have the AffineExpr
902 /// for it, and is substituted into. The ArrayRef `flatExprs` is expected to be
903 /// in the format [dims, symbols, locals, constant term].
905  unsigned numDims,
906  unsigned numSymbols,
907  ArrayRef<AffineExpr> localExprs,
908  MLIRContext *context) {
909  // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
910  assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
911  "unexpected number of local expressions");
912 
913  auto expr = getAffineConstantExpr(0, context);
914  // Dimensions and symbols.
915  for (unsigned j = 0; j < numDims + numSymbols; j++) {
916  if (flatExprs[j] == 0)
917  continue;
918  auto id = j < numDims ? getAffineDimExpr(j, context)
919  : getAffineSymbolExpr(j - numDims, context);
920  expr = expr + id * flatExprs[j];
921  }
922 
923  // Local identifiers.
924  for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
925  j++) {
926  if (flatExprs[j] == 0)
927  continue;
928  auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
929  expr = expr + term;
930  }
931 
932  // Constant term.
933  int64_t constTerm = flatExprs[flatExprs.size() - 1];
934  if (constTerm != 0)
935  expr = expr + constTerm;
936  return expr;
937 }
938 
939 /// Constructs a semi-affine expression from a flat ArrayRef. If there are
940 /// local identifiers (neither dimensional nor symbolic) that appear in the sum
941 /// of products expression, `localExprs` is expected to have the AffineExprs for
942 /// it, and is substituted into. The ArrayRef `flatExprs` is expected to be in
943 /// the format [dims, symbols, locals, constant term]. The semi-affine
944 /// expression is constructed in the sorted order of dimension and symbol
945 /// position numbers. Note: local expressions/ids are used for mod, div as well
946 /// as symbolic RHS terms for terms that are not pure affine.
948  unsigned numDims,
949  unsigned numSymbols,
950  ArrayRef<AffineExpr> localExprs,
951  MLIRContext *context) {
952  assert(!flatExprs.empty() && "flatExprs cannot be empty");
953 
954  // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
955  assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
956  "unexpected number of local expressions");
957 
958  AffineExpr expr = getAffineConstantExpr(0, context);
959 
960  // We design indices as a pair which help us present the semi-affine map as
961  // sum of product where terms are sorted based on dimension or symbol
962  // position: <keyA, keyB> for expressions of the form dimension * symbol,
963  // where keyA is the position number of the dimension and keyB is the
964  // position number of the symbol. For dimensional expressions we set the index
965  // as (position number of the dimension, -1), as we want dimensional
966  // expressions to appear before symbolic and product of dimensional and
967  // symbolic expressions having the dimension with the same position number.
968  // For symbolic expression set the index as (position number of the symbol,
969  // maximum of last dimension and symbol position) number. For example, we want
970  // the expression we are constructing to look something like: d0 + d0 * s0 +
971  // s0 + d1*s1 + s1.
972 
973  // Stores the affine expression corresponding to a given index.
975  // Stores the constant coefficient value corresponding to a given
976  // dimension, symbol or a non-pure affine expression stored in `localExprs`.
977  DenseMap<std::pair<unsigned, signed>, int64_t> coefficients;
978  // Stores the indices as defined above, and later sorted to produce
979  // the semi-affine expression in the desired form.
981 
982  // Example: expression = d0 + d0 * s0 + 2 * s0.
983  // indices = [{0,-1}, {0, 0}, {0, 1}]
984  // coefficients = [{{0, -1}, 1}, {{0, 0}, 1}, {{0, 1}, 2}]
985  // indexToExprMap = [{{0, -1}, d0}, {{0, 0}, d0 * s0}, {{0, 1}, s0}]
986 
987  // Adds entries to `indexToExprMap`, `coefficients` and `indices`.
988  auto addEntry = [&](std::pair<unsigned, signed> index, int64_t coefficient,
989  AffineExpr expr) {
990  assert(!llvm::is_contained(indices, index) &&
991  "Key is already present in indices vector and overwriting will "
992  "happen in `indexToExprMap` and `coefficients`!");
993 
994  indices.push_back(index);
995  coefficients.insert({index, coefficient});
996  indexToExprMap.insert({index, expr});
997  };
998 
999  // Design indices for dimensional or symbolic terms, and store the indices,
1000  // constant coefficient corresponding to the indices in `coefficients` map,
1001  // and affine expression corresponding to indices in `indexToExprMap` map.
1002 
1003  // Ensure we do not have duplicate keys in `indexToExpr` map.
1004  unsigned offsetSym = 0;
1005  signed offsetDim = -1;
1006  for (unsigned j = numDims; j < numDims + numSymbols; ++j) {
1007  if (flatExprs[j] == 0)
1008  continue;
1009  // For symbolic expression set the index as <position number
1010  // of the symbol, max(dimCount, symCount)> number,
1011  // as we want symbolic expressions with the same positional number to
1012  // appear after dimensional expressions having the same positional number.
1013  std::pair<unsigned, signed> indexEntry(
1014  j - numDims, std::max(numDims, numSymbols) + offsetSym++);
1015  addEntry(indexEntry, flatExprs[j],
1016  getAffineSymbolExpr(j - numDims, context));
1017  }
1018 
1019  // Denotes semi-affine product, modulo or division terms, which has been added
1020  // to the `indexToExpr` map.
1021  SmallVector<bool, 4> addedToMap(flatExprs.size() - numDims - numSymbols - 1,
1022  false);
1023  unsigned lhsPos, rhsPos;
1024  // Construct indices for product terms involving dimension, symbol or constant
1025  // as lhs/rhs, and store the indices, constant coefficient corresponding to
1026  // the indices in `coefficients` map, and affine expression corresponding to
1027  // in indices in `indexToExprMap` map.
1028  for (const auto &it : llvm::enumerate(localExprs)) {
1029  AffineExpr expr = it.value();
1030  if (flatExprs[numDims + numSymbols + it.index()] == 0)
1031  continue;
1032  AffineExpr lhs = expr.cast<AffineBinaryOpExpr>().getLHS();
1033  AffineExpr rhs = expr.cast<AffineBinaryOpExpr>().getRHS();
1034  if (!((lhs.isa<AffineDimExpr>() || lhs.isa<AffineSymbolExpr>()) &&
1035  (rhs.isa<AffineDimExpr>() || rhs.isa<AffineSymbolExpr>() ||
1036  rhs.isa<AffineConstantExpr>()))) {
1037  continue;
1038  }
1039  if (rhs.isa<AffineConstantExpr>()) {
1040  // For product/modulo/division expressions, when rhs of modulo/division
1041  // expression is constant, we put 0 in place of keyB, because we want
1042  // them to appear earlier in the semi-affine expression we are
1043  // constructing. When rhs is constant, we place 0 in place of keyB.
1044  if (lhs.isa<AffineDimExpr>()) {
1045  lhsPos = lhs.cast<AffineDimExpr>().getPosition();
1046  std::pair<unsigned, signed> indexEntry(lhsPos, offsetDim--);
1047  addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1048  expr);
1049  } else {
1050  lhsPos = lhs.cast<AffineSymbolExpr>().getPosition();
1051  std::pair<unsigned, signed> indexEntry(
1052  lhsPos, std::max(numDims, numSymbols) + offsetSym++);
1053  addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1054  expr);
1055  }
1056  } else if (lhs.isa<AffineDimExpr>()) {
1057  // For product/modulo/division expressions having lhs as dimension and rhs
1058  // as symbol, we order the terms in the semi-affine expression based on
1059  // the pair: <keyA, keyB> for expressions of the form dimension * symbol,
1060  // where keyA is the position number of the dimension and keyB is the
1061  // position number of the symbol.
1062  lhsPos = lhs.cast<AffineDimExpr>().getPosition();
1063  rhsPos = rhs.cast<AffineSymbolExpr>().getPosition();
1064  std::pair<unsigned, signed> indexEntry(lhsPos, rhsPos);
1065  addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1066  } else {
1067  // For product/modulo/division expressions having both lhs and rhs as
1068  // symbol, we design indices as a pair: <keyA, keyB> for expressions
1069  // of the form dimension * symbol, where keyA is the position number of
1070  // the dimension and keyB is the position number of the symbol.
1071  lhsPos = lhs.cast<AffineSymbolExpr>().getPosition();
1072  rhsPos = rhs.cast<AffineSymbolExpr>().getPosition();
1073  std::pair<unsigned, signed> indexEntry(
1074  lhsPos, std::max(numDims, numSymbols) + offsetSym++);
1075  addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1076  }
1077  addedToMap[it.index()] = true;
1078  }
1079 
1080  for (unsigned j = 0; j < numDims; ++j) {
1081  if (flatExprs[j] == 0)
1082  continue;
1083  // For dimensional expressions we set the index as <position number of the
1084  // dimension, 0>, as we want dimensional expressions to appear before
1085  // symbolic ones and products of dimensional and symbolic expressions
1086  // having the dimension with the same position number.
1087  std::pair<unsigned, signed> indexEntry(j, offsetDim--);
1088  addEntry(indexEntry, flatExprs[j], getAffineDimExpr(j, context));
1089  }
1090 
1091  // Constructing the simplified semi-affine sum of product/division/mod
1092  // expression from the flattened form in the desired sorted order of indices
1093  // of the various individual product/division/mod expressions.
1094  llvm::sort(indices);
1095  for (const std::pair<unsigned, unsigned> index : indices) {
1096  assert(indexToExprMap.lookup(index) &&
1097  "cannot find key in `indexToExprMap` map");
1098  expr = expr + indexToExprMap.lookup(index) * coefficients.lookup(index);
1099  }
1100 
1101  // Local identifiers.
1102  for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
1103  j++) {
1104  // If the coefficient of the local expression is 0, continue as we need not
1105  // add it in out final expression.
1106  if (flatExprs[j] == 0 || addedToMap[j - numDims - numSymbols])
1107  continue;
1108  auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
1109  expr = expr + term;
1110  }
1111 
1112  // Constant term.
1113  int64_t constTerm = flatExprs.back();
1114  if (constTerm != 0)
1115  expr = expr + constTerm;
1116  return expr;
1117 }
1118 
1120  unsigned numSymbols)
1121  : numDims(numDims), numSymbols(numSymbols), numLocals(0) {
1122  operandExprStack.reserve(8);
1123 }
1124 
1125 // In pure affine t = expr * c, we multiply each coefficient of lhs with c.
1126 //
1127 // In case of semi affine multiplication expressions, t = expr * symbolic_expr,
1128 // introduce a local variable p (= expr * symbolic_expr), and the affine
1129 // expression expr * symbolic_expr is added to `localExprs`.
1131  assert(operandExprStack.size() >= 2);
1133  operandExprStack.pop_back();
1135 
1136  // Flatten semi-affine multiplication expressions by introducing a local
1137  // variable in place of the product; the affine expression
1138  // corresponding to the quantifier is added to `localExprs`.
1139  if (!expr.getRHS().isa<AffineConstantExpr>()) {
1140  MLIRContext *context = expr.getContext();
1142  localExprs, context);
1144  localExprs, context);
1145  addLocalVariableSemiAffine(a * b, lhs, lhs.size());
1146  return;
1147  }
1148 
1149  // Get the RHS constant.
1150  auto rhsConst = rhs[getConstantIndex()];
1151  for (unsigned i = 0, e = lhs.size(); i < e; i++) {
1152  lhs[i] *= rhsConst;
1153  }
1154 }
1155 
1157  assert(operandExprStack.size() >= 2);
1158  const auto &rhs = operandExprStack.back();
1159  auto &lhs = operandExprStack[operandExprStack.size() - 2];
1160  assert(lhs.size() == rhs.size());
1161  // Update the LHS in place.
1162  for (unsigned i = 0, e = rhs.size(); i < e; i++) {
1163  lhs[i] += rhs[i];
1164  }
1165  // Pop off the RHS.
1166  operandExprStack.pop_back();
1167 }
1168 
1169 //
1170 // t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1
1171 //
1172 // A mod expression "expr mod c" is thus flattened by introducing a new local
1173 // variable q (= expr floordiv c), such that expr mod c is replaced with
1174 // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
1175 //
1176 // In case of semi-affine modulo expressions, t = expr mod symbolic_expr,
1177 // introduce a local variable m (= expr mod symbolic_expr), and the affine
1178 // expression expr mod symbolic_expr is added to `localExprs`.
1180  assert(operandExprStack.size() >= 2);
1181 
1183  operandExprStack.pop_back();
1185  MLIRContext *context = expr.getContext();
1186 
1187  // Flatten semi affine modulo expressions by introducing a local
1188  // variable in place of the modulo value, and the affine expression
1189  // corresponding to the quantifier is added to `localExprs`.
1190  if (!expr.getRHS().isa<AffineConstantExpr>()) {
1191  AffineExpr dividendExpr = getAffineExprFromFlatForm(
1192  lhs, numDims, numSymbols, localExprs, context);
1194  localExprs, context);
1195  AffineExpr modExpr = dividendExpr % divisorExpr;
1196  addLocalVariableSemiAffine(modExpr, lhs, lhs.size());
1197  return;
1198  }
1199 
1200  int64_t rhsConst = rhs[getConstantIndex()];
1201  // TODO: handle modulo by zero case when this issue is fixed
1202  // at the other places in the IR.
1203  assert(rhsConst > 0 && "RHS constant has to be positive");
1204 
1205  // Check if the LHS expression is a multiple of modulo factor.
1206  unsigned i, e;
1207  for (i = 0, e = lhs.size(); i < e; i++)
1208  if (lhs[i] % rhsConst != 0)
1209  break;
1210  // If yes, modulo expression here simplifies to zero.
1211  if (i == lhs.size()) {
1212  std::fill(lhs.begin(), lhs.end(), 0);
1213  return;
1214  }
1215 
1216  // Add a local variable for the quotient, i.e., expr % c is replaced by
1217  // (expr - q * c) where q = expr floordiv c. Do this while canceling out
1218  // the GCD of expr and c.
1219  SmallVector<int64_t, 8> floorDividend(lhs);
1220  uint64_t gcd = rhsConst;
1221  for (unsigned i = 0, e = lhs.size(); i < e; i++)
1222  gcd = std::gcd(gcd, (uint64_t)std::abs(lhs[i]));
1223  // Simplify the numerator and the denominator.
1224  if (gcd != 1) {
1225  for (unsigned i = 0, e = floorDividend.size(); i < e; i++)
1226  floorDividend[i] = floorDividend[i] / static_cast<int64_t>(gcd);
1227  }
1228  int64_t floorDivisor = rhsConst / static_cast<int64_t>(gcd);
1229 
1230  // Construct the AffineExpr form of the floordiv to store in localExprs.
1231 
1232  AffineExpr dividendExpr = getAffineExprFromFlatForm(
1233  floorDividend, numDims, numSymbols, localExprs, context);
1234  AffineExpr divisorExpr = getAffineConstantExpr(floorDivisor, context);
1235  AffineExpr floorDivExpr = dividendExpr.floorDiv(divisorExpr);
1236  int loc;
1237  if ((loc = findLocalId(floorDivExpr)) == -1) {
1238  addLocalFloorDivId(floorDividend, floorDivisor, floorDivExpr);
1239  // Set result at top of stack to "lhs - rhsConst * q".
1240  lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst;
1241  } else {
1242  // Reuse the existing local id.
1243  lhs[getLocalVarStartIndex() + loc] = -rhsConst;
1244  }
1245 }
1246 
1248  visitDivExpr(expr, /*isCeil=*/true);
1249 }
1251  visitDivExpr(expr, /*isCeil=*/false);
1252 }
1253 
1255  operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1256  auto &eq = operandExprStack.back();
1257  assert(expr.getPosition() < numDims && "Inconsistent number of dims");
1258  eq[getDimStartIndex() + expr.getPosition()] = 1;
1259 }
1260 
1262  operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1263  auto &eq = operandExprStack.back();
1264  assert(expr.getPosition() < numSymbols && "inconsistent number of symbols");
1265  eq[getSymbolStartIndex() + expr.getPosition()] = 1;
1266 }
1267 
1269  operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1270  auto &eq = operandExprStack.back();
1271  eq[getConstantIndex()] = expr.getValue();
1272 }
1273 
1274 void SimpleAffineExprFlattener::addLocalVariableSemiAffine(
1275  AffineExpr expr, SmallVectorImpl<int64_t> &result,
1276  unsigned long resultSize) {
1277  assert(result.size() == resultSize &&
1278  "`result` vector passed is not of correct size");
1279  int loc;
1280  if ((loc = findLocalId(expr)) == -1)
1281  addLocalIdSemiAffine(expr);
1282  std::fill(result.begin(), result.end(), 0);
1283  if (loc == -1)
1284  result[getLocalVarStartIndex() + numLocals - 1] = 1;
1285  else
1286  result[getLocalVarStartIndex() + loc] = 1;
1287 }
1288 
1289 // t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1
1290 // A floordiv is thus flattened by introducing a new local variable q, and
1291 // replacing that expression with 'q' while adding the constraints
1292 // c * q <= expr <= c * q + c - 1 to localVarCst (done by
1293 // FlatAffineConstraints::addLocalFloorDiv).
1294 //
1295 // A ceildiv is similarly flattened:
1296 // t = expr ceildiv c <=> t = (expr + c - 1) floordiv c
1297 //
1298 // In case of semi affine division expressions, t = expr floordiv symbolic_expr
1299 // or t = expr ceildiv symbolic_expr, introduce a local variable q (= expr
1300 // floordiv/ceildiv symbolic_expr), and the affine floordiv/ceildiv is added to
1301 // `localExprs`.
1302 void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
1303  bool isCeil) {
1304  assert(operandExprStack.size() >= 2);
1305 
1306  MLIRContext *context = expr.getContext();
1308  operandExprStack.pop_back();
1310 
1311  // Flatten semi affine division expressions by introducing a local
1312  // variable in place of the quotient, and the affine expression corresponding
1313  // to the quantifier is added to `localExprs`.
1314  if (!expr.getRHS().isa<AffineConstantExpr>()) {
1316  localExprs, context);
1318  localExprs, context);
1319  AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
1320  addLocalVariableSemiAffine(divExpr, lhs, lhs.size());
1321  return;
1322  }
1323 
1324  // This is a pure affine expr; the RHS is a positive constant.
1325  int64_t rhsConst = rhs[getConstantIndex()];
1326  // TODO: handle division by zero at the same time the issue is
1327  // fixed at other places.
1328  assert(rhsConst > 0 && "RHS constant has to be positive");
1329 
1330  // Simplify the floordiv, ceildiv if possible by canceling out the greatest
1331  // common divisors of the numerator and denominator.
1332  uint64_t gcd = std::abs(rhsConst);
1333  for (unsigned i = 0, e = lhs.size(); i < e; i++)
1334  gcd = std::gcd(gcd, (uint64_t)std::abs(lhs[i]));
1335  // Simplify the numerator and the denominator.
1336  if (gcd != 1) {
1337  for (unsigned i = 0, e = lhs.size(); i < e; i++)
1338  lhs[i] = lhs[i] / static_cast<int64_t>(gcd);
1339  }
1340  int64_t divisor = rhsConst / static_cast<int64_t>(gcd);
1341  // If the divisor becomes 1, the updated LHS is the result. (The
1342  // divisor can't be negative since rhsConst is positive).
1343  if (divisor == 1)
1344  return;
1345 
1346  // If the divisor cannot be simplified to one, we will have to retain
1347  // the ceil/floor expr (simplified up until here). Add an existential
1348  // quantifier to express its result, i.e., expr1 div expr2 is replaced
1349  // by a new identifier, q.
1350  AffineExpr a =
1352  AffineExpr b = getAffineConstantExpr(divisor, context);
1353 
1354  int loc;
1355  AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
1356  if ((loc = findLocalId(divExpr)) == -1) {
1357  if (!isCeil) {
1358  SmallVector<int64_t, 8> dividend(lhs);
1359  addLocalFloorDivId(dividend, divisor, divExpr);
1360  } else {
1361  // lhs ceildiv c <=> (lhs + c - 1) floordiv c
1362  SmallVector<int64_t, 8> dividend(lhs);
1363  dividend.back() += divisor - 1;
1364  addLocalFloorDivId(dividend, divisor, divExpr);
1365  }
1366  }
1367  // Set the expression on stack to the local var introduced to capture the
1368  // result of the division (floor or ceil).
1369  std::fill(lhs.begin(), lhs.end(), 0);
1370  if (loc == -1)
1371  lhs[getLocalVarStartIndex() + numLocals - 1] = 1;
1372  else
1373  lhs[getLocalVarStartIndex() + loc] = 1;
1374 }
1375 
1376 // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
1377 // The local identifier added is always a floordiv of a pure add/mul affine
1378 // function of other identifiers, coefficients of which are specified in
1379 // dividend and with respect to a positive constant divisor. localExpr is the
1380 // simplified tree expression (AffineExpr) corresponding to the quantifier.
1382  int64_t divisor,
1383  AffineExpr localExpr) {
1384  assert(divisor > 0 && "positive constant divisor expected");
1385  for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
1386  subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
1387  localExprs.push_back(localExpr);
1388  numLocals++;
1389  // dividend and divisor are not used here; an override of this method uses it.
1390 }
1391 
1393  for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
1394  subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
1395  localExprs.push_back(localExpr);
1396  ++numLocals;
1397 }
1398 
1399 int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) {
1401  if ((it = llvm::find(localExprs, localExpr)) == localExprs.end())
1402  return -1;
1403  return it - localExprs.begin();
1404 }
1405 
1406 /// Simplify the affine expression by flattening it and reconstructing it.
1408  unsigned numSymbols) {
1409  // Simplify semi-affine expressions separately.
1410  if (!expr.isPureAffine())
1411  expr = simplifySemiAffine(expr);
1412 
1413  SimpleAffineExprFlattener flattener(numDims, numSymbols);
1414  flattener.walkPostOrder(expr);
1415  ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
1416  if (!expr.isPureAffine() &&
1417  expr == getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1418  flattener.localExprs,
1419  expr.getContext()))
1420  return expr;
1421  AffineExpr simplifiedExpr =
1422  expr.isPureAffine()
1423  ? getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1424  flattener.localExprs, expr.getContext())
1425  : getSemiAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1426  flattener.localExprs,
1427  expr.getContext());
1428 
1429  flattener.operandExprStack.pop_back();
1430  assert(flattener.operandExprStack.empty());
1431  return simplifiedExpr;
1432 }
static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos, AffineExprKind opKind)
Divides the given expression by the given symbol at position symbolPos.
Definition: AffineExpr.cpp:395
static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs)
Simplify a multiply expression. Return nullptr if it can't be simplified.
Definition: AffineExpr.cpp:665
static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:834
static AffineExpr simplifySemiAffine(AffineExpr expr)
Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv operations when the second...
Definition: AffineExpr.cpp:452
static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs)
Simplify add expression. Return nullptr if it can't be simplified.
Definition: AffineExpr.cpp:537
static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef< int64_t > flatExprs, unsigned numDims, unsigned numSymbols, ArrayRef< AffineExpr > localExprs, MLIRContext *context)
Constructs a semi-affine expression from a flat ArrayRef.
Definition: AffineExpr.cpp:947
static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:791
static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:735
static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:491
static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos, AffineExprKind opKind)
Returns true if the expression is divisible by the given symbol with position symbolPos.
Definition: AffineExpr.cpp:336
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Affine binary operation expression.
Definition: AffineExpr.h:207
AffineExpr getLHS() const
Definition: AffineExpr.cpp:317
AffineBinaryOpExpr(AffineExpr::ImplType *ptr)
Definition: AffineExpr.cpp:315
AffineExpr getRHS() const
Definition: AffineExpr.cpp:320
An integer constant appearing in affine expression.
Definition: AffineExpr.h:232
AffineConstantExpr(AffineExpr::ImplType *ptr=nullptr)
Definition: AffineExpr.cpp:517
int64_t getValue() const
Definition: AffineExpr.cpp:519
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
AffineDimExpr(AffineExpr::ImplType *ptr)
Definition: AffineExpr.cpp:324
unsigned getPosition() const
Definition: AffineExpr.cpp:325
Base class for AffineExpr visitors/walkers.
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements) const
This method substitutes any uses of dimensions and symbols (e.g.
Definition: AffineExpr.cpp:66
AffineExpr shiftDims(unsigned numDims, unsigned shift, unsigned offset=0) const
Replace dims[offset ...
Definition: AffineExpr.cpp:110
AffineExpr operator+(int64_t v) const
Definition: AffineExpr.cpp:652
bool isSymbolicOrConstant() const
Returns true if this expression is made out of only symbols and constants, i.e., it does not involve ...
Definition: AffineExpr.cpp:165
AffineExpr operator*(int64_t v) const
Definition: AffineExpr.cpp:712
bool operator==(AffineExpr other) const
Definition: AffineExpr.h:76
bool isPureAffine() const
Returns true if this is a pure affine expression, i.e., multiplication, floordiv, ceildiv,...
Definition: AffineExpr.cpp:189
AffineExpr shiftSymbols(unsigned numSymbols, unsigned shift, unsigned offset=0) const
Replace symbols[offset ...
Definition: AffineExpr.cpp:122
AffineExpr operator-() const
Definition: AffineExpr.cpp:725
U cast() const
Definition: AffineExpr.h:291
void walk(std::function< void(AffineExpr)> callback) const
Walk all of the AffineExpr's in this expression in postorder.
Definition: AffineExpr.cpp:30
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:778
ImplType * expr
Definition: AffineExpr.h:198
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:27
bool isMultipleOf(int64_t factor) const
Return true if the affine expression is a multiple of 'factor'.
Definition: AffineExpr.cpp:260
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
Definition: AffineExpr.cpp:220
AffineExpr compose(AffineMap map) const
Compose with an AffineMap.
Definition: AffineExpr.cpp:889
constexpr bool isa() const
Definition: AffineExpr.h:270
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
Definition: AffineExpr.cpp:293
bool isFunctionOfSymbol(unsigned position) const
Return true if the affine expression involves AffineSymbolExpr position.
Definition: AffineExpr.cpp:304
AffineExpr replaceDims(ArrayRef< AffineExpr > dimReplacements) const
Dim-only version of replaceDimsAndSymbols.
Definition: AffineExpr.cpp:99
AffineExpr operator%(uint64_t v) const
Definition: AffineExpr.cpp:877
MLIRContext * getContext() const
Definition: AffineExpr.cpp:25
AffineExpr replace(AffineExpr expr, AffineExpr replacement) const
Sparse replace method.
Definition: AffineExpr.cpp:158
AffineExpr replaceSymbols(ArrayRef< AffineExpr > symReplacements) const
Symbol-only version of replaceDimsAndSymbols.
Definition: AffineExpr.cpp:104
AffineExpr ceilDiv(uint64_t v) const
Definition: AffineExpr.cpp:821
void print(raw_ostream &os) const
U dyn_cast() const
Definition: AffineExpr.h:281
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:43
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:332
A symbolic identifier appearing in an affine expression.
Definition: AffineExpr.h:224
AffineSymbolExpr(AffineExpr::ImplType *ptr)
Definition: AffineExpr.cpp:506
unsigned getPosition() const
Definition: AffineExpr.cpp:508
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
StorageUniquer & getAffineUniquer()
Returns the storage uniquer used for creating affine constructs.
virtual void addLocalFloorDivId(ArrayRef< int64_t > dividend, int64_t divisor, AffineExpr localExpr)
void visitFloorDivExpr(AffineBinaryOpExpr expr)
void visitAddExpr(AffineBinaryOpExpr expr)
std::vector< SmallVector< int64_t, 8 > > operandExprStack
void visitDimExpr(AffineDimExpr expr)
void visitConstantExpr(AffineConstantExpr expr)
void visitSymbolExpr(AffineSymbolExpr expr)
virtual void addLocalIdSemiAffine(AffineExpr localExpr)
Add a local identifier (needed to flatten a mod, floordiv, ceildiv, mul expr) when the rhs is a symbo...
SmallVector< AffineExpr, 4 > localExprs
void visitCeilDivExpr(AffineBinaryOpExpr expr)
void visitModExpr(AffineBinaryOpExpr expr)
SimpleAffineExprFlattener(unsigned numDims, unsigned numSymbols)
void visitMulExpr(AffineBinaryOpExpr expr)
A utility class to get or create instances of "storage classes".
Storage * get(function_ref< void(Storage *)> initFn, TypeID id, Args &&...args)
Gets a uniqued instance of 'Storage'.
Detect if any of the given parameter types has a sub-element handler.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:223
LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt gcd(const MPInt &a, const MPInt &b)
Definition: MPInt.h:399
LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt abs(const MPInt &x)
Definition: MPInt.h:370
Include the generated interface declarations.
int64_t floorDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's floordiv operation on constants.
Definition: MathExtras.h:33
int64_t ceilDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's ceildiv operation on constants.
Definition: MathExtras.h:23
AffineExprKind
Definition: AffineExpr.h:40
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ Constant
Constant integer.
@ SymbolId
Symbolic identifier.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:47
AffineExpr getAffineExprFromFlatForm(ArrayRef< int64_t > flatExprs, unsigned numDims, unsigned numSymbols, ArrayRef< AffineExpr > localExprs, MLIRContext *context)
Constructs an affine expression from a flat ArrayRef.
Definition: AffineExpr.cpp:904
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:527
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:502
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:512
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
int64_t mod(int64_t lhs, int64_t rhs)
Returns MLIR's mod operation on constants.
Definition: MathExtras.h:45
A binary operation appearing in an affine expression.
An integer constant appearing in affine expression.
A dimensional or symbolic identifier appearing in an affine expression.
Base storage class appearing in an affine expression.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.