MLIR  18.0.0git
AffineExpr.cpp
Go to the documentation of this file.
1 //===- AffineExpr.cpp - MLIR Affine Expr Classes --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "AffineExprDetail.h"
12 #include "mlir/IR/AffineExpr.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/IntegerSet.h"
17 #include "mlir/Support/TypeID.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include <numeric>
20 #include <optional>
21 
22 using namespace mlir;
23 using namespace mlir::detail;
24 
25 MLIRContext *AffineExpr::getContext() const { return expr->context; }
26 
27 AffineExprKind AffineExpr::getKind() const { return expr->kind; }
28 
29 /// Walk all of the AffineExprs in this subgraph in postorder.
30 void AffineExpr::walk(std::function<void(AffineExpr)> callback) const {
31  struct AffineExprWalker : public AffineExprVisitor<AffineExprWalker> {
32  std::function<void(AffineExpr)> callback;
33 
34  AffineExprWalker(std::function<void(AffineExpr)> callback)
35  : callback(std::move(callback)) {}
36 
37  void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { callback(expr); }
38  void visitConstantExpr(AffineConstantExpr expr) { callback(expr); }
39  void visitDimExpr(AffineDimExpr expr) { callback(expr); }
40  void visitSymbolExpr(AffineSymbolExpr expr) { callback(expr); }
41  };
42 
43  AffineExprWalker(std::move(callback)).walkPostOrder(*this);
44 }
45 
46 // Dispatch affine expression construction based on kind.
48  AffineExpr rhs) {
49  if (kind == AffineExprKind::Add)
50  return lhs + rhs;
51  if (kind == AffineExprKind::Mul)
52  return lhs * rhs;
53  if (kind == AffineExprKind::FloorDiv)
54  return lhs.floorDiv(rhs);
55  if (kind == AffineExprKind::CeilDiv)
56  return lhs.ceilDiv(rhs);
57  if (kind == AffineExprKind::Mod)
58  return lhs % rhs;
59 
60  llvm_unreachable("unknown binary operation on affine expressions");
61 }
62 
63 /// This method substitutes any uses of dimensions and symbols (e.g.
64 /// dim#0 with dimReplacements[0]) and returns the modified expression tree.
67  ArrayRef<AffineExpr> symReplacements) const {
68  switch (getKind()) {
70  return *this;
71  case AffineExprKind::DimId: {
72  unsigned dimId = cast<AffineDimExpr>().getPosition();
73  if (dimId >= dimReplacements.size())
74  return *this;
75  return dimReplacements[dimId];
76  }
78  unsigned symId = cast<AffineSymbolExpr>().getPosition();
79  if (symId >= symReplacements.size())
80  return *this;
81  return symReplacements[symId];
82  }
88  auto binOp = cast<AffineBinaryOpExpr>();
89  auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
90  auto newLHS = lhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
91  auto newRHS = rhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
92  if (newLHS == lhs && newRHS == rhs)
93  return *this;
94  return getAffineBinaryOpExpr(getKind(), newLHS, newRHS);
95  }
96  llvm_unreachable("Unknown AffineExpr");
97 }
98 
100  return replaceDimsAndSymbols(dimReplacements, {});
101 }
102 
105  return replaceDimsAndSymbols({}, symReplacements);
106 }
107 
108 /// Replace dims[offset ... numDims)
109 /// by dims[offset + shift ... shift + numDims).
110 AffineExpr AffineExpr::shiftDims(unsigned numDims, unsigned shift,
111  unsigned offset) const {
113  for (unsigned idx = 0; idx < offset; ++idx)
114  dims.push_back(getAffineDimExpr(idx, getContext()));
115  for (unsigned idx = offset; idx < numDims; ++idx)
116  dims.push_back(getAffineDimExpr(idx + shift, getContext()));
117  return replaceDimsAndSymbols(dims, {});
118 }
119 
120 /// Replace symbols[offset ... numSymbols)
121 /// by symbols[offset + shift ... shift + numSymbols).
122 AffineExpr AffineExpr::shiftSymbols(unsigned numSymbols, unsigned shift,
123  unsigned offset) const {
125  for (unsigned idx = 0; idx < offset; ++idx)
126  symbols.push_back(getAffineSymbolExpr(idx, getContext()));
127  for (unsigned idx = offset; idx < numSymbols; ++idx)
128  symbols.push_back(getAffineSymbolExpr(idx + shift, getContext()));
129  return replaceDimsAndSymbols({}, symbols);
130 }
131 
132 /// Sparse replace method. Return the modified expression tree.
135  auto it = map.find(*this);
136  if (it != map.end())
137  return it->second;
138  switch (getKind()) {
139  default:
140  return *this;
141  case AffineExprKind::Add:
142  case AffineExprKind::Mul:
145  case AffineExprKind::Mod:
146  auto binOp = cast<AffineBinaryOpExpr>();
147  auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
148  auto newLHS = lhs.replace(map);
149  auto newRHS = rhs.replace(map);
150  if (newLHS == lhs && newRHS == rhs)
151  return *this;
152  return getAffineBinaryOpExpr(getKind(), newLHS, newRHS);
153  }
154  llvm_unreachable("Unknown AffineExpr");
155 }
156 
157 /// Sparse replace method. Return the modified expression tree.
160  map.insert(std::make_pair(expr, replacement));
161  return replace(map);
162 }
163 /// Returns true if this expression is made out of only symbols and
164 /// constants (no dimensional identifiers).
166  switch (getKind()) {
168  return true;
170  return false;
172  return true;
173 
174  case AffineExprKind::Add:
175  case AffineExprKind::Mul:
178  case AffineExprKind::Mod: {
179  auto expr = this->cast<AffineBinaryOpExpr>();
180  return expr.getLHS().isSymbolicOrConstant() &&
181  expr.getRHS().isSymbolicOrConstant();
182  }
183  }
184  llvm_unreachable("Unknown AffineExpr");
185 }
186 
187 /// Returns true if this is a pure affine expression, i.e., multiplication,
188 /// floordiv, ceildiv, and mod is only allowed w.r.t constants.
190  switch (getKind()) {
194  return true;
195  case AffineExprKind::Add: {
196  auto op = cast<AffineBinaryOpExpr>();
197  return op.getLHS().isPureAffine() && op.getRHS().isPureAffine();
198  }
199 
200  case AffineExprKind::Mul: {
201  // TODO: Canonicalize the constants in binary operators to the RHS when
202  // possible, allowing this to merge into the next case.
203  auto op = cast<AffineBinaryOpExpr>();
204  return op.getLHS().isPureAffine() && op.getRHS().isPureAffine() &&
205  (op.getLHS().template isa<AffineConstantExpr>() ||
206  op.getRHS().template isa<AffineConstantExpr>());
207  }
210  case AffineExprKind::Mod: {
211  auto op = cast<AffineBinaryOpExpr>();
212  return op.getLHS().isPureAffine() &&
213  op.getRHS().template isa<AffineConstantExpr>();
214  }
215  }
216  llvm_unreachable("Unknown AffineExpr");
217 }
218 
219 // Returns the greatest known integral divisor of this affine expression.
221  AffineBinaryOpExpr binExpr(nullptr);
222  switch (getKind()) {
224  [[fallthrough]];
226  return 1;
228  [[fallthrough]];
230  // If the RHS is a constant and divides the known divisor on the LHS, the
231  // quotient is a known divisor of the expression.
232  binExpr = this->cast<AffineBinaryOpExpr>();
233  auto rhs = binExpr.getRHS().dyn_cast<AffineConstantExpr>();
234  // Leave alone undefined expressions.
235  if (rhs && rhs.getValue() != 0) {
236  int64_t lhsDiv = binExpr.getLHS().getLargestKnownDivisor();
237  if (lhsDiv % rhs.getValue() == 0)
238  return lhsDiv / rhs.getValue();
239  }
240  return 1;
241  }
243  return std::abs(this->cast<AffineConstantExpr>().getValue());
244  case AffineExprKind::Mul: {
245  binExpr = this->cast<AffineBinaryOpExpr>();
246  return binExpr.getLHS().getLargestKnownDivisor() *
247  binExpr.getRHS().getLargestKnownDivisor();
248  }
249  case AffineExprKind::Add:
250  [[fallthrough]];
251  case AffineExprKind::Mod: {
252  binExpr = cast<AffineBinaryOpExpr>();
253  return std::gcd((uint64_t)binExpr.getLHS().getLargestKnownDivisor(),
254  (uint64_t)binExpr.getRHS().getLargestKnownDivisor());
255  }
256  }
257  llvm_unreachable("Unknown AffineExpr");
258 }
259 
260 bool AffineExpr::isMultipleOf(int64_t factor) const {
261  AffineBinaryOpExpr binExpr(nullptr);
262  uint64_t l, u;
263  switch (getKind()) {
265  [[fallthrough]];
267  return factor * factor == 1;
269  return cast<AffineConstantExpr>().getValue() % factor == 0;
270  case AffineExprKind::Mul: {
271  binExpr = cast<AffineBinaryOpExpr>();
272  // It's probably not worth optimizing this further (to not traverse the
273  // whole sub-tree under - it that would require a version of isMultipleOf
274  // that on a 'false' return also returns the largest known divisor).
275  return (l = binExpr.getLHS().getLargestKnownDivisor()) % factor == 0 ||
276  (u = binExpr.getRHS().getLargestKnownDivisor()) % factor == 0 ||
277  (l * u) % factor == 0;
278  }
279  case AffineExprKind::Add:
282  case AffineExprKind::Mod: {
283  binExpr = cast<AffineBinaryOpExpr>();
284  return std::gcd((uint64_t)binExpr.getLHS().getLargestKnownDivisor(),
285  (uint64_t)binExpr.getRHS().getLargestKnownDivisor()) %
286  factor ==
287  0;
288  }
289  }
290  llvm_unreachable("Unknown AffineExpr");
291 }
292 
293 bool AffineExpr::isFunctionOfDim(unsigned position) const {
294  if (getKind() == AffineExprKind::DimId) {
295  return *this == mlir::getAffineDimExpr(position, getContext());
296  }
297  if (auto expr = this->dyn_cast<AffineBinaryOpExpr>()) {
298  return expr.getLHS().isFunctionOfDim(position) ||
299  expr.getRHS().isFunctionOfDim(position);
300  }
301  return false;
302 }
303 
304 bool AffineExpr::isFunctionOfSymbol(unsigned position) const {
305  if (getKind() == AffineExprKind::SymbolId) {
306  return *this == mlir::getAffineSymbolExpr(position, getContext());
307  }
308  if (auto expr = this->dyn_cast<AffineBinaryOpExpr>()) {
309  return expr.getLHS().isFunctionOfSymbol(position) ||
310  expr.getRHS().isFunctionOfSymbol(position);
311  }
312  return false;
313 }
314 
316  : AffineExpr(ptr) {}
318  return static_cast<ImplType *>(expr)->lhs;
319 }
321  return static_cast<ImplType *>(expr)->rhs;
322 }
323 
325 unsigned AffineDimExpr::getPosition() const {
326  return static_cast<ImplType *>(expr)->position;
327 }
328 
329 /// Returns true if the expression is divisible by the given symbol with
330 /// position `symbolPos`. The argument `opKind` specifies here what kind of
331 /// division or mod operation called this division. It helps in implementing the
332 /// commutative property of the floordiv and ceildiv operations. If the argument
333 ///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv
334 /// operation, then the commutative property can be used otherwise, the floordiv
335 /// operation is not divisible. The same argument holds for ceildiv operation.
336 static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
337  AffineExprKind opKind) {
338  // The argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
339  assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
340  opKind == AffineExprKind::CeilDiv) &&
341  "unexpected opKind");
342  switch (expr.getKind()) {
344  return expr.cast<AffineConstantExpr>().getValue() == 0;
346  return false;
348  return (expr.cast<AffineSymbolExpr>().getPosition() == symbolPos);
349  // Checks divisibility by the given symbol for both operands.
350  case AffineExprKind::Add: {
351  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
352  return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) &&
353  isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
354  }
355  // Checks divisibility by the given symbol for both operands. Consider the
356  // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
357  // this is a division by s1 and both the operands of modulo are divisible by
358  // s1 but it is not divisible by s1 always. The third argument is
359  // `AffineExprKind::Mod` for this reason.
360  case AffineExprKind::Mod: {
361  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
362  return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos,
364  isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos,
366  }
367  // Checks if any of the operand divisible by the given symbol.
368  case AffineExprKind::Mul: {
369  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
370  return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) ||
371  isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
372  }
373  // Floordiv and ceildiv are divisible by the given symbol when the first
374  // operand is divisible, and the affine expression kind of the argument expr
375  // is same as the argument `opKind`. This can be inferred from commutative
376  // property of floordiv and ceildiv operations and are as follow:
377  // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
378  // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
379  // It will fail if operations are not same. For example:
380  // (exps1 ceildiv exp2) floordiv exp3 can not be simplified.
383  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
384  if (opKind != expr.getKind())
385  return false;
386  return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind());
387  }
388  }
389  llvm_unreachable("Unknown AffineExpr");
390 }
391 
392 /// Divides the given expression by the given symbol at position `symbolPos`. It
393 /// considers the divisibility condition is checked before calling itself. A
394 /// null expression is returned whenever the divisibility condition fails.
395 static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos,
396  AffineExprKind opKind) {
397  // THe argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
398  assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
399  opKind == AffineExprKind::CeilDiv) &&
400  "unexpected opKind");
401  switch (expr.getKind()) {
403  if (expr.cast<AffineConstantExpr>().getValue() != 0)
404  return nullptr;
405  return getAffineConstantExpr(0, expr.getContext());
407  return nullptr;
409  return getAffineConstantExpr(1, expr.getContext());
410  // Dividing both operands by the given symbol.
411  case AffineExprKind::Add: {
412  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
413  return getAffineBinaryOpExpr(
414  expr.getKind(), symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind),
415  symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind));
416  }
417  // Dividing both operands by the given symbol.
418  case AffineExprKind::Mod: {
419  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
420  return getAffineBinaryOpExpr(
421  expr.getKind(),
422  symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
423  symbolicDivide(binaryExpr.getRHS(), symbolPos, expr.getKind()));
424  }
425  // Dividing any of the operand by the given symbol.
426  case AffineExprKind::Mul: {
427  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
428  if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind))
429  return binaryExpr.getLHS() *
430  symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind);
431  return symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind) *
432  binaryExpr.getRHS();
433  }
434  // Dividing first operand only by the given symbol.
437  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
438  return getAffineBinaryOpExpr(
439  expr.getKind(),
440  symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
441  binaryExpr.getRHS());
442  }
443  }
444  llvm_unreachable("Unknown AffineExpr");
445 }
446 
447 /// Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv
448 /// operations when the second operand simplifies to a symbol and the first
449 /// operand is divisible by that symbol. It can be applied to any semi-affine
450 /// expression. Returned expression can either be a semi-affine or pure affine
451 /// expression.
453  switch (expr.getKind()) {
457  return expr;
458  case AffineExprKind::Add:
459  case AffineExprKind::Mul: {
460  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
461  return getAffineBinaryOpExpr(expr.getKind(),
462  simplifySemiAffine(binaryExpr.getLHS()),
463  simplifySemiAffine(binaryExpr.getRHS()));
464  }
465  // Check if the simplification of the second operand is a symbol, and the
466  // first operand is divisible by it. If the operation is a modulo, a constant
467  // zero expression is returned. In the case of floordiv and ceildiv, the
468  // symbol from the simplification of the second operand divides the first
469  // operand. Otherwise, simplification is not possible.
472  case AffineExprKind::Mod: {
473  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
474  AffineExpr sLHS = simplifySemiAffine(binaryExpr.getLHS());
475  AffineExpr sRHS = simplifySemiAffine(binaryExpr.getRHS());
476  AffineSymbolExpr symbolExpr =
478  if (!symbolExpr)
479  return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
480  unsigned symbolPos = symbolExpr.getPosition();
481  if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind()))
482  return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
483  if (expr.getKind() == AffineExprKind::Mod)
484  return getAffineConstantExpr(0, expr.getContext());
485  return symbolicDivide(sLHS, symbolPos, expr.getKind());
486  }
487  }
488  llvm_unreachable("Unknown AffineExpr");
489 }
490 
491 static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position,
492  MLIRContext *context) {
493  auto assignCtx = [context](AffineDimExprStorage *storage) {
494  storage->context = context;
495  };
496 
497  StorageUniquer &uniquer = context->getAffineUniquer();
498  return uniquer.get<AffineDimExprStorage>(
499  assignCtx, static_cast<unsigned>(kind), position);
500 }
501 
502 AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) {
503  return getAffineDimOrSymbol(AffineExprKind::DimId, position, context);
504 }
505 
507  : AffineExpr(ptr) {}
509  return static_cast<ImplType *>(expr)->position;
510 }
511 
512 AffineExpr mlir::getAffineSymbolExpr(unsigned position, MLIRContext *context) {
513  return getAffineDimOrSymbol(AffineExprKind::SymbolId, position, context);
514  ;
515 }
516 
518  : AffineExpr(ptr) {}
520  return static_cast<ImplType *>(expr)->constant;
521 }
522 
523 bool AffineExpr::operator==(int64_t v) const {
524  return *this == getAffineConstantExpr(v, getContext());
525 }
526 
528  auto assignCtx = [context](AffineConstantExprStorage *storage) {
529  storage->context = context;
530  };
531 
532  StorageUniquer &uniquer = context->getAffineUniquer();
533  return uniquer.get<AffineConstantExprStorage>(assignCtx, constant);
534 }
535 
538  MLIRContext *context) {
539  return llvm::to_vector(llvm::map_range(constants, [&](int64_t constant) {
540  return getAffineConstantExpr(constant, context);
541  }));
542 }
543 
544 /// Simplify add expression. Return nullptr if it can't be simplified.
546  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
547  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
548  // Fold if both LHS, RHS are a constant.
549  if (lhsConst && rhsConst)
550  return getAffineConstantExpr(lhsConst.getValue() + rhsConst.getValue(),
551  lhs.getContext());
552 
553  // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
554  // If only one of them is a symbolic expressions, make it the RHS.
555  if (lhs.isa<AffineConstantExpr>() ||
556  (lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())) {
557  return rhs + lhs;
558  }
559 
560  // At this point, if there was a constant, it would be on the right.
561 
562  // Addition with a zero is a noop, return the other input.
563  if (rhsConst) {
564  if (rhsConst.getValue() == 0)
565  return lhs;
566  }
567  // Fold successive additions like (d0 + 2) + 3 into d0 + 5.
568  auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
569  if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Add) {
570  if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>())
571  return lBin.getLHS() + (lrhs.getValue() + rhsConst.getValue());
572  }
573 
574  // Detect "c1 * expr + c_2 * expr" as "(c1 + c2) * expr".
575  // c1 is rRhsConst, c2 is rLhsConst; firstExpr, secondExpr are their
576  // respective multiplicands.
577  std::optional<int64_t> rLhsConst, rRhsConst;
578  AffineExpr firstExpr, secondExpr;
579  AffineConstantExpr rLhsConstExpr;
580  auto lBinOpExpr = lhs.dyn_cast<AffineBinaryOpExpr>();
581  if (lBinOpExpr && lBinOpExpr.getKind() == AffineExprKind::Mul &&
582  (rLhsConstExpr = lBinOpExpr.getRHS().dyn_cast<AffineConstantExpr>())) {
583  rLhsConst = rLhsConstExpr.getValue();
584  firstExpr = lBinOpExpr.getLHS();
585  } else {
586  rLhsConst = 1;
587  firstExpr = lhs;
588  }
589 
590  auto rBinOpExpr = rhs.dyn_cast<AffineBinaryOpExpr>();
591  AffineConstantExpr rRhsConstExpr;
592  if (rBinOpExpr && rBinOpExpr.getKind() == AffineExprKind::Mul &&
593  (rRhsConstExpr = rBinOpExpr.getRHS().dyn_cast<AffineConstantExpr>())) {
594  rRhsConst = rRhsConstExpr.getValue();
595  secondExpr = rBinOpExpr.getLHS();
596  } else {
597  rRhsConst = 1;
598  secondExpr = rhs;
599  }
600 
601  if (rLhsConst && rRhsConst && firstExpr == secondExpr)
602  return getAffineBinaryOpExpr(
603  AffineExprKind::Mul, firstExpr,
604  getAffineConstantExpr(*rLhsConst + *rRhsConst, lhs.getContext()));
605 
606  // When doing successive additions, bring constant to the right: turn (d0 + 2)
607  // + d1 into (d0 + d1) + 2.
608  if (lBin && lBin.getKind() == AffineExprKind::Add) {
609  if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
610  return lBin.getLHS() + rhs + lrhs;
611  }
612  }
613 
614  // Detect and transform "expr - q * (expr floordiv q)" to "expr mod q", where
615  // q may be a constant or symbolic expression. This leads to a much more
616  // efficient form when 'c' is a power of two, and in general a more compact
617  // and readable form.
618 
619  // Process '(expr floordiv c) * (-c)'.
620  if (!rBinOpExpr)
621  return nullptr;
622 
623  auto lrhs = rBinOpExpr.getLHS();
624  auto rrhs = rBinOpExpr.getRHS();
625 
626  AffineExpr llrhs, rlrhs;
627 
628  // Check if lrhsBinOpExpr is of the form (expr floordiv q) * q, where q is a
629  // symbolic expression.
630  auto lrhsBinOpExpr = lrhs.dyn_cast<AffineBinaryOpExpr>();
631  // Check rrhsConstOpExpr = -1.
632  auto rrhsConstOpExpr = rrhs.dyn_cast<AffineConstantExpr>();
633  if (rrhsConstOpExpr && rrhsConstOpExpr.getValue() == -1 && lrhsBinOpExpr &&
634  lrhsBinOpExpr.getKind() == AffineExprKind::Mul) {
635  // Check llrhs = expr floordiv q.
636  llrhs = lrhsBinOpExpr.getLHS();
637  // Check rlrhs = q.
638  rlrhs = lrhsBinOpExpr.getRHS();
639  auto llrhsBinOpExpr = llrhs.dyn_cast<AffineBinaryOpExpr>();
640  if (!llrhsBinOpExpr || llrhsBinOpExpr.getKind() != AffineExprKind::FloorDiv)
641  return nullptr;
642  if (llrhsBinOpExpr.getRHS() == rlrhs && lhs == llrhsBinOpExpr.getLHS())
643  return lhs % rlrhs;
644  }
645 
646  // Process lrhs, which is 'expr floordiv c'.
647  AffineBinaryOpExpr lrBinOpExpr = lrhs.dyn_cast<AffineBinaryOpExpr>();
648  if (!lrBinOpExpr || lrBinOpExpr.getKind() != AffineExprKind::FloorDiv)
649  return nullptr;
650 
651  llrhs = lrBinOpExpr.getLHS();
652  rlrhs = lrBinOpExpr.getRHS();
653 
654  if (lhs == llrhs && rlrhs == -rrhs) {
655  return lhs % rlrhs;
656  }
657  return nullptr;
658 }
659 
661  return *this + getAffineConstantExpr(v, getContext());
662 }
664  if (auto simplified = simplifyAdd(*this, other))
665  return simplified;
666 
668  return uniquer.get<AffineBinaryOpExprStorage>(
669  /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), *this, other);
670 }
671 
672 /// Simplify a multiply expression. Return nullptr if it can't be simplified.
674  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
675  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
676 
677  if (lhsConst && rhsConst)
678  return getAffineConstantExpr(lhsConst.getValue() * rhsConst.getValue(),
679  lhs.getContext());
680 
681  assert(lhs.isSymbolicOrConstant() || rhs.isSymbolicOrConstant());
682 
683  // Canonicalize the mul expression so that the constant/symbolic term is the
684  // RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a
685  // constant. (Note that a constant is trivially symbolic).
686  if (!rhs.isSymbolicOrConstant() || lhs.isa<AffineConstantExpr>()) {
687  // At least one of them has to be symbolic.
688  return rhs * lhs;
689  }
690 
691  // At this point, if there was a constant, it would be on the right.
692 
693  // Multiplication with a one is a noop, return the other input.
694  if (rhsConst) {
695  if (rhsConst.getValue() == 1)
696  return lhs;
697  // Multiplication with zero.
698  if (rhsConst.getValue() == 0)
699  return rhsConst;
700  }
701 
702  // Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6.
703  auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
704  if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Mul) {
705  if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>())
706  return lBin.getLHS() * (lrhs.getValue() * rhsConst.getValue());
707  }
708 
709  // When doing successive multiplication, bring constant to the right: turn (d0
710  // * 2) * d1 into (d0 * d1) * 2.
711  if (lBin && lBin.getKind() == AffineExprKind::Mul) {
712  if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
713  return (lBin.getLHS() * rhs) * lrhs;
714  }
715  }
716 
717  return nullptr;
718 }
719 
721  return *this * getAffineConstantExpr(v, getContext());
722 }
724  if (auto simplified = simplifyMul(*this, other))
725  return simplified;
726 
728  return uniquer.get<AffineBinaryOpExprStorage>(
729  /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), *this, other);
730 }
731 
732 // Unary minus, delegate to operator*.
734  return *this * getAffineConstantExpr(-1, getContext());
735 }
736 
737 // Delegate to operator+.
738 AffineExpr AffineExpr::operator-(int64_t v) const { return *this + (-v); }
740  return *this + (-other);
741 }
742 
744  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
745  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
746 
747  // mlir floordiv by zero or negative numbers is undefined and preserved as is.
748  if (!rhsConst || rhsConst.getValue() < 1)
749  return nullptr;
750 
751  if (lhsConst)
752  return getAffineConstantExpr(
753  floorDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext());
754 
755  // Fold floordiv of a multiply with a constant that is a multiple of the
756  // divisor. Eg: (i * 128) floordiv 64 = i * 2.
757  if (rhsConst == 1)
758  return lhs;
759 
760  // Simplify (expr * const) floordiv divConst when expr is known to be a
761  // multiple of divConst.
762  auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
763  if (lBin && lBin.getKind() == AffineExprKind::Mul) {
764  if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
765  // rhsConst is known to be a positive constant.
766  if (lrhs.getValue() % rhsConst.getValue() == 0)
767  return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
768  }
769  }
770 
771  // Simplify (expr1 + expr2) floordiv divConst when either expr1 or expr2 is
772  // known to be a multiple of divConst.
773  if (lBin && lBin.getKind() == AffineExprKind::Add) {
774  int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
775  int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
776  // rhsConst is known to be a positive constant.
777  if (llhsDiv % rhsConst.getValue() == 0 ||
778  lrhsDiv % rhsConst.getValue() == 0)
779  return lBin.getLHS().floorDiv(rhsConst.getValue()) +
780  lBin.getRHS().floorDiv(rhsConst.getValue());
781  }
782 
783  return nullptr;
784 }
785 
786 AffineExpr AffineExpr::floorDiv(uint64_t v) const {
788 }
790  if (auto simplified = simplifyFloorDiv(*this, other))
791  return simplified;
792 
794  return uniquer.get<AffineBinaryOpExprStorage>(
795  /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::FloorDiv), *this,
796  other);
797 }
798 
800  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
801  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
802 
803  if (!rhsConst || rhsConst.getValue() < 1)
804  return nullptr;
805 
806  if (lhsConst)
807  return getAffineConstantExpr(
808  ceilDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext());
809 
810  // Fold ceildiv of a multiply with a constant that is a multiple of the
811  // divisor. Eg: (i * 128) ceildiv 64 = i * 2.
812  if (rhsConst.getValue() == 1)
813  return lhs;
814 
815  // Simplify (expr * const) ceildiv divConst when const is known to be a
816  // multiple of divConst.
817  auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
818  if (lBin && lBin.getKind() == AffineExprKind::Mul) {
819  if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
820  // rhsConst is known to be a positive constant.
821  if (lrhs.getValue() % rhsConst.getValue() == 0)
822  return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
823  }
824  }
825 
826  return nullptr;
827 }
828 
829 AffineExpr AffineExpr::ceilDiv(uint64_t v) const {
831 }
833  if (auto simplified = simplifyCeilDiv(*this, other))
834  return simplified;
835 
837  return uniquer.get<AffineBinaryOpExprStorage>(
838  /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::CeilDiv), *this,
839  other);
840 }
841 
843  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
844  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
845 
846  // mod w.r.t zero or negative numbers is undefined and preserved as is.
847  if (!rhsConst || rhsConst.getValue() < 1)
848  return nullptr;
849 
850  if (lhsConst)
851  return getAffineConstantExpr(mod(lhsConst.getValue(), rhsConst.getValue()),
852  lhs.getContext());
853 
854  // Fold modulo of an expression that is known to be a multiple of a constant
855  // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128)
856  // mod 64 is folded to 0, and less trivially, (i*(j*4*(k*32))) mod 128 = 0.
857  if (lhs.getLargestKnownDivisor() % rhsConst.getValue() == 0)
858  return getAffineConstantExpr(0, lhs.getContext());
859 
860  // Simplify (expr1 + expr2) mod divConst when either expr1 or expr2 is
861  // known to be a multiple of divConst.
862  auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
863  if (lBin && lBin.getKind() == AffineExprKind::Add) {
864  int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
865  int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
866  // rhsConst is known to be a positive constant.
867  if (llhsDiv % rhsConst.getValue() == 0)
868  return lBin.getRHS() % rhsConst.getValue();
869  if (lrhsDiv % rhsConst.getValue() == 0)
870  return lBin.getLHS() % rhsConst.getValue();
871  }
872 
873  // Simplify (e % a) % b to e % b when b evenly divides a
874  if (lBin && lBin.getKind() == AffineExprKind::Mod) {
875  auto intermediate = lBin.getRHS().dyn_cast<AffineConstantExpr>();
876  if (intermediate && intermediate.getValue() >= 1 &&
877  mod(intermediate.getValue(), rhsConst.getValue()) == 0) {
878  return lBin.getLHS() % rhsConst.getValue();
879  }
880  }
881 
882  return nullptr;
883 }
884 
886  return *this % getAffineConstantExpr(v, getContext());
887 }
889  if (auto simplified = simplifyMod(*this, other))
890  return simplified;
891 
893  return uniquer.get<AffineBinaryOpExprStorage>(
894  /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mod), *this, other);
895 }
896 
898  SmallVector<AffineExpr, 8> dimReplacements(map.getResults().begin(),
899  map.getResults().end());
900  return replaceDimsAndSymbols(dimReplacements, {});
901 }
902 raw_ostream &mlir::operator<<(raw_ostream &os, AffineExpr expr) {
903  expr.print(os);
904  return os;
905 }
906 
907 /// Constructs an affine expression from a flat ArrayRef. If there are local
908 /// identifiers (neither dimensional nor symbolic) that appear in the sum of
909 /// products expression, `localExprs` is expected to have the AffineExpr
910 /// for it, and is substituted into. The ArrayRef `flatExprs` is expected to be
911 /// in the format [dims, symbols, locals, constant term].
913  unsigned numDims,
914  unsigned numSymbols,
915  ArrayRef<AffineExpr> localExprs,
916  MLIRContext *context) {
917  // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
918  assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
919  "unexpected number of local expressions");
920 
921  auto expr = getAffineConstantExpr(0, context);
922  // Dimensions and symbols.
923  for (unsigned j = 0; j < numDims + numSymbols; j++) {
924  if (flatExprs[j] == 0)
925  continue;
926  auto id = j < numDims ? getAffineDimExpr(j, context)
927  : getAffineSymbolExpr(j - numDims, context);
928  expr = expr + id * flatExprs[j];
929  }
930 
931  // Local identifiers.
932  for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
933  j++) {
934  if (flatExprs[j] == 0)
935  continue;
936  auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
937  expr = expr + term;
938  }
939 
940  // Constant term.
941  int64_t constTerm = flatExprs[flatExprs.size() - 1];
942  if (constTerm != 0)
943  expr = expr + constTerm;
944  return expr;
945 }
946 
947 /// Constructs a semi-affine expression from a flat ArrayRef. If there are
948 /// local identifiers (neither dimensional nor symbolic) that appear in the sum
949 /// of products expression, `localExprs` is expected to have the AffineExprs for
950 /// it, and is substituted into. The ArrayRef `flatExprs` is expected to be in
951 /// the format [dims, symbols, locals, constant term]. The semi-affine
952 /// expression is constructed in the sorted order of dimension and symbol
953 /// position numbers. Note: local expressions/ids are used for mod, div as well
954 /// as symbolic RHS terms for terms that are not pure affine.
956  unsigned numDims,
957  unsigned numSymbols,
958  ArrayRef<AffineExpr> localExprs,
959  MLIRContext *context) {
960  assert(!flatExprs.empty() && "flatExprs cannot be empty");
961 
962  // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
963  assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
964  "unexpected number of local expressions");
965 
966  AffineExpr expr = getAffineConstantExpr(0, context);
967 
968  // We design indices as a pair which help us present the semi-affine map as
969  // sum of product where terms are sorted based on dimension or symbol
970  // position: <keyA, keyB> for expressions of the form dimension * symbol,
971  // where keyA is the position number of the dimension and keyB is the
972  // position number of the symbol. For dimensional expressions we set the index
973  // as (position number of the dimension, -1), as we want dimensional
974  // expressions to appear before symbolic and product of dimensional and
975  // symbolic expressions having the dimension with the same position number.
976  // For symbolic expression set the index as (position number of the symbol,
977  // maximum of last dimension and symbol position) number. For example, we want
978  // the expression we are constructing to look something like: d0 + d0 * s0 +
979  // s0 + d1*s1 + s1.
980 
981  // Stores the affine expression corresponding to a given index.
983  // Stores the constant coefficient value corresponding to a given
984  // dimension, symbol or a non-pure affine expression stored in `localExprs`.
985  DenseMap<std::pair<unsigned, signed>, int64_t> coefficients;
986  // Stores the indices as defined above, and later sorted to produce
987  // the semi-affine expression in the desired form.
989 
990  // Example: expression = d0 + d0 * s0 + 2 * s0.
991  // indices = [{0,-1}, {0, 0}, {0, 1}]
992  // coefficients = [{{0, -1}, 1}, {{0, 0}, 1}, {{0, 1}, 2}]
993  // indexToExprMap = [{{0, -1}, d0}, {{0, 0}, d0 * s0}, {{0, 1}, s0}]
994 
995  // Adds entries to `indexToExprMap`, `coefficients` and `indices`.
996  auto addEntry = [&](std::pair<unsigned, signed> index, int64_t coefficient,
997  AffineExpr expr) {
998  assert(!llvm::is_contained(indices, index) &&
999  "Key is already present in indices vector and overwriting will "
1000  "happen in `indexToExprMap` and `coefficients`!");
1001 
1002  indices.push_back(index);
1003  coefficients.insert({index, coefficient});
1004  indexToExprMap.insert({index, expr});
1005  };
1006 
1007  // Design indices for dimensional or symbolic terms, and store the indices,
1008  // constant coefficient corresponding to the indices in `coefficients` map,
1009  // and affine expression corresponding to indices in `indexToExprMap` map.
1010 
1011  // Ensure we do not have duplicate keys in `indexToExpr` map.
1012  unsigned offsetSym = 0;
1013  signed offsetDim = -1;
1014  for (unsigned j = numDims; j < numDims + numSymbols; ++j) {
1015  if (flatExprs[j] == 0)
1016  continue;
1017  // For symbolic expression set the index as <position number
1018  // of the symbol, max(dimCount, symCount)> number,
1019  // as we want symbolic expressions with the same positional number to
1020  // appear after dimensional expressions having the same positional number.
1021  std::pair<unsigned, signed> indexEntry(
1022  j - numDims, std::max(numDims, numSymbols) + offsetSym++);
1023  addEntry(indexEntry, flatExprs[j],
1024  getAffineSymbolExpr(j - numDims, context));
1025  }
1026 
1027  // Denotes semi-affine product, modulo or division terms, which has been added
1028  // to the `indexToExpr` map.
1029  SmallVector<bool, 4> addedToMap(flatExprs.size() - numDims - numSymbols - 1,
1030  false);
1031  unsigned lhsPos, rhsPos;
1032  // Construct indices for product terms involving dimension, symbol or constant
1033  // as lhs/rhs, and store the indices, constant coefficient corresponding to
1034  // the indices in `coefficients` map, and affine expression corresponding to
1035  // in indices in `indexToExprMap` map.
1036  for (const auto &it : llvm::enumerate(localExprs)) {
1037  AffineExpr expr = it.value();
1038  if (flatExprs[numDims + numSymbols + it.index()] == 0)
1039  continue;
1040  AffineExpr lhs = expr.cast<AffineBinaryOpExpr>().getLHS();
1041  AffineExpr rhs = expr.cast<AffineBinaryOpExpr>().getRHS();
1042  if (!((lhs.isa<AffineDimExpr>() || lhs.isa<AffineSymbolExpr>()) &&
1043  (rhs.isa<AffineDimExpr>() || rhs.isa<AffineSymbolExpr>() ||
1044  rhs.isa<AffineConstantExpr>()))) {
1045  continue;
1046  }
1047  if (rhs.isa<AffineConstantExpr>()) {
1048  // For product/modulo/division expressions, when rhs of modulo/division
1049  // expression is constant, we put 0 in place of keyB, because we want
1050  // them to appear earlier in the semi-affine expression we are
1051  // constructing. When rhs is constant, we place 0 in place of keyB.
1052  if (lhs.isa<AffineDimExpr>()) {
1053  lhsPos = lhs.cast<AffineDimExpr>().getPosition();
1054  std::pair<unsigned, signed> indexEntry(lhsPos, offsetDim--);
1055  addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1056  expr);
1057  } else {
1058  lhsPos = lhs.cast<AffineSymbolExpr>().getPosition();
1059  std::pair<unsigned, signed> indexEntry(
1060  lhsPos, std::max(numDims, numSymbols) + offsetSym++);
1061  addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1062  expr);
1063  }
1064  } else if (lhs.isa<AffineDimExpr>()) {
1065  // For product/modulo/division expressions having lhs as dimension and rhs
1066  // as symbol, we order the terms in the semi-affine expression based on
1067  // the pair: <keyA, keyB> for expressions of the form dimension * symbol,
1068  // where keyA is the position number of the dimension and keyB is the
1069  // position number of the symbol.
1070  lhsPos = lhs.cast<AffineDimExpr>().getPosition();
1071  rhsPos = rhs.cast<AffineSymbolExpr>().getPosition();
1072  std::pair<unsigned, signed> indexEntry(lhsPos, rhsPos);
1073  addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1074  } else {
1075  // For product/modulo/division expressions having both lhs and rhs as
1076  // symbol, we design indices as a pair: <keyA, keyB> for expressions
1077  // of the form dimension * symbol, where keyA is the position number of
1078  // the dimension and keyB is the position number of the symbol.
1079  lhsPos = lhs.cast<AffineSymbolExpr>().getPosition();
1080  rhsPos = rhs.cast<AffineSymbolExpr>().getPosition();
1081  std::pair<unsigned, signed> indexEntry(
1082  lhsPos, std::max(numDims, numSymbols) + offsetSym++);
1083  addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1084  }
1085  addedToMap[it.index()] = true;
1086  }
1087 
1088  for (unsigned j = 0; j < numDims; ++j) {
1089  if (flatExprs[j] == 0)
1090  continue;
1091  // For dimensional expressions we set the index as <position number of the
1092  // dimension, 0>, as we want dimensional expressions to appear before
1093  // symbolic ones and products of dimensional and symbolic expressions
1094  // having the dimension with the same position number.
1095  std::pair<unsigned, signed> indexEntry(j, offsetDim--);
1096  addEntry(indexEntry, flatExprs[j], getAffineDimExpr(j, context));
1097  }
1098 
1099  // Constructing the simplified semi-affine sum of product/division/mod
1100  // expression from the flattened form in the desired sorted order of indices
1101  // of the various individual product/division/mod expressions.
1102  llvm::sort(indices);
1103  for (const std::pair<unsigned, unsigned> index : indices) {
1104  assert(indexToExprMap.lookup(index) &&
1105  "cannot find key in `indexToExprMap` map");
1106  expr = expr + indexToExprMap.lookup(index) * coefficients.lookup(index);
1107  }
1108 
1109  // Local identifiers.
1110  for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
1111  j++) {
1112  // If the coefficient of the local expression is 0, continue as we need not
1113  // add it in out final expression.
1114  if (flatExprs[j] == 0 || addedToMap[j - numDims - numSymbols])
1115  continue;
1116  auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
1117  expr = expr + term;
1118  }
1119 
1120  // Constant term.
1121  int64_t constTerm = flatExprs.back();
1122  if (constTerm != 0)
1123  expr = expr + constTerm;
1124  return expr;
1125 }
1126 
1128  unsigned numSymbols)
1129  : numDims(numDims), numSymbols(numSymbols), numLocals(0) {
1130  operandExprStack.reserve(8);
1131 }
1132 
1133 // In pure affine t = expr * c, we multiply each coefficient of lhs with c.
1134 //
1135 // In case of semi affine multiplication expressions, t = expr * symbolic_expr,
1136 // introduce a local variable p (= expr * symbolic_expr), and the affine
1137 // expression expr * symbolic_expr is added to `localExprs`.
1139  assert(operandExprStack.size() >= 2);
1141  operandExprStack.pop_back();
1143 
1144  // Flatten semi-affine multiplication expressions by introducing a local
1145  // variable in place of the product; the affine expression
1146  // corresponding to the quantifier is added to `localExprs`.
1147  if (!expr.getRHS().isa<AffineConstantExpr>()) {
1148  MLIRContext *context = expr.getContext();
1150  localExprs, context);
1152  localExprs, context);
1153  addLocalVariableSemiAffine(a * b, lhs, lhs.size());
1154  return;
1155  }
1156 
1157  // Get the RHS constant.
1158  auto rhsConst = rhs[getConstantIndex()];
1159  for (unsigned i = 0, e = lhs.size(); i < e; i++) {
1160  lhs[i] *= rhsConst;
1161  }
1162 }
1163 
1165  assert(operandExprStack.size() >= 2);
1166  const auto &rhs = operandExprStack.back();
1167  auto &lhs = operandExprStack[operandExprStack.size() - 2];
1168  assert(lhs.size() == rhs.size());
1169  // Update the LHS in place.
1170  for (unsigned i = 0, e = rhs.size(); i < e; i++) {
1171  lhs[i] += rhs[i];
1172  }
1173  // Pop off the RHS.
1174  operandExprStack.pop_back();
1175 }
1176 
1177 //
1178 // t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1
1179 //
1180 // A mod expression "expr mod c" is thus flattened by introducing a new local
1181 // variable q (= expr floordiv c), such that expr mod c is replaced with
1182 // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
1183 //
1184 // In case of semi-affine modulo expressions, t = expr mod symbolic_expr,
1185 // introduce a local variable m (= expr mod symbolic_expr), and the affine
1186 // expression expr mod symbolic_expr is added to `localExprs`.
1188  assert(operandExprStack.size() >= 2);
1189 
1191  operandExprStack.pop_back();
1193  MLIRContext *context = expr.getContext();
1194 
1195  // Flatten semi affine modulo expressions by introducing a local
1196  // variable in place of the modulo value, and the affine expression
1197  // corresponding to the quantifier is added to `localExprs`.
1198  if (!expr.getRHS().isa<AffineConstantExpr>()) {
1199  AffineExpr dividendExpr = getAffineExprFromFlatForm(
1200  lhs, numDims, numSymbols, localExprs, context);
1202  localExprs, context);
1203  AffineExpr modExpr = dividendExpr % divisorExpr;
1204  addLocalVariableSemiAffine(modExpr, lhs, lhs.size());
1205  return;
1206  }
1207 
1208  int64_t rhsConst = rhs[getConstantIndex()];
1209  // TODO: handle modulo by zero case when this issue is fixed
1210  // at the other places in the IR.
1211  assert(rhsConst > 0 && "RHS constant has to be positive");
1212 
1213  // Check if the LHS expression is a multiple of modulo factor.
1214  unsigned i, e;
1215  for (i = 0, e = lhs.size(); i < e; i++)
1216  if (lhs[i] % rhsConst != 0)
1217  break;
1218  // If yes, modulo expression here simplifies to zero.
1219  if (i == lhs.size()) {
1220  std::fill(lhs.begin(), lhs.end(), 0);
1221  return;
1222  }
1223 
1224  // Add a local variable for the quotient, i.e., expr % c is replaced by
1225  // (expr - q * c) where q = expr floordiv c. Do this while canceling out
1226  // the GCD of expr and c.
1227  SmallVector<int64_t, 8> floorDividend(lhs);
1228  uint64_t gcd = rhsConst;
1229  for (unsigned i = 0, e = lhs.size(); i < e; i++)
1230  gcd = std::gcd(gcd, (uint64_t)std::abs(lhs[i]));
1231  // Simplify the numerator and the denominator.
1232  if (gcd != 1) {
1233  for (unsigned i = 0, e = floorDividend.size(); i < e; i++)
1234  floorDividend[i] = floorDividend[i] / static_cast<int64_t>(gcd);
1235  }
1236  int64_t floorDivisor = rhsConst / static_cast<int64_t>(gcd);
1237 
1238  // Construct the AffineExpr form of the floordiv to store in localExprs.
1239 
1240  AffineExpr dividendExpr = getAffineExprFromFlatForm(
1241  floorDividend, numDims, numSymbols, localExprs, context);
1242  AffineExpr divisorExpr = getAffineConstantExpr(floorDivisor, context);
1243  AffineExpr floorDivExpr = dividendExpr.floorDiv(divisorExpr);
1244  int loc;
1245  if ((loc = findLocalId(floorDivExpr)) == -1) {
1246  addLocalFloorDivId(floorDividend, floorDivisor, floorDivExpr);
1247  // Set result at top of stack to "lhs - rhsConst * q".
1248  lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst;
1249  } else {
1250  // Reuse the existing local id.
1251  lhs[getLocalVarStartIndex() + loc] = -rhsConst;
1252  }
1253 }
1254 
1256  visitDivExpr(expr, /*isCeil=*/true);
1257 }
1259  visitDivExpr(expr, /*isCeil=*/false);
1260 }
1261 
1263  operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1264  auto &eq = operandExprStack.back();
1265  assert(expr.getPosition() < numDims && "Inconsistent number of dims");
1266  eq[getDimStartIndex() + expr.getPosition()] = 1;
1267 }
1268 
1270  operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1271  auto &eq = operandExprStack.back();
1272  assert(expr.getPosition() < numSymbols && "inconsistent number of symbols");
1273  eq[getSymbolStartIndex() + expr.getPosition()] = 1;
1274 }
1275 
1277  operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1278  auto &eq = operandExprStack.back();
1279  eq[getConstantIndex()] = expr.getValue();
1280 }
1281 
1282 void SimpleAffineExprFlattener::addLocalVariableSemiAffine(
1283  AffineExpr expr, SmallVectorImpl<int64_t> &result,
1284  unsigned long resultSize) {
1285  assert(result.size() == resultSize &&
1286  "`result` vector passed is not of correct size");
1287  int loc;
1288  if ((loc = findLocalId(expr)) == -1)
1289  addLocalIdSemiAffine(expr);
1290  std::fill(result.begin(), result.end(), 0);
1291  if (loc == -1)
1292  result[getLocalVarStartIndex() + numLocals - 1] = 1;
1293  else
1294  result[getLocalVarStartIndex() + loc] = 1;
1295 }
1296 
1297 // t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1
1298 // A floordiv is thus flattened by introducing a new local variable q, and
1299 // replacing that expression with 'q' while adding the constraints
1300 // c * q <= expr <= c * q + c - 1 to localVarCst (done by
1301 // IntegerRelation::addLocalFloorDiv).
1302 //
1303 // A ceildiv is similarly flattened:
1304 // t = expr ceildiv c <=> t = (expr + c - 1) floordiv c
1305 //
1306 // In case of semi affine division expressions, t = expr floordiv symbolic_expr
1307 // or t = expr ceildiv symbolic_expr, introduce a local variable q (= expr
1308 // floordiv/ceildiv symbolic_expr), and the affine floordiv/ceildiv is added to
1309 // `localExprs`.
1310 void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
1311  bool isCeil) {
1312  assert(operandExprStack.size() >= 2);
1313 
1314  MLIRContext *context = expr.getContext();
1316  operandExprStack.pop_back();
1318 
1319  // Flatten semi affine division expressions by introducing a local
1320  // variable in place of the quotient, and the affine expression corresponding
1321  // to the quantifier is added to `localExprs`.
1322  if (!expr.getRHS().isa<AffineConstantExpr>()) {
1324  localExprs, context);
1326  localExprs, context);
1327  AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
1328  addLocalVariableSemiAffine(divExpr, lhs, lhs.size());
1329  return;
1330  }
1331 
1332  // This is a pure affine expr; the RHS is a positive constant.
1333  int64_t rhsConst = rhs[getConstantIndex()];
1334  // TODO: handle division by zero at the same time the issue is
1335  // fixed at other places.
1336  assert(rhsConst > 0 && "RHS constant has to be positive");
1337 
1338  // Simplify the floordiv, ceildiv if possible by canceling out the greatest
1339  // common divisors of the numerator and denominator.
1340  uint64_t gcd = std::abs(rhsConst);
1341  for (unsigned i = 0, e = lhs.size(); i < e; i++)
1342  gcd = std::gcd(gcd, (uint64_t)std::abs(lhs[i]));
1343  // Simplify the numerator and the denominator.
1344  if (gcd != 1) {
1345  for (unsigned i = 0, e = lhs.size(); i < e; i++)
1346  lhs[i] = lhs[i] / static_cast<int64_t>(gcd);
1347  }
1348  int64_t divisor = rhsConst / static_cast<int64_t>(gcd);
1349  // If the divisor becomes 1, the updated LHS is the result. (The
1350  // divisor can't be negative since rhsConst is positive).
1351  if (divisor == 1)
1352  return;
1353 
1354  // If the divisor cannot be simplified to one, we will have to retain
1355  // the ceil/floor expr (simplified up until here). Add an existential
1356  // quantifier to express its result, i.e., expr1 div expr2 is replaced
1357  // by a new identifier, q.
1358  AffineExpr a =
1360  AffineExpr b = getAffineConstantExpr(divisor, context);
1361 
1362  int loc;
1363  AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
1364  if ((loc = findLocalId(divExpr)) == -1) {
1365  if (!isCeil) {
1366  SmallVector<int64_t, 8> dividend(lhs);
1367  addLocalFloorDivId(dividend, divisor, divExpr);
1368  } else {
1369  // lhs ceildiv c <=> (lhs + c - 1) floordiv c
1370  SmallVector<int64_t, 8> dividend(lhs);
1371  dividend.back() += divisor - 1;
1372  addLocalFloorDivId(dividend, divisor, divExpr);
1373  }
1374  }
1375  // Set the expression on stack to the local var introduced to capture the
1376  // result of the division (floor or ceil).
1377  std::fill(lhs.begin(), lhs.end(), 0);
1378  if (loc == -1)
1379  lhs[getLocalVarStartIndex() + numLocals - 1] = 1;
1380  else
1381  lhs[getLocalVarStartIndex() + loc] = 1;
1382 }
1383 
1384 // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
1385 // The local identifier added is always a floordiv of a pure add/mul affine
1386 // function of other identifiers, coefficients of which are specified in
1387 // dividend and with respect to a positive constant divisor. localExpr is the
1388 // simplified tree expression (AffineExpr) corresponding to the quantifier.
1390  int64_t divisor,
1391  AffineExpr localExpr) {
1392  assert(divisor > 0 && "positive constant divisor expected");
1393  for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
1394  subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
1395  localExprs.push_back(localExpr);
1396  numLocals++;
1397  // dividend and divisor are not used here; an override of this method uses it.
1398 }
1399 
1401  for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
1402  subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
1403  localExprs.push_back(localExpr);
1404  ++numLocals;
1405 }
1406 
1407 int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) {
1409  if ((it = llvm::find(localExprs, localExpr)) == localExprs.end())
1410  return -1;
1411  return it - localExprs.begin();
1412 }
1413 
1414 /// Simplify the affine expression by flattening it and reconstructing it.
1416  unsigned numSymbols) {
1417  // Simplify semi-affine expressions separately.
1418  if (!expr.isPureAffine())
1419  expr = simplifySemiAffine(expr);
1420 
1421  SimpleAffineExprFlattener flattener(numDims, numSymbols);
1422  flattener.walkPostOrder(expr);
1423  ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
1424  if (!expr.isPureAffine() &&
1425  expr == getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1426  flattener.localExprs,
1427  expr.getContext()))
1428  return expr;
1429  AffineExpr simplifiedExpr =
1430  expr.isPureAffine()
1431  ? getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1432  flattener.localExprs, expr.getContext())
1433  : getSemiAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1434  flattener.localExprs,
1435  expr.getContext());
1436 
1437  flattener.operandExprStack.pop_back();
1438  assert(flattener.operandExprStack.empty());
1439  return simplifiedExpr;
1440 }
static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos, AffineExprKind opKind)
Divides the given expression by the given symbol at position symbolPos.
Definition: AffineExpr.cpp:395
static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs)
Simplify a multiply expression. Return nullptr if it can't be simplified.
Definition: AffineExpr.cpp:673
static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:842
static AffineExpr simplifySemiAffine(AffineExpr expr)
Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv operations when the second...
Definition: AffineExpr.cpp:452
static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs)
Simplify add expression. Return nullptr if it can't be simplified.
Definition: AffineExpr.cpp:545
static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef< int64_t > flatExprs, unsigned numDims, unsigned numSymbols, ArrayRef< AffineExpr > localExprs, MLIRContext *context)
Constructs a semi-affine expression from a flat ArrayRef.
Definition: AffineExpr.cpp:955
static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:799
static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:743
static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:491
static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos, AffineExprKind opKind)
Returns true if the expression is divisible by the given symbol with position symbolPos.
Definition: AffineExpr.cpp:336
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Affine binary operation expression.
Definition: AffineExpr.h:207
AffineExpr getLHS() const
Definition: AffineExpr.cpp:317
AffineBinaryOpExpr(AffineExpr::ImplType *ptr)
Definition: AffineExpr.cpp:315
AffineExpr getRHS() const
Definition: AffineExpr.cpp:320
An integer constant appearing in affine expression.
Definition: AffineExpr.h:232
AffineConstantExpr(AffineExpr::ImplType *ptr=nullptr)
Definition: AffineExpr.cpp:517
int64_t getValue() const
Definition: AffineExpr.cpp:519
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
AffineDimExpr(AffineExpr::ImplType *ptr)
Definition: AffineExpr.cpp:324
unsigned getPosition() const
Definition: AffineExpr.cpp:325
Base class for AffineExpr visitors/walkers.
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements) const
This method substitutes any uses of dimensions and symbols (e.g.
Definition: AffineExpr.cpp:66
AffineExpr shiftDims(unsigned numDims, unsigned shift, unsigned offset=0) const
Replace dims[offset ...
Definition: AffineExpr.cpp:110
AffineExpr operator+(int64_t v) const
Definition: AffineExpr.cpp:660
bool isSymbolicOrConstant() const
Returns true if this expression is made out of only symbols and constants, i.e., it does not involve ...
Definition: AffineExpr.cpp:165
AffineExpr operator*(int64_t v) const
Definition: AffineExpr.cpp:720
bool operator==(AffineExpr other) const
Definition: AffineExpr.h:76
bool isPureAffine() const
Returns true if this is a pure affine expression, i.e., multiplication, floordiv, ceildiv,...
Definition: AffineExpr.cpp:189
AffineExpr shiftSymbols(unsigned numSymbols, unsigned shift, unsigned offset=0) const
Replace symbols[offset ...
Definition: AffineExpr.cpp:122
AffineExpr operator-() const
Definition: AffineExpr.cpp:733
U cast() const
Definition: AffineExpr.h:293
void walk(std::function< void(AffineExpr)> callback) const
Walk all of the AffineExpr's in this expression in postorder.
Definition: AffineExpr.cpp:30
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:786
ImplType * expr
Definition: AffineExpr.h:198
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:27
bool isMultipleOf(int64_t factor) const
Return true if the affine expression is a multiple of 'factor'.
Definition: AffineExpr.cpp:260
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
Definition: AffineExpr.cpp:220
AffineExpr compose(AffineMap map) const
Compose with an AffineMap.
Definition: AffineExpr.cpp:897
constexpr bool isa() const
Definition: AffineExpr.h:272
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
Definition: AffineExpr.cpp:293
bool isFunctionOfSymbol(unsigned position) const
Return true if the affine expression involves AffineSymbolExpr position.
Definition: AffineExpr.cpp:304
AffineExpr replaceDims(ArrayRef< AffineExpr > dimReplacements) const
Dim-only version of replaceDimsAndSymbols.
Definition: AffineExpr.cpp:99
AffineExpr operator%(uint64_t v) const
Definition: AffineExpr.cpp:885
MLIRContext * getContext() const
Definition: AffineExpr.cpp:25
AffineExpr replace(AffineExpr expr, AffineExpr replacement) const
Sparse replace method.
Definition: AffineExpr.cpp:158
AffineExpr replaceSymbols(ArrayRef< AffineExpr > symReplacements) const
Symbol-only version of replaceDimsAndSymbols.
Definition: AffineExpr.cpp:104
AffineExpr ceilDiv(uint64_t v) const
Definition: AffineExpr.cpp:829
void print(raw_ostream &os) const
U dyn_cast() const
Definition: AffineExpr.h:283
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:44
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:350
A symbolic identifier appearing in an affine expression.
Definition: AffineExpr.h:224
AffineSymbolExpr(AffineExpr::ImplType *ptr)
Definition: AffineExpr.cpp:506
unsigned getPosition() const
Definition: AffineExpr.cpp:508
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
StorageUniquer & getAffineUniquer()
Returns the storage uniquer used for creating affine constructs.
virtual void addLocalFloorDivId(ArrayRef< int64_t > dividend, int64_t divisor, AffineExpr localExpr)
void visitFloorDivExpr(AffineBinaryOpExpr expr)
void visitAddExpr(AffineBinaryOpExpr expr)
std::vector< SmallVector< int64_t, 8 > > operandExprStack
void visitDimExpr(AffineDimExpr expr)
void visitConstantExpr(AffineConstantExpr expr)
void visitSymbolExpr(AffineSymbolExpr expr)
virtual void addLocalIdSemiAffine(AffineExpr localExpr)
Add a local identifier (needed to flatten a mod, floordiv, ceildiv, mul expr) when the rhs is a symbo...
SmallVector< AffineExpr, 4 > localExprs
void visitCeilDivExpr(AffineBinaryOpExpr expr)
void visitModExpr(AffineBinaryOpExpr expr)
SimpleAffineExprFlattener(unsigned numDims, unsigned numSymbols)
void visitMulExpr(AffineBinaryOpExpr expr)
A utility class to get or create instances of "storage classes".
Storage * get(function_ref< void(Storage *)> initFn, TypeID id, Args &&...args)
Gets a uniqued instance of 'Storage'.
Detect if any of the given parameter types has a sub-element handler.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt gcd(const MPInt &a, const MPInt &b)
Definition: MPInt.h:399
LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt abs(const MPInt &x)
Definition: MPInt.h:370
This header declares functions that assist transformations in the MemRef dialect.
int64_t floorDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's floordiv operation on constants.
Definition: MathExtras.h:33
int64_t ceilDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's ceildiv operation on constants.
Definition: MathExtras.h:23
AffineExprKind
Definition: AffineExpr.h:40
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ Constant
Constant integer.
@ SymbolId
Symbolic identifier.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:47
AffineExpr getAffineExprFromFlatForm(ArrayRef< int64_t > flatExprs, unsigned numDims, unsigned numSymbols, ArrayRef< AffineExpr > localExprs, MLIRContext *context)
Constructs an affine expression from a flat ArrayRef.
Definition: AffineExpr.cpp:912
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:527
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
SmallVector< AffineExpr > getAffineConstantExprs(ArrayRef< int64_t > constants, MLIRContext *context)
Definition: AffineExpr.cpp:537
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:502
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:512
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
int64_t mod(int64_t lhs, int64_t rhs)
Returns MLIR's mod operation on constants.
Definition: MathExtras.h:45
A binary operation appearing in an affine expression.
An integer constant appearing in affine expression.
A dimensional or symbolic identifier appearing in an affine expression.
Base storage class appearing in an affine expression.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.