MLIR  16.0.0git
AffineExpr.cpp
Go to the documentation of this file.
1 //===- AffineExpr.cpp - MLIR Affine Expr Classes --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "AffineExprDetail.h"
12 #include "mlir/IR/AffineExpr.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/IntegerSet.h"
17 #include "mlir/Support/TypeID.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include <numeric>
20 
21 using namespace mlir;
22 using namespace mlir::detail;
23 
24 MLIRContext *AffineExpr::getContext() const { return expr->context; }
25 
26 AffineExprKind AffineExpr::getKind() const { return expr->kind; }
27 
28 /// Walk all of the AffineExprs in this subgraph in postorder.
29 void AffineExpr::walk(std::function<void(AffineExpr)> callback) const {
30  struct AffineExprWalker : public AffineExprVisitor<AffineExprWalker> {
31  std::function<void(AffineExpr)> callback;
32 
33  AffineExprWalker(std::function<void(AffineExpr)> callback)
34  : callback(std::move(callback)) {}
35 
36  void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { callback(expr); }
37  void visitConstantExpr(AffineConstantExpr expr) { callback(expr); }
38  void visitDimExpr(AffineDimExpr expr) { callback(expr); }
39  void visitSymbolExpr(AffineSymbolExpr expr) { callback(expr); }
40  };
41 
42  AffineExprWalker(std::move(callback)).walkPostOrder(*this);
43 }
44 
45 // Dispatch affine expression construction based on kind.
47  AffineExpr rhs) {
48  if (kind == AffineExprKind::Add)
49  return lhs + rhs;
50  if (kind == AffineExprKind::Mul)
51  return lhs * rhs;
52  if (kind == AffineExprKind::FloorDiv)
53  return lhs.floorDiv(rhs);
54  if (kind == AffineExprKind::CeilDiv)
55  return lhs.ceilDiv(rhs);
56  if (kind == AffineExprKind::Mod)
57  return lhs % rhs;
58 
59  llvm_unreachable("unknown binary operation on affine expressions");
60 }
61 
62 /// This method substitutes any uses of dimensions and symbols (e.g.
63 /// dim#0 with dimReplacements[0]) and returns the modified expression tree.
66  ArrayRef<AffineExpr> symReplacements) const {
67  switch (getKind()) {
69  return *this;
70  case AffineExprKind::DimId: {
71  unsigned dimId = cast<AffineDimExpr>().getPosition();
72  if (dimId >= dimReplacements.size())
73  return *this;
74  return dimReplacements[dimId];
75  }
77  unsigned symId = cast<AffineSymbolExpr>().getPosition();
78  if (symId >= symReplacements.size())
79  return *this;
80  return symReplacements[symId];
81  }
87  auto binOp = cast<AffineBinaryOpExpr>();
88  auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
89  auto newLHS = lhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
90  auto newRHS = rhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
91  if (newLHS == lhs && newRHS == rhs)
92  return *this;
93  return getAffineBinaryOpExpr(getKind(), newLHS, newRHS);
94  }
95  llvm_unreachable("Unknown AffineExpr");
96 }
97 
99  return replaceDimsAndSymbols(dimReplacements, {});
100 }
101 
104  return replaceDimsAndSymbols({}, symReplacements);
105 }
106 
107 /// Replace dims[offset ... numDims)
108 /// by dims[offset + shift ... shift + numDims).
109 AffineExpr AffineExpr::shiftDims(unsigned numDims, unsigned shift,
110  unsigned offset) const {
112  for (unsigned idx = 0; idx < offset; ++idx)
113  dims.push_back(getAffineDimExpr(idx, getContext()));
114  for (unsigned idx = offset; idx < numDims; ++idx)
115  dims.push_back(getAffineDimExpr(idx + shift, getContext()));
116  return replaceDimsAndSymbols(dims, {});
117 }
118 
119 /// Replace symbols[offset ... numSymbols)
120 /// by symbols[offset + shift ... shift + numSymbols).
121 AffineExpr AffineExpr::shiftSymbols(unsigned numSymbols, unsigned shift,
122  unsigned offset) const {
124  for (unsigned idx = 0; idx < offset; ++idx)
125  symbols.push_back(getAffineSymbolExpr(idx, getContext()));
126  for (unsigned idx = offset; idx < numSymbols; ++idx)
127  symbols.push_back(getAffineSymbolExpr(idx + shift, getContext()));
128  return replaceDimsAndSymbols({}, symbols);
129 }
130 
131 /// Sparse replace method. Return the modified expression tree.
134  auto it = map.find(*this);
135  if (it != map.end())
136  return it->second;
137  switch (getKind()) {
138  default:
139  return *this;
140  case AffineExprKind::Add:
141  case AffineExprKind::Mul:
144  case AffineExprKind::Mod:
145  auto binOp = cast<AffineBinaryOpExpr>();
146  auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
147  auto newLHS = lhs.replace(map);
148  auto newRHS = rhs.replace(map);
149  if (newLHS == lhs && newRHS == rhs)
150  return *this;
151  return getAffineBinaryOpExpr(getKind(), newLHS, newRHS);
152  }
153  llvm_unreachable("Unknown AffineExpr");
154 }
155 
156 /// Sparse replace method. Return the modified expression tree.
159  map.insert(std::make_pair(expr, replacement));
160  return replace(map);
161 }
162 /// Returns true if this expression is made out of only symbols and
163 /// constants (no dimensional identifiers).
165  switch (getKind()) {
167  return true;
169  return false;
171  return true;
172 
173  case AffineExprKind::Add:
174  case AffineExprKind::Mul:
177  case AffineExprKind::Mod: {
178  auto expr = this->cast<AffineBinaryOpExpr>();
179  return expr.getLHS().isSymbolicOrConstant() &&
180  expr.getRHS().isSymbolicOrConstant();
181  }
182  }
183  llvm_unreachable("Unknown AffineExpr");
184 }
185 
186 /// Returns true if this is a pure affine expression, i.e., multiplication,
187 /// floordiv, ceildiv, and mod is only allowed w.r.t constants.
189  switch (getKind()) {
193  return true;
194  case AffineExprKind::Add: {
195  auto op = cast<AffineBinaryOpExpr>();
196  return op.getLHS().isPureAffine() && op.getRHS().isPureAffine();
197  }
198 
199  case AffineExprKind::Mul: {
200  // TODO: Canonicalize the constants in binary operators to the RHS when
201  // possible, allowing this to merge into the next case.
202  auto op = cast<AffineBinaryOpExpr>();
203  return op.getLHS().isPureAffine() && op.getRHS().isPureAffine() &&
204  (op.getLHS().template isa<AffineConstantExpr>() ||
205  op.getRHS().template isa<AffineConstantExpr>());
206  }
209  case AffineExprKind::Mod: {
210  auto op = cast<AffineBinaryOpExpr>();
211  return op.getLHS().isPureAffine() &&
212  op.getRHS().template isa<AffineConstantExpr>();
213  }
214  }
215  llvm_unreachable("Unknown AffineExpr");
216 }
217 
218 // Returns the greatest known integral divisor of this affine expression.
220  AffineBinaryOpExpr binExpr(nullptr);
221  switch (getKind()) {
223  [[fallthrough]];
227  return 1;
229  return std::abs(this->cast<AffineConstantExpr>().getValue());
230  case AffineExprKind::Mul: {
231  binExpr = this->cast<AffineBinaryOpExpr>();
232  return binExpr.getLHS().getLargestKnownDivisor() *
233  binExpr.getRHS().getLargestKnownDivisor();
234  }
235  case AffineExprKind::Add:
236  [[fallthrough]];
237  case AffineExprKind::Mod: {
238  binExpr = cast<AffineBinaryOpExpr>();
239  return std::gcd((uint64_t)binExpr.getLHS().getLargestKnownDivisor(),
240  (uint64_t)binExpr.getRHS().getLargestKnownDivisor());
241  }
242  }
243  llvm_unreachable("Unknown AffineExpr");
244 }
245 
246 bool AffineExpr::isMultipleOf(int64_t factor) const {
247  AffineBinaryOpExpr binExpr(nullptr);
248  uint64_t l, u;
249  switch (getKind()) {
251  [[fallthrough]];
253  return factor * factor == 1;
255  return cast<AffineConstantExpr>().getValue() % factor == 0;
256  case AffineExprKind::Mul: {
257  binExpr = cast<AffineBinaryOpExpr>();
258  // It's probably not worth optimizing this further (to not traverse the
259  // whole sub-tree under - it that would require a version of isMultipleOf
260  // that on a 'false' return also returns the largest known divisor).
261  return (l = binExpr.getLHS().getLargestKnownDivisor()) % factor == 0 ||
262  (u = binExpr.getRHS().getLargestKnownDivisor()) % factor == 0 ||
263  (l * u) % factor == 0;
264  }
265  case AffineExprKind::Add:
268  case AffineExprKind::Mod: {
269  binExpr = cast<AffineBinaryOpExpr>();
270  return std::gcd((uint64_t)binExpr.getLHS().getLargestKnownDivisor(),
271  (uint64_t)binExpr.getRHS().getLargestKnownDivisor()) %
272  factor ==
273  0;
274  }
275  }
276  llvm_unreachable("Unknown AffineExpr");
277 }
278 
279 bool AffineExpr::isFunctionOfDim(unsigned position) const {
280  if (getKind() == AffineExprKind::DimId) {
281  return *this == mlir::getAffineDimExpr(position, getContext());
282  }
283  if (auto expr = this->dyn_cast<AffineBinaryOpExpr>()) {
284  return expr.getLHS().isFunctionOfDim(position) ||
285  expr.getRHS().isFunctionOfDim(position);
286  }
287  return false;
288 }
289 
290 bool AffineExpr::isFunctionOfSymbol(unsigned position) const {
291  if (getKind() == AffineExprKind::SymbolId) {
292  return *this == mlir::getAffineSymbolExpr(position, getContext());
293  }
294  if (auto expr = this->dyn_cast<AffineBinaryOpExpr>()) {
295  return expr.getLHS().isFunctionOfSymbol(position) ||
296  expr.getRHS().isFunctionOfSymbol(position);
297  }
298  return false;
299 }
300 
302  : AffineExpr(ptr) {}
304  return static_cast<ImplType *>(expr)->lhs;
305 }
307  return static_cast<ImplType *>(expr)->rhs;
308 }
309 
311 unsigned AffineDimExpr::getPosition() const {
312  return static_cast<ImplType *>(expr)->position;
313 }
314 
315 /// Returns true if the expression is divisible by the given symbol with
316 /// position `symbolPos`. The argument `opKind` specifies here what kind of
317 /// division or mod operation called this division. It helps in implementing the
318 /// commutative property of the floordiv and ceildiv operations. If the argument
319 ///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv
320 /// operation, then the commutative property can be used otherwise, the floordiv
321 /// operation is not divisible. The same argument holds for ceildiv operation.
322 static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
323  AffineExprKind opKind) {
324  // The argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
325  assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
326  opKind == AffineExprKind::CeilDiv) &&
327  "unexpected opKind");
328  switch (expr.getKind()) {
330  return expr.cast<AffineConstantExpr>().getValue() == 0;
332  return false;
334  return (expr.cast<AffineSymbolExpr>().getPosition() == symbolPos);
335  // Checks divisibility by the given symbol for both operands.
336  case AffineExprKind::Add: {
337  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
338  return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) &&
339  isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
340  }
341  // Checks divisibility by the given symbol for both operands. Consider the
342  // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
343  // this is a division by s1 and both the operands of modulo are divisible by
344  // s1 but it is not divisible by s1 always. The third argument is
345  // `AffineExprKind::Mod` for this reason.
346  case AffineExprKind::Mod: {
347  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
348  return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos,
350  isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos,
352  }
353  // Checks if any of the operand divisible by the given symbol.
354  case AffineExprKind::Mul: {
355  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
356  return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) ||
357  isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
358  }
359  // Floordiv and ceildiv are divisible by the given symbol when the first
360  // operand is divisible, and the affine expression kind of the argument expr
361  // is same as the argument `opKind`. This can be inferred from commutative
362  // property of floordiv and ceildiv operations and are as follow:
363  // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
364  // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
365  // It will fail if operations are not same. For example:
366  // (exps1 ceildiv exp2) floordiv exp3 can not be simplified.
369  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
370  if (opKind != expr.getKind())
371  return false;
372  return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind());
373  }
374  }
375  llvm_unreachable("Unknown AffineExpr");
376 }
377 
378 /// Divides the given expression by the given symbol at position `symbolPos`. It
379 /// considers the divisibility condition is checked before calling itself. A
380 /// null expression is returned whenever the divisibility condition fails.
381 static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos,
382  AffineExprKind opKind) {
383  // THe argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
384  assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
385  opKind == AffineExprKind::CeilDiv) &&
386  "unexpected opKind");
387  switch (expr.getKind()) {
389  if (expr.cast<AffineConstantExpr>().getValue() != 0)
390  return nullptr;
391  return getAffineConstantExpr(0, expr.getContext());
393  return nullptr;
395  return getAffineConstantExpr(1, expr.getContext());
396  // Dividing both operands by the given symbol.
397  case AffineExprKind::Add: {
398  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
399  return getAffineBinaryOpExpr(
400  expr.getKind(), symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind),
401  symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind));
402  }
403  // Dividing both operands by the given symbol.
404  case AffineExprKind::Mod: {
405  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
406  return getAffineBinaryOpExpr(
407  expr.getKind(),
408  symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
409  symbolicDivide(binaryExpr.getRHS(), symbolPos, expr.getKind()));
410  }
411  // Dividing any of the operand by the given symbol.
412  case AffineExprKind::Mul: {
413  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
414  if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind))
415  return binaryExpr.getLHS() *
416  symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind);
417  return symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind) *
418  binaryExpr.getRHS();
419  }
420  // Dividing first operand only by the given symbol.
423  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
424  return getAffineBinaryOpExpr(
425  expr.getKind(),
426  symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
427  binaryExpr.getRHS());
428  }
429  }
430  llvm_unreachable("Unknown AffineExpr");
431 }
432 
433 /// Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv
434 /// operations when the second operand simplifies to a symbol and the first
435 /// operand is divisible by that symbol. It can be applied to any semi-affine
436 /// expression. Returned expression can either be a semi-affine or pure affine
437 /// expression.
439  switch (expr.getKind()) {
443  return expr;
444  case AffineExprKind::Add:
445  case AffineExprKind::Mul: {
446  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
447  return getAffineBinaryOpExpr(expr.getKind(),
448  simplifySemiAffine(binaryExpr.getLHS()),
449  simplifySemiAffine(binaryExpr.getRHS()));
450  }
451  // Check if the simplification of the second operand is a symbol, and the
452  // first operand is divisible by it. If the operation is a modulo, a constant
453  // zero expression is returned. In the case of floordiv and ceildiv, the
454  // symbol from the simplification of the second operand divides the first
455  // operand. Otherwise, simplification is not possible.
458  case AffineExprKind::Mod: {
459  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
460  AffineExpr sLHS = simplifySemiAffine(binaryExpr.getLHS());
461  AffineExpr sRHS = simplifySemiAffine(binaryExpr.getRHS());
462  AffineSymbolExpr symbolExpr =
463  simplifySemiAffine(binaryExpr.getRHS()).dyn_cast<AffineSymbolExpr>();
464  if (!symbolExpr)
465  return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
466  unsigned symbolPos = symbolExpr.getPosition();
467  if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind()))
468  return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
469  if (expr.getKind() == AffineExprKind::Mod)
470  return getAffineConstantExpr(0, expr.getContext());
471  return symbolicDivide(sLHS, symbolPos, expr.getKind());
472  }
473  }
474  llvm_unreachable("Unknown AffineExpr");
475 }
476 
477 static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position,
478  MLIRContext *context) {
479  auto assignCtx = [context](AffineDimExprStorage *storage) {
480  storage->context = context;
481  };
482 
483  StorageUniquer &uniquer = context->getAffineUniquer();
484  return uniquer.get<AffineDimExprStorage>(
485  assignCtx, static_cast<unsigned>(kind), position);
486 }
487 
488 AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) {
489  return getAffineDimOrSymbol(AffineExprKind::DimId, position, context);
490 }
491 
493  : AffineExpr(ptr) {}
495  return static_cast<ImplType *>(expr)->position;
496 }
497 
498 AffineExpr mlir::getAffineSymbolExpr(unsigned position, MLIRContext *context) {
499  return getAffineDimOrSymbol(AffineExprKind::SymbolId, position, context);
500  ;
501 }
502 
504  : AffineExpr(ptr) {}
506  return static_cast<ImplType *>(expr)->constant;
507 }
508 
509 bool AffineExpr::operator==(int64_t v) const {
510  return *this == getAffineConstantExpr(v, getContext());
511 }
512 
514  auto assignCtx = [context](AffineConstantExprStorage *storage) {
515  storage->context = context;
516  };
517 
518  StorageUniquer &uniquer = context->getAffineUniquer();
519  return uniquer.get<AffineConstantExprStorage>(assignCtx, constant);
520 }
521 
522 /// Simplify add expression. Return nullptr if it can't be simplified.
524  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
525  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
526  // Fold if both LHS, RHS are a constant.
527  if (lhsConst && rhsConst)
528  return getAffineConstantExpr(lhsConst.getValue() + rhsConst.getValue(),
529  lhs.getContext());
530 
531  // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
532  // If only one of them is a symbolic expressions, make it the RHS.
533  if (lhs.isa<AffineConstantExpr>() ||
534  (lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())) {
535  return rhs + lhs;
536  }
537 
538  // At this point, if there was a constant, it would be on the right.
539 
540  // Addition with a zero is a noop, return the other input.
541  if (rhsConst) {
542  if (rhsConst.getValue() == 0)
543  return lhs;
544  }
545  // Fold successive additions like (d0 + 2) + 3 into d0 + 5.
546  auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
547  if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Add) {
548  if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>())
549  return lBin.getLHS() + (lrhs.getValue() + rhsConst.getValue());
550  }
551 
552  // Detect "c1 * expr + c_2 * expr" as "(c1 + c2) * expr".
553  // c1 is rRhsConst, c2 is rLhsConst; firstExpr, secondExpr are their
554  // respective multiplicands.
555  Optional<int64_t> rLhsConst, rRhsConst;
556  AffineExpr firstExpr, secondExpr;
557  AffineConstantExpr rLhsConstExpr;
558  auto lBinOpExpr = lhs.dyn_cast<AffineBinaryOpExpr>();
559  if (lBinOpExpr && lBinOpExpr.getKind() == AffineExprKind::Mul &&
560  (rLhsConstExpr = lBinOpExpr.getRHS().dyn_cast<AffineConstantExpr>())) {
561  rLhsConst = rLhsConstExpr.getValue();
562  firstExpr = lBinOpExpr.getLHS();
563  } else {
564  rLhsConst = 1;
565  firstExpr = lhs;
566  }
567 
568  auto rBinOpExpr = rhs.dyn_cast<AffineBinaryOpExpr>();
569  AffineConstantExpr rRhsConstExpr;
570  if (rBinOpExpr && rBinOpExpr.getKind() == AffineExprKind::Mul &&
571  (rRhsConstExpr = rBinOpExpr.getRHS().dyn_cast<AffineConstantExpr>())) {
572  rRhsConst = rRhsConstExpr.getValue();
573  secondExpr = rBinOpExpr.getLHS();
574  } else {
575  rRhsConst = 1;
576  secondExpr = rhs;
577  }
578 
579  if (rLhsConst && rRhsConst && firstExpr == secondExpr)
580  return getAffineBinaryOpExpr(
581  AffineExprKind::Mul, firstExpr,
582  getAffineConstantExpr(*rLhsConst + *rRhsConst, lhs.getContext()));
583 
584  // When doing successive additions, bring constant to the right: turn (d0 + 2)
585  // + d1 into (d0 + d1) + 2.
586  if (lBin && lBin.getKind() == AffineExprKind::Add) {
587  if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
588  return lBin.getLHS() + rhs + lrhs;
589  }
590  }
591 
592  // Detect and transform "expr - q * (expr floordiv q)" to "expr mod q", where
593  // q may be a constant or symbolic expression. This leads to a much more
594  // efficient form when 'c' is a power of two, and in general a more compact
595  // and readable form.
596 
597  // Process '(expr floordiv c) * (-c)'.
598  if (!rBinOpExpr)
599  return nullptr;
600 
601  auto lrhs = rBinOpExpr.getLHS();
602  auto rrhs = rBinOpExpr.getRHS();
603 
604  AffineExpr llrhs, rlrhs;
605 
606  // Check if lrhsBinOpExpr is of the form (expr floordiv q) * q, where q is a
607  // symbolic expression.
608  auto lrhsBinOpExpr = lrhs.dyn_cast<AffineBinaryOpExpr>();
609  // Check rrhsConstOpExpr = -1.
610  auto rrhsConstOpExpr = rrhs.dyn_cast<AffineConstantExpr>();
611  if (rrhsConstOpExpr && rrhsConstOpExpr.getValue() == -1 && lrhsBinOpExpr &&
612  lrhsBinOpExpr.getKind() == AffineExprKind::Mul) {
613  // Check llrhs = expr floordiv q.
614  llrhs = lrhsBinOpExpr.getLHS();
615  // Check rlrhs = q.
616  rlrhs = lrhsBinOpExpr.getRHS();
617  auto llrhsBinOpExpr = llrhs.dyn_cast<AffineBinaryOpExpr>();
618  if (!llrhsBinOpExpr || llrhsBinOpExpr.getKind() != AffineExprKind::FloorDiv)
619  return nullptr;
620  if (llrhsBinOpExpr.getRHS() == rlrhs && lhs == llrhsBinOpExpr.getLHS())
621  return lhs % rlrhs;
622  }
623 
624  // Process lrhs, which is 'expr floordiv c'.
625  AffineBinaryOpExpr lrBinOpExpr = lrhs.dyn_cast<AffineBinaryOpExpr>();
626  if (!lrBinOpExpr || lrBinOpExpr.getKind() != AffineExprKind::FloorDiv)
627  return nullptr;
628 
629  llrhs = lrBinOpExpr.getLHS();
630  rlrhs = lrBinOpExpr.getRHS();
631 
632  if (lhs == llrhs && rlrhs == -rrhs) {
633  return lhs % rlrhs;
634  }
635  return nullptr;
636 }
637 
639  return *this + getAffineConstantExpr(v, getContext());
640 }
642  if (auto simplified = simplifyAdd(*this, other))
643  return simplified;
644 
646  return uniquer.get<AffineBinaryOpExprStorage>(
647  /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), *this, other);
648 }
649 
650 /// Simplify a multiply expression. Return nullptr if it can't be simplified.
652  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
653  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
654 
655  if (lhsConst && rhsConst)
656  return getAffineConstantExpr(lhsConst.getValue() * rhsConst.getValue(),
657  lhs.getContext());
658 
659  assert(lhs.isSymbolicOrConstant() || rhs.isSymbolicOrConstant());
660 
661  // Canonicalize the mul expression so that the constant/symbolic term is the
662  // RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a
663  // constant. (Note that a constant is trivially symbolic).
664  if (!rhs.isSymbolicOrConstant() || lhs.isa<AffineConstantExpr>()) {
665  // At least one of them has to be symbolic.
666  return rhs * lhs;
667  }
668 
669  // At this point, if there was a constant, it would be on the right.
670 
671  // Multiplication with a one is a noop, return the other input.
672  if (rhsConst) {
673  if (rhsConst.getValue() == 1)
674  return lhs;
675  // Multiplication with zero.
676  if (rhsConst.getValue() == 0)
677  return rhsConst;
678  }
679 
680  // Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6.
681  auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
682  if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Mul) {
683  if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>())
684  return lBin.getLHS() * (lrhs.getValue() * rhsConst.getValue());
685  }
686 
687  // When doing successive multiplication, bring constant to the right: turn (d0
688  // * 2) * d1 into (d0 * d1) * 2.
689  if (lBin && lBin.getKind() == AffineExprKind::Mul) {
690  if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
691  return (lBin.getLHS() * rhs) * lrhs;
692  }
693  }
694 
695  return nullptr;
696 }
697 
699  return *this * getAffineConstantExpr(v, getContext());
700 }
702  if (auto simplified = simplifyMul(*this, other))
703  return simplified;
704 
706  return uniquer.get<AffineBinaryOpExprStorage>(
707  /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), *this, other);
708 }
709 
710 // Unary minus, delegate to operator*.
712  return *this * getAffineConstantExpr(-1, getContext());
713 }
714 
715 // Delegate to operator+.
716 AffineExpr AffineExpr::operator-(int64_t v) const { return *this + (-v); }
718  return *this + (-other);
719 }
720 
722  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
723  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
724 
725  // mlir floordiv by zero or negative numbers is undefined and preserved as is.
726  if (!rhsConst || rhsConst.getValue() < 1)
727  return nullptr;
728 
729  if (lhsConst)
730  return getAffineConstantExpr(
731  floorDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext());
732 
733  // Fold floordiv of a multiply with a constant that is a multiple of the
734  // divisor. Eg: (i * 128) floordiv 64 = i * 2.
735  if (rhsConst == 1)
736  return lhs;
737 
738  // Simplify (expr * const) floordiv divConst when expr is known to be a
739  // multiple of divConst.
740  auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
741  if (lBin && lBin.getKind() == AffineExprKind::Mul) {
742  if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
743  // rhsConst is known to be a positive constant.
744  if (lrhs.getValue() % rhsConst.getValue() == 0)
745  return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
746  }
747  }
748 
749  // Simplify (expr1 + expr2) floordiv divConst when either expr1 or expr2 is
750  // known to be a multiple of divConst.
751  if (lBin && lBin.getKind() == AffineExprKind::Add) {
752  int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
753  int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
754  // rhsConst is known to be a positive constant.
755  if (llhsDiv % rhsConst.getValue() == 0 ||
756  lrhsDiv % rhsConst.getValue() == 0)
757  return lBin.getLHS().floorDiv(rhsConst.getValue()) +
758  lBin.getRHS().floorDiv(rhsConst.getValue());
759  }
760 
761  return nullptr;
762 }
763 
764 AffineExpr AffineExpr::floorDiv(uint64_t v) const {
766 }
768  if (auto simplified = simplifyFloorDiv(*this, other))
769  return simplified;
770 
772  return uniquer.get<AffineBinaryOpExprStorage>(
773  /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::FloorDiv), *this,
774  other);
775 }
776 
778  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
779  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
780 
781  if (!rhsConst || rhsConst.getValue() < 1)
782  return nullptr;
783 
784  if (lhsConst)
785  return getAffineConstantExpr(
786  ceilDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext());
787 
788  // Fold ceildiv of a multiply with a constant that is a multiple of the
789  // divisor. Eg: (i * 128) ceildiv 64 = i * 2.
790  if (rhsConst.getValue() == 1)
791  return lhs;
792 
793  // Simplify (expr * const) ceildiv divConst when const is known to be a
794  // multiple of divConst.
795  auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
796  if (lBin && lBin.getKind() == AffineExprKind::Mul) {
797  if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
798  // rhsConst is known to be a positive constant.
799  if (lrhs.getValue() % rhsConst.getValue() == 0)
800  return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
801  }
802  }
803 
804  return nullptr;
805 }
806 
807 AffineExpr AffineExpr::ceilDiv(uint64_t v) const {
809 }
811  if (auto simplified = simplifyCeilDiv(*this, other))
812  return simplified;
813 
815  return uniquer.get<AffineBinaryOpExprStorage>(
816  /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::CeilDiv), *this,
817  other);
818 }
819 
821  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
822  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
823 
824  // mod w.r.t zero or negative numbers is undefined and preserved as is.
825  if (!rhsConst || rhsConst.getValue() < 1)
826  return nullptr;
827 
828  if (lhsConst)
829  return getAffineConstantExpr(mod(lhsConst.getValue(), rhsConst.getValue()),
830  lhs.getContext());
831 
832  // Fold modulo of an expression that is known to be a multiple of a constant
833  // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128)
834  // mod 64 is folded to 0, and less trivially, (i*(j*4*(k*32))) mod 128 = 0.
835  if (lhs.getLargestKnownDivisor() % rhsConst.getValue() == 0)
836  return getAffineConstantExpr(0, lhs.getContext());
837 
838  // Simplify (expr1 + expr2) mod divConst when either expr1 or expr2 is
839  // known to be a multiple of divConst.
840  auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
841  if (lBin && lBin.getKind() == AffineExprKind::Add) {
842  int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
843  int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
844  // rhsConst is known to be a positive constant.
845  if (llhsDiv % rhsConst.getValue() == 0)
846  return lBin.getRHS() % rhsConst.getValue();
847  if (lrhsDiv % rhsConst.getValue() == 0)
848  return lBin.getLHS() % rhsConst.getValue();
849  }
850 
851  // Simplify (e % a) % b to e % b when b evenly divides a
852  if (lBin && lBin.getKind() == AffineExprKind::Mod) {
853  auto intermediate = lBin.getRHS().dyn_cast<AffineConstantExpr>();
854  if (intermediate && intermediate.getValue() >= 1 &&
855  mod(intermediate.getValue(), rhsConst.getValue()) == 0) {
856  return lBin.getLHS() % rhsConst.getValue();
857  }
858  }
859 
860  return nullptr;
861 }
862 
864  return *this % getAffineConstantExpr(v, getContext());
865 }
867  if (auto simplified = simplifyMod(*this, other))
868  return simplified;
869 
871  return uniquer.get<AffineBinaryOpExprStorage>(
872  /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mod), *this, other);
873 }
874 
876  SmallVector<AffineExpr, 8> dimReplacements(map.getResults().begin(),
877  map.getResults().end());
878  return replaceDimsAndSymbols(dimReplacements, {});
879 }
880 raw_ostream &mlir::operator<<(raw_ostream &os, AffineExpr expr) {
881  expr.print(os);
882  return os;
883 }
884 
885 /// Constructs an affine expression from a flat ArrayRef. If there are local
886 /// identifiers (neither dimensional nor symbolic) that appear in the sum of
887 /// products expression, `localExprs` is expected to have the AffineExpr
888 /// for it, and is substituted into. The ArrayRef `flatExprs` is expected to be
889 /// in the format [dims, symbols, locals, constant term].
891  unsigned numDims,
892  unsigned numSymbols,
893  ArrayRef<AffineExpr> localExprs,
894  MLIRContext *context) {
895  // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
896  assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
897  "unexpected number of local expressions");
898 
899  auto expr = getAffineConstantExpr(0, context);
900  // Dimensions and symbols.
901  for (unsigned j = 0; j < numDims + numSymbols; j++) {
902  if (flatExprs[j] == 0)
903  continue;
904  auto id = j < numDims ? getAffineDimExpr(j, context)
905  : getAffineSymbolExpr(j - numDims, context);
906  expr = expr + id * flatExprs[j];
907  }
908 
909  // Local identifiers.
910  for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
911  j++) {
912  if (flatExprs[j] == 0)
913  continue;
914  auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
915  expr = expr + term;
916  }
917 
918  // Constant term.
919  int64_t constTerm = flatExprs[flatExprs.size() - 1];
920  if (constTerm != 0)
921  expr = expr + constTerm;
922  return expr;
923 }
924 
925 /// Constructs a semi-affine expression from a flat ArrayRef. If there are
926 /// local identifiers (neither dimensional nor symbolic) that appear in the sum
927 /// of products expression, `localExprs` is expected to have the AffineExprs for
928 /// it, and is substituted into. The ArrayRef `flatExprs` is expected to be in
929 /// the format [dims, symbols, locals, constant term]. The semi-affine
930 /// expression is constructed in the sorted order of dimension and symbol
931 /// position numbers. Note: local expressions/ids are used for mod, div as well
932 /// as symbolic RHS terms for terms that are not pure affine.
934  unsigned numDims,
935  unsigned numSymbols,
936  ArrayRef<AffineExpr> localExprs,
937  MLIRContext *context) {
938  assert(!flatExprs.empty() && "flatExprs cannot be empty");
939 
940  // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
941  assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
942  "unexpected number of local expressions");
943 
944  AffineExpr expr = getAffineConstantExpr(0, context);
945 
946  // We design indices as a pair which help us present the semi-affine map as
947  // sum of product where terms are sorted based on dimension or symbol
948  // position: <keyA, keyB> for expressions of the form dimension * symbol,
949  // where keyA is the position number of the dimension and keyB is the
950  // position number of the symbol. For dimensional expressions we set the index
951  // as (position number of the dimension, -1), as we want dimensional
952  // expressions to appear before symbolic and product of dimensional and
953  // symbolic expressions having the dimension with the same position number.
954  // For symbolic expression set the index as (position number of the symbol,
955  // maximum of last dimension and symbol position) number. For example, we want
956  // the expression we are constructing to look something like: d0 + d0 * s0 +
957  // s0 + d1*s1 + s1.
958 
959  // Stores the affine expression corresponding to a given index.
961  // Stores the constant coefficient value corresponding to a given
962  // dimension, symbol or a non-pure affine expression stored in `localExprs`.
963  DenseMap<std::pair<unsigned, signed>, int64_t> coefficients;
964  // Stores the indices as defined above, and later sorted to produce
965  // the semi-affine expression in the desired form.
967 
968  // Example: expression = d0 + d0 * s0 + 2 * s0.
969  // indices = [{0,-1}, {0, 0}, {0, 1}]
970  // coefficients = [{{0, -1}, 1}, {{0, 0}, 1}, {{0, 1}, 2}]
971  // indexToExprMap = [{{0, -1}, d0}, {{0, 0}, d0 * s0}, {{0, 1}, s0}]
972 
973  // Adds entries to `indexToExprMap`, `coefficients` and `indices`.
974  auto addEntry = [&](std::pair<unsigned, signed> index, int64_t coefficient,
975  AffineExpr expr) {
976  assert(!llvm::is_contained(indices, index) &&
977  "Key is already present in indices vector and overwriting will "
978  "happen in `indexToExprMap` and `coefficients`!");
979 
980  indices.push_back(index);
981  coefficients.insert({index, coefficient});
982  indexToExprMap.insert({index, expr});
983  };
984 
985  // Design indices for dimensional or symbolic terms, and store the indices,
986  // constant coefficient corresponding to the indices in `coefficients` map,
987  // and affine expression corresponding to indices in `indexToExprMap` map.
988 
989  for (unsigned j = 0; j < numDims; ++j) {
990  if (flatExprs[j] == 0)
991  continue;
992  // For dimensional expressions we set the index as <position number of the
993  // dimension, 0>, as we want dimensional expressions to appear before
994  // symbolic ones and products of dimensional and symbolic expressions
995  // having the dimension with the same position number.
996  std::pair<unsigned, signed> indexEntry(j, -1);
997  addEntry(indexEntry, flatExprs[j], getAffineDimExpr(j, context));
998  }
999  for (unsigned j = numDims; j < numDims + numSymbols; ++j) {
1000  if (flatExprs[j] == 0)
1001  continue;
1002  // For symbolic expression set the index as <position number
1003  // of the symbol, max(dimCount, symCount)> number,
1004  // as we want symbolic expressions with the same positional number to
1005  // appear after dimensional expressions having the same positional number.
1006  std::pair<unsigned, signed> indexEntry(j - numDims,
1007  std::max(numDims, numSymbols));
1008  addEntry(indexEntry, flatExprs[j],
1009  getAffineSymbolExpr(j - numDims, context));
1010  }
1011 
1012  // Denotes semi-affine product, modulo or division terms, which has been added
1013  // to the `indexToExpr` map.
1014  SmallVector<bool, 4> addedToMap(flatExprs.size() - numDims - numSymbols - 1,
1015  false);
1016  unsigned lhsPos, rhsPos;
1017  // Construct indices for product terms involving dimension, symbol or constant
1018  // as lhs/rhs, and store the indices, constant coefficient corresponding to
1019  // the indices in `coefficients` map, and affine expression corresponding to
1020  // in indices in `indexToExprMap` map.
1021  for (const auto &it : llvm::enumerate(localExprs)) {
1022  AffineExpr expr = it.value();
1023  if (flatExprs[numDims + numSymbols + it.index()] == 0)
1024  continue;
1025  AffineExpr lhs = expr.cast<AffineBinaryOpExpr>().getLHS();
1026  AffineExpr rhs = expr.cast<AffineBinaryOpExpr>().getRHS();
1027  if (!((lhs.isa<AffineDimExpr>() || lhs.isa<AffineSymbolExpr>()) &&
1028  (rhs.isa<AffineDimExpr>() || rhs.isa<AffineSymbolExpr>() ||
1029  rhs.isa<AffineConstantExpr>()))) {
1030  continue;
1031  }
1032  if (rhs.isa<AffineConstantExpr>()) {
1033  // For product/modulo/division expressions, when rhs of modulo/division
1034  // expression is constant, we put 0 in place of keyB, because we want
1035  // them to appear earlier in the semi-affine expression we are
1036  // constructing. When rhs is constant, we place 0 in place of keyB.
1037  if (lhs.isa<AffineDimExpr>()) {
1038  lhsPos = lhs.cast<AffineDimExpr>().getPosition();
1039  std::pair<unsigned, signed> indexEntry(lhsPos, -1);
1040  addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1041  expr);
1042  } else {
1043  lhsPos = lhs.cast<AffineSymbolExpr>().getPosition();
1044  std::pair<unsigned, signed> indexEntry(lhsPos,
1045  std::max(numDims, numSymbols));
1046  addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1047  expr);
1048  }
1049  } else if (lhs.isa<AffineDimExpr>()) {
1050  // For product/modulo/division expressions having lhs as dimension and rhs
1051  // as symbol, we order the terms in the semi-affine expression based on
1052  // the pair: <keyA, keyB> for expressions of the form dimension * symbol,
1053  // where keyA is the position number of the dimension and keyB is the
1054  // position number of the symbol.
1055  lhsPos = lhs.cast<AffineDimExpr>().getPosition();
1056  rhsPos = rhs.cast<AffineSymbolExpr>().getPosition();
1057  std::pair<unsigned, signed> indexEntry(lhsPos, rhsPos);
1058  addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1059  } else {
1060  // For product/modulo/division expressions having both lhs and rhs as
1061  // symbol, we design indices as a pair: <keyA, keyB> for expressions
1062  // of the form dimension * symbol, where keyA is the position number of
1063  // the dimension and keyB is the position number of the symbol.
1064  lhsPos = lhs.cast<AffineSymbolExpr>().getPosition();
1065  rhsPos = rhs.cast<AffineSymbolExpr>().getPosition();
1066  std::pair<unsigned, signed> indexEntry(lhsPos, rhsPos);
1067  addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1068  }
1069  addedToMap[it.index()] = true;
1070  }
1071 
1072  // Constructing the simplified semi-affine sum of product/division/mod
1073  // expression from the flattened form in the desired sorted order of indices
1074  // of the various individual product/division/mod expressions.
1075  llvm::sort(indices);
1076  for (const std::pair<unsigned, unsigned> index : indices) {
1077  assert(indexToExprMap.lookup(index) &&
1078  "cannot find key in `indexToExprMap` map");
1079  expr = expr + indexToExprMap.lookup(index) * coefficients.lookup(index);
1080  }
1081 
1082  // Local identifiers.
1083  for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
1084  j++) {
1085  // If the coefficient of the local expression is 0, continue as we need not
1086  // add it in out final expression.
1087  if (flatExprs[j] == 0 || addedToMap[j - numDims - numSymbols])
1088  continue;
1089  auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
1090  expr = expr + term;
1091  }
1092 
1093  // Constant term.
1094  int64_t constTerm = flatExprs.back();
1095  if (constTerm != 0)
1096  expr = expr + constTerm;
1097  return expr;
1098 }
1099 
1101  unsigned numSymbols)
1102  : numDims(numDims), numSymbols(numSymbols), numLocals(0) {
1103  operandExprStack.reserve(8);
1104 }
1105 
1106 // In pure affine t = expr * c, we multiply each coefficient of lhs with c.
1107 //
1108 // In case of semi affine multiplication expressions, t = expr * symbolic_expr,
1109 // introduce a local variable p (= expr * symbolic_expr), and the affine
1110 // expression expr * symbolic_expr is added to `localExprs`.
1112  assert(operandExprStack.size() >= 2);
1114  operandExprStack.pop_back();
1116 
1117  // Flatten semi-affine multiplication expressions by introducing a local
1118  // variable in place of the product; the affine expression
1119  // corresponding to the quantifier is added to `localExprs`.
1120  if (!expr.getRHS().isa<AffineConstantExpr>()) {
1121  MLIRContext *context = expr.getContext();
1123  localExprs, context);
1125  localExprs, context);
1126  addLocalVariableSemiAffine(a * b, lhs, lhs.size());
1127  return;
1128  }
1129 
1130  // Get the RHS constant.
1131  auto rhsConst = rhs[getConstantIndex()];
1132  for (unsigned i = 0, e = lhs.size(); i < e; i++) {
1133  lhs[i] *= rhsConst;
1134  }
1135 }
1136 
1138  assert(operandExprStack.size() >= 2);
1139  const auto &rhs = operandExprStack.back();
1140  auto &lhs = operandExprStack[operandExprStack.size() - 2];
1141  assert(lhs.size() == rhs.size());
1142  // Update the LHS in place.
1143  for (unsigned i = 0, e = rhs.size(); i < e; i++) {
1144  lhs[i] += rhs[i];
1145  }
1146  // Pop off the RHS.
1147  operandExprStack.pop_back();
1148 }
1149 
1150 //
1151 // t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1
1152 //
1153 // A mod expression "expr mod c" is thus flattened by introducing a new local
1154 // variable q (= expr floordiv c), such that expr mod c is replaced with
1155 // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
1156 //
1157 // In case of semi-affine modulo expressions, t = expr mod symbolic_expr,
1158 // introduce a local variable m (= expr mod symbolic_expr), and the affine
1159 // expression expr mod symbolic_expr is added to `localExprs`.
1161  assert(operandExprStack.size() >= 2);
1162 
1164  operandExprStack.pop_back();
1166  MLIRContext *context = expr.getContext();
1167 
1168  // Flatten semi affine modulo expressions by introducing a local
1169  // variable in place of the modulo value, and the affine expression
1170  // corresponding to the quantifier is added to `localExprs`.
1171  if (!expr.getRHS().isa<AffineConstantExpr>()) {
1172  AffineExpr dividendExpr = getAffineExprFromFlatForm(
1173  lhs, numDims, numSymbols, localExprs, context);
1175  localExprs, context);
1176  AffineExpr modExpr = dividendExpr % divisorExpr;
1177  addLocalVariableSemiAffine(modExpr, lhs, lhs.size());
1178  return;
1179  }
1180 
1181  int64_t rhsConst = rhs[getConstantIndex()];
1182  // TODO: handle modulo by zero case when this issue is fixed
1183  // at the other places in the IR.
1184  assert(rhsConst > 0 && "RHS constant has to be positive");
1185 
1186  // Check if the LHS expression is a multiple of modulo factor.
1187  unsigned i, e;
1188  for (i = 0, e = lhs.size(); i < e; i++)
1189  if (lhs[i] % rhsConst != 0)
1190  break;
1191  // If yes, modulo expression here simplifies to zero.
1192  if (i == lhs.size()) {
1193  std::fill(lhs.begin(), lhs.end(), 0);
1194  return;
1195  }
1196 
1197  // Add a local variable for the quotient, i.e., expr % c is replaced by
1198  // (expr - q * c) where q = expr floordiv c. Do this while canceling out
1199  // the GCD of expr and c.
1200  SmallVector<int64_t, 8> floorDividend(lhs);
1201  uint64_t gcd = rhsConst;
1202  for (unsigned i = 0, e = lhs.size(); i < e; i++)
1203  gcd = std::gcd(gcd, (uint64_t)std::abs(lhs[i]));
1204  // Simplify the numerator and the denominator.
1205  if (gcd != 1) {
1206  for (unsigned i = 0, e = floorDividend.size(); i < e; i++)
1207  floorDividend[i] = floorDividend[i] / static_cast<int64_t>(gcd);
1208  }
1209  int64_t floorDivisor = rhsConst / static_cast<int64_t>(gcd);
1210 
1211  // Construct the AffineExpr form of the floordiv to store in localExprs.
1212 
1213  AffineExpr dividendExpr = getAffineExprFromFlatForm(
1214  floorDividend, numDims, numSymbols, localExprs, context);
1215  AffineExpr divisorExpr = getAffineConstantExpr(floorDivisor, context);
1216  AffineExpr floorDivExpr = dividendExpr.floorDiv(divisorExpr);
1217  int loc;
1218  if ((loc = findLocalId(floorDivExpr)) == -1) {
1219  addLocalFloorDivId(floorDividend, floorDivisor, floorDivExpr);
1220  // Set result at top of stack to "lhs - rhsConst * q".
1221  lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst;
1222  } else {
1223  // Reuse the existing local id.
1224  lhs[getLocalVarStartIndex() + loc] = -rhsConst;
1225  }
1226 }
1227 
1229  visitDivExpr(expr, /*isCeil=*/true);
1230 }
1232  visitDivExpr(expr, /*isCeil=*/false);
1233 }
1234 
1236  operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1237  auto &eq = operandExprStack.back();
1238  assert(expr.getPosition() < numDims && "Inconsistent number of dims");
1239  eq[getDimStartIndex() + expr.getPosition()] = 1;
1240 }
1241 
1243  operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1244  auto &eq = operandExprStack.back();
1245  assert(expr.getPosition() < numSymbols && "inconsistent number of symbols");
1246  eq[getSymbolStartIndex() + expr.getPosition()] = 1;
1247 }
1248 
1250  operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1251  auto &eq = operandExprStack.back();
1252  eq[getConstantIndex()] = expr.getValue();
1253 }
1254 
1255 void SimpleAffineExprFlattener::addLocalVariableSemiAffine(
1256  AffineExpr expr, SmallVectorImpl<int64_t> &result,
1257  unsigned long resultSize) {
1258  assert(result.size() == resultSize &&
1259  "`result` vector passed is not of correct size");
1260  int loc;
1261  if ((loc = findLocalId(expr)) == -1)
1262  addLocalIdSemiAffine(expr);
1263  std::fill(result.begin(), result.end(), 0);
1264  if (loc == -1)
1265  result[getLocalVarStartIndex() + numLocals - 1] = 1;
1266  else
1267  result[getLocalVarStartIndex() + loc] = 1;
1268 }
1269 
1270 // t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1
1271 // A floordiv is thus flattened by introducing a new local variable q, and
1272 // replacing that expression with 'q' while adding the constraints
1273 // c * q <= expr <= c * q + c - 1 to localVarCst (done by
1274 // FlatAffineConstraints::addLocalFloorDiv).
1275 //
1276 // A ceildiv is similarly flattened:
1277 // t = expr ceildiv c <=> t = (expr + c - 1) floordiv c
1278 //
1279 // In case of semi affine division expressions, t = expr floordiv symbolic_expr
1280 // or t = expr ceildiv symbolic_expr, introduce a local variable q (= expr
1281 // floordiv/ceildiv symbolic_expr), and the affine floordiv/ceildiv is added to
1282 // `localExprs`.
1283 void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
1284  bool isCeil) {
1285  assert(operandExprStack.size() >= 2);
1286 
1287  MLIRContext *context = expr.getContext();
1289  operandExprStack.pop_back();
1291 
1292  // Flatten semi affine division expressions by introducing a local
1293  // variable in place of the quotient, and the affine expression corresponding
1294  // to the quantifier is added to `localExprs`.
1295  if (!expr.getRHS().isa<AffineConstantExpr>()) {
1297  localExprs, context);
1299  localExprs, context);
1300  AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
1301  addLocalVariableSemiAffine(divExpr, lhs, lhs.size());
1302  return;
1303  }
1304 
1305  // This is a pure affine expr; the RHS is a positive constant.
1306  int64_t rhsConst = rhs[getConstantIndex()];
1307  // TODO: handle division by zero at the same time the issue is
1308  // fixed at other places.
1309  assert(rhsConst > 0 && "RHS constant has to be positive");
1310 
1311  // Simplify the floordiv, ceildiv if possible by canceling out the greatest
1312  // common divisors of the numerator and denominator.
1313  uint64_t gcd = std::abs(rhsConst);
1314  for (unsigned i = 0, e = lhs.size(); i < e; i++)
1315  gcd = std::gcd(gcd, (uint64_t)std::abs(lhs[i]));
1316  // Simplify the numerator and the denominator.
1317  if (gcd != 1) {
1318  for (unsigned i = 0, e = lhs.size(); i < e; i++)
1319  lhs[i] = lhs[i] / static_cast<int64_t>(gcd);
1320  }
1321  int64_t divisor = rhsConst / static_cast<int64_t>(gcd);
1322  // If the divisor becomes 1, the updated LHS is the result. (The
1323  // divisor can't be negative since rhsConst is positive).
1324  if (divisor == 1)
1325  return;
1326 
1327  // If the divisor cannot be simplified to one, we will have to retain
1328  // the ceil/floor expr (simplified up until here). Add an existential
1329  // quantifier to express its result, i.e., expr1 div expr2 is replaced
1330  // by a new identifier, q.
1331  AffineExpr a =
1333  AffineExpr b = getAffineConstantExpr(divisor, context);
1334 
1335  int loc;
1336  AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
1337  if ((loc = findLocalId(divExpr)) == -1) {
1338  if (!isCeil) {
1339  SmallVector<int64_t, 8> dividend(lhs);
1340  addLocalFloorDivId(dividend, divisor, divExpr);
1341  } else {
1342  // lhs ceildiv c <=> (lhs + c - 1) floordiv c
1343  SmallVector<int64_t, 8> dividend(lhs);
1344  dividend.back() += divisor - 1;
1345  addLocalFloorDivId(dividend, divisor, divExpr);
1346  }
1347  }
1348  // Set the expression on stack to the local var introduced to capture the
1349  // result of the division (floor or ceil).
1350  std::fill(lhs.begin(), lhs.end(), 0);
1351  if (loc == -1)
1352  lhs[getLocalVarStartIndex() + numLocals - 1] = 1;
1353  else
1354  lhs[getLocalVarStartIndex() + loc] = 1;
1355 }
1356 
1357 // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
1358 // The local identifier added is always a floordiv of a pure add/mul affine
1359 // function of other identifiers, coefficients of which are specified in
1360 // dividend and with respect to a positive constant divisor. localExpr is the
1361 // simplified tree expression (AffineExpr) corresponding to the quantifier.
1363  int64_t divisor,
1364  AffineExpr localExpr) {
1365  assert(divisor > 0 && "positive constant divisor expected");
1366  for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
1367  subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
1368  localExprs.push_back(localExpr);
1369  numLocals++;
1370  // dividend and divisor are not used here; an override of this method uses it.
1371 }
1372 
1374  for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
1375  subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
1376  localExprs.push_back(localExpr);
1377  ++numLocals;
1378 }
1379 
1380 int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) {
1382  if ((it = llvm::find(localExprs, localExpr)) == localExprs.end())
1383  return -1;
1384  return it - localExprs.begin();
1385 }
1386 
1387 /// Simplify the affine expression by flattening it and reconstructing it.
1389  unsigned numSymbols) {
1390  // Simplify semi-affine expressions separately.
1391  if (!expr.isPureAffine())
1392  expr = simplifySemiAffine(expr);
1393 
1394  SimpleAffineExprFlattener flattener(numDims, numSymbols);
1395  flattener.walkPostOrder(expr);
1396  ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
1397  if (!expr.isPureAffine() &&
1398  expr == getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1399  flattener.localExprs,
1400  expr.getContext()))
1401  return expr;
1402  AffineExpr simplifiedExpr =
1403  expr.isPureAffine()
1404  ? getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1405  flattener.localExprs, expr.getContext())
1406  : getSemiAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1407  flattener.localExprs,
1408  expr.getContext());
1409 
1410  flattener.operandExprStack.pop_back();
1411  assert(flattener.operandExprStack.empty());
1412  return simplifiedExpr;
1413 }
Affine binary operation expression.
Definition: AffineExpr.h:207
Include the generated interface declarations.
static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef< int64_t > flatExprs, unsigned numDims, unsigned numSymbols, ArrayRef< AffineExpr > localExprs, MLIRContext *context)
Constructs a semi-affine expression from a flat ArrayRef.
Definition: AffineExpr.cpp:933
StorageUniquer & getAffineUniquer()
Returns the storage uniquer used for creating affine constructs.
RHS of mod is always a constant or a symbolic expression with a positive value.
Base storage class appearing in an affine expression.
LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt abs(const MPInt &x)
Definition: MPInt.h:370
AffineExpr replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements) const
This method substitutes any uses of dimensions and symbols (e.g.
Definition: AffineExpr.cpp:65
Storage * get(function_ref< void(Storage *)> initFn, TypeID id, Args &&...args)
Gets a uniqued instance of &#39;Storage&#39;.
AffineConstantExpr(AffineExpr::ImplType *ptr=nullptr)
Definition: AffineExpr.cpp:503
bool isPureAffine() const
Returns true if this is a pure affine expression, i.e., multiplication, floordiv, ceildiv...
Definition: AffineExpr.cpp:188
static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos, AffineExprKind opKind)
Returns true if the expression is divisible by the given symbol with position symbolPos.
Definition: AffineExpr.cpp:322
AffineExpr compose(AffineMap map) const
Compose with an AffineMap.
Definition: AffineExpr.cpp:875
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
Definition: AffineExpr.cpp:279
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:513
int64_t getValue() const
Definition: AffineExpr.cpp:505
AffineExpr shiftSymbols(unsigned numSymbols, unsigned shift, unsigned offset=0) const
Replace symbols[offset ...
Definition: AffineExpr.cpp:121
AffineBinaryOpExpr(AffineExpr::ImplType *ptr)
Definition: AffineExpr.cpp:301
bool operator==(AffineExpr other) const
Definition: AffineExpr.h:76
AffineExpr shiftDims(unsigned numDims, unsigned shift, unsigned offset=0) const
Replace dims[offset ...
Definition: AffineExpr.cpp:109
A binary operation appearing in an affine expression.
RetTy walkPostOrder(AffineExpr expr)
ImplType * expr
Definition: AffineExpr.h:198
AffineSymbolExpr(AffineExpr::ImplType *ptr)
Definition: AffineExpr.cpp:492
unsigned getPosition() const
Definition: AffineExpr.cpp:311
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
An integer constant appearing in affine expression.
Definition: AffineExpr.h:232
AffineExpr getAffineExprFromFlatForm(ArrayRef< int64_t > flatExprs, unsigned numDims, unsigned numSymbols, ArrayRef< AffineExpr > localExprs, MLIRContext *context)
Constructs an affine expression from a flat ArrayRef.
Definition: AffineExpr.cpp:890
void walk(std::function< void(AffineExpr)> callback) const
Walk all of the AffineExpr&#39;s in this expression in postorder.
Definition: AffineExpr.cpp:29
void visitConstantExpr(AffineConstantExpr expr)
AffineExpr getRHS() const
Definition: AffineExpr.cpp:306
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
bool isMultipleOf(int64_t factor) const
Return true if the affine expression is a multiple of &#39;factor&#39;.
Definition: AffineExpr.cpp:246
Base class for AffineExpr visitors/walkers.
bool isSymbolicOrConstant() const
Returns true if this expression is made out of only symbols and constants, i.e., it does not involve ...
Definition: AffineExpr.cpp:164
static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos, AffineExprKind opKind)
Divides the given expression by the given symbol at position symbolPos.
Definition: AffineExpr.cpp:381
U dyn_cast() const
Definition: AffineExpr.h:281
AffineExpr operator*(int64_t v) const
Definition: AffineExpr.cpp:698
static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:820
AffineDimExpr(AffineExpr::ImplType *ptr)
Definition: AffineExpr.cpp:310
void visitSymbolExpr(AffineSymbolExpr expr)
static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:777
static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs)
Simplify a multiply expression. Return nullptr if it can&#39;t be simplified.
Definition: AffineExpr.cpp:651
AffineExpr getLHS() const
Definition: AffineExpr.cpp:303
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:498
virtual void addLocalIdSemiAffine(AffineExpr localExpr)
Add a local identifier (needed to flatten a mod, floordiv, ceildiv, mul expr) when the rhs is a symbo...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:232
Base type for affine expression.
Definition: AffineExpr.h:68
MLIRContext * getContext() const
Definition: AffineExpr.cpp:24
int64_t mod(int64_t lhs, int64_t rhs)
Returns MLIR&#39;s mod operation on constants.
Definition: MathExtras.h:45
RHS of mul is always a constant or a symbolic expression.
void visitCeilDivExpr(AffineBinaryOpExpr expr)
LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt gcd(const MPInt &a, const MPInt &b)
Definition: MPInt.h:399
virtual void addLocalFloorDivId(ArrayRef< int64_t > dividend, int64_t divisor, AffineExpr localExpr)
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
Definition: AffineExpr.cpp:219
A multi-dimensional affine map Affine map&#39;s are immutable like Type&#39;s, and they are uniqued...
Definition: AffineMap.h:42
U cast() const
Definition: AffineExpr.h:291
A utility class to get or create instances of "storage classes".
RHS of floordiv is always a constant or a symbolic expression.
AffineExpr ceilDiv(uint64_t v) const
Definition: AffineExpr.cpp:807
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:319
bool isFunctionOfSymbol(unsigned position) const
Return true if the affine expression involves AffineSymbolExpr position.
Definition: AffineExpr.cpp:290
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:488
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:764
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:46
RHS of ceildiv is always a constant or a symbolic expression.
bool isa() const
Definition: AffineExpr.h:270
unsigned getPosition() const
Definition: AffineExpr.cpp:494
void visitFloorDivExpr(AffineBinaryOpExpr expr)
AffineExpr replaceSymbols(ArrayRef< AffineExpr > symReplacements) const
Symbol-only version of replaceDimsAndSymbols.
Definition: AffineExpr.cpp:103
static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:721
AffineExpr replaceDims(ArrayRef< AffineExpr > dimReplacements) const
Dim-only version of replaceDimsAndSymbols.
Definition: AffineExpr.cpp:98
AffineExpr replace(AffineExpr expr, AffineExpr replacement) const
Sparse replace method.
Definition: AffineExpr.cpp:157
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:26
A dimensional or symbolic identifier appearing in an affine expression.
static AffineExpr simplifySemiAffine(AffineExpr expr)
Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv operations when the second...
Definition: AffineExpr.cpp:438
void print(raw_ostream &os) const
Symbolic identifier.
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
void visitDimExpr(AffineDimExpr expr)
static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs)
Simplify add expression. Return nullptr if it can&#39;t be simplified.
Definition: AffineExpr.cpp:523
SimpleAffineExprFlattener(unsigned numDims, unsigned numSymbols)
AffineExprKind
Definition: AffineExpr.h:40
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
AffineExpr operator+(int64_t v) const
Definition: AffineExpr.cpp:638
An integer constant appearing in affine expression.
void visitMulExpr(AffineBinaryOpExpr expr)
void visitModExpr(AffineBinaryOpExpr expr)
static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:477
Dimensional identifier.
AffineExpr operator-() const
Definition: AffineExpr.cpp:711
AffineExpr operator%(uint64_t v) const
Definition: AffineExpr.cpp:863
void visitAddExpr(AffineBinaryOpExpr expr)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
std::vector< SmallVector< int64_t, 8 > > operandExprStack
A symbolic identifier appearing in an affine expression.
Definition: AffineExpr.h:224
SmallVector< AffineExpr, 4 > localExprs