MLIR  15.0.0git
AffineExpr.cpp
Go to the documentation of this file.
1 //===- AffineExpr.cpp - MLIR Affine Expr Classes --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/IR/AffineExpr.h"
12 #include "AffineExprDetail.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/IntegerSet.h"
17 #include "mlir/Support/TypeID.h"
18 #include "llvm/ADT/STLExtras.h"
19 
20 using namespace mlir;
21 using namespace mlir::detail;
22 
23 MLIRContext *AffineExpr::getContext() const { return expr->context; }
24 
25 AffineExprKind AffineExpr::getKind() const { return expr->kind; }
26 
27 /// Walk all of the AffineExprs in this subgraph in postorder.
28 void AffineExpr::walk(std::function<void(AffineExpr)> callback) const {
29  struct AffineExprWalker : public AffineExprVisitor<AffineExprWalker> {
30  std::function<void(AffineExpr)> callback;
31 
32  AffineExprWalker(std::function<void(AffineExpr)> callback)
33  : callback(std::move(callback)) {}
34 
35  void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { callback(expr); }
36  void visitConstantExpr(AffineConstantExpr expr) { callback(expr); }
37  void visitDimExpr(AffineDimExpr expr) { callback(expr); }
38  void visitSymbolExpr(AffineSymbolExpr expr) { callback(expr); }
39  };
40 
41  AffineExprWalker(std::move(callback)).walkPostOrder(*this);
42 }
43 
44 // Dispatch affine expression construction based on kind.
46  AffineExpr rhs) {
47  if (kind == AffineExprKind::Add)
48  return lhs + rhs;
49  if (kind == AffineExprKind::Mul)
50  return lhs * rhs;
51  if (kind == AffineExprKind::FloorDiv)
52  return lhs.floorDiv(rhs);
53  if (kind == AffineExprKind::CeilDiv)
54  return lhs.ceilDiv(rhs);
55  if (kind == AffineExprKind::Mod)
56  return lhs % rhs;
57 
58  llvm_unreachable("unknown binary operation on affine expressions");
59 }
60 
61 /// This method substitutes any uses of dimensions and symbols (e.g.
62 /// dim#0 with dimReplacements[0]) and returns the modified expression tree.
65  ArrayRef<AffineExpr> symReplacements) const {
66  switch (getKind()) {
68  return *this;
69  case AffineExprKind::DimId: {
70  unsigned dimId = cast<AffineDimExpr>().getPosition();
71  if (dimId >= dimReplacements.size())
72  return *this;
73  return dimReplacements[dimId];
74  }
76  unsigned symId = cast<AffineSymbolExpr>().getPosition();
77  if (symId >= symReplacements.size())
78  return *this;
79  return symReplacements[symId];
80  }
86  auto binOp = cast<AffineBinaryOpExpr>();
87  auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
88  auto newLHS = lhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
89  auto newRHS = rhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
90  if (newLHS == lhs && newRHS == rhs)
91  return *this;
92  return getAffineBinaryOpExpr(getKind(), newLHS, newRHS);
93  }
94  llvm_unreachable("Unknown AffineExpr");
95 }
96 
98  return replaceDimsAndSymbols(dimReplacements, {});
99 }
100 
103  return replaceDimsAndSymbols({}, symReplacements);
104 }
105 
106 /// Replace dims[offset ... numDims)
107 /// by dims[offset + shift ... shift + numDims).
108 AffineExpr AffineExpr::shiftDims(unsigned numDims, unsigned shift,
109  unsigned offset) const {
111  for (unsigned idx = 0; idx < offset; ++idx)
112  dims.push_back(getAffineDimExpr(idx, getContext()));
113  for (unsigned idx = offset; idx < numDims; ++idx)
114  dims.push_back(getAffineDimExpr(idx + shift, getContext()));
115  return replaceDimsAndSymbols(dims, {});
116 }
117 
118 /// Replace symbols[offset ... numSymbols)
119 /// by symbols[offset + shift ... shift + numSymbols).
120 AffineExpr AffineExpr::shiftSymbols(unsigned numSymbols, unsigned shift,
121  unsigned offset) const {
123  for (unsigned idx = 0; idx < offset; ++idx)
124  symbols.push_back(getAffineSymbolExpr(idx, getContext()));
125  for (unsigned idx = offset; idx < numSymbols; ++idx)
126  symbols.push_back(getAffineSymbolExpr(idx + shift, getContext()));
127  return replaceDimsAndSymbols({}, symbols);
128 }
129 
130 /// Sparse replace method. Return the modified expression tree.
133  auto it = map.find(*this);
134  if (it != map.end())
135  return it->second;
136  switch (getKind()) {
137  default:
138  return *this;
139  case AffineExprKind::Add:
140  case AffineExprKind::Mul:
143  case AffineExprKind::Mod:
144  auto binOp = cast<AffineBinaryOpExpr>();
145  auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
146  auto newLHS = lhs.replace(map);
147  auto newRHS = rhs.replace(map);
148  if (newLHS == lhs && newRHS == rhs)
149  return *this;
150  return getAffineBinaryOpExpr(getKind(), newLHS, newRHS);
151  }
152  llvm_unreachable("Unknown AffineExpr");
153 }
154 
155 /// Sparse replace method. Return the modified expression tree.
158  map.insert(std::make_pair(expr, replacement));
159  return replace(map);
160 }
161 /// Returns true if this expression is made out of only symbols and
162 /// constants (no dimensional identifiers).
164  switch (getKind()) {
166  return true;
168  return false;
170  return true;
171 
172  case AffineExprKind::Add:
173  case AffineExprKind::Mul:
176  case AffineExprKind::Mod: {
177  auto expr = this->cast<AffineBinaryOpExpr>();
178  return expr.getLHS().isSymbolicOrConstant() &&
179  expr.getRHS().isSymbolicOrConstant();
180  }
181  }
182  llvm_unreachable("Unknown AffineExpr");
183 }
184 
185 /// Returns true if this is a pure affine expression, i.e., multiplication,
186 /// floordiv, ceildiv, and mod is only allowed w.r.t constants.
188  switch (getKind()) {
192  return true;
193  case AffineExprKind::Add: {
194  auto op = cast<AffineBinaryOpExpr>();
195  return op.getLHS().isPureAffine() && op.getRHS().isPureAffine();
196  }
197 
198  case AffineExprKind::Mul: {
199  // TODO: Canonicalize the constants in binary operators to the RHS when
200  // possible, allowing this to merge into the next case.
201  auto op = cast<AffineBinaryOpExpr>();
202  return op.getLHS().isPureAffine() && op.getRHS().isPureAffine() &&
203  (op.getLHS().template isa<AffineConstantExpr>() ||
204  op.getRHS().template isa<AffineConstantExpr>());
205  }
208  case AffineExprKind::Mod: {
209  auto op = cast<AffineBinaryOpExpr>();
210  return op.getLHS().isPureAffine() &&
211  op.getRHS().template isa<AffineConstantExpr>();
212  }
213  }
214  llvm_unreachable("Unknown AffineExpr");
215 }
216 
217 // Returns the greatest known integral divisor of this affine expression.
219  AffineBinaryOpExpr binExpr(nullptr);
220  switch (getKind()) {
222  LLVM_FALLTHROUGH;
226  return 1;
228  return std::abs(this->cast<AffineConstantExpr>().getValue());
229  case AffineExprKind::Mul: {
230  binExpr = this->cast<AffineBinaryOpExpr>();
231  return binExpr.getLHS().getLargestKnownDivisor() *
232  binExpr.getRHS().getLargestKnownDivisor();
233  }
234  case AffineExprKind::Add:
235  LLVM_FALLTHROUGH;
236  case AffineExprKind::Mod: {
237  binExpr = cast<AffineBinaryOpExpr>();
238  return llvm::GreatestCommonDivisor64(
239  binExpr.getLHS().getLargestKnownDivisor(),
240  binExpr.getRHS().getLargestKnownDivisor());
241  }
242  }
243  llvm_unreachable("Unknown AffineExpr");
244 }
245 
246 bool AffineExpr::isMultipleOf(int64_t factor) const {
247  AffineBinaryOpExpr binExpr(nullptr);
248  uint64_t l, u;
249  switch (getKind()) {
251  LLVM_FALLTHROUGH;
253  return factor * factor == 1;
255  return cast<AffineConstantExpr>().getValue() % factor == 0;
256  case AffineExprKind::Mul: {
257  binExpr = cast<AffineBinaryOpExpr>();
258  // It's probably not worth optimizing this further (to not traverse the
259  // whole sub-tree under - it that would require a version of isMultipleOf
260  // that on a 'false' return also returns the largest known divisor).
261  return (l = binExpr.getLHS().getLargestKnownDivisor()) % factor == 0 ||
262  (u = binExpr.getRHS().getLargestKnownDivisor()) % factor == 0 ||
263  (l * u) % factor == 0;
264  }
265  case AffineExprKind::Add:
268  case AffineExprKind::Mod: {
269  binExpr = cast<AffineBinaryOpExpr>();
270  return llvm::GreatestCommonDivisor64(
271  binExpr.getLHS().getLargestKnownDivisor(),
272  binExpr.getRHS().getLargestKnownDivisor()) %
273  factor ==
274  0;
275  }
276  }
277  llvm_unreachable("Unknown AffineExpr");
278 }
279 
280 bool AffineExpr::isFunctionOfDim(unsigned position) const {
281  if (getKind() == AffineExprKind::DimId) {
282  return *this == mlir::getAffineDimExpr(position, getContext());
283  }
284  if (auto expr = this->dyn_cast<AffineBinaryOpExpr>()) {
285  return expr.getLHS().isFunctionOfDim(position) ||
286  expr.getRHS().isFunctionOfDim(position);
287  }
288  return false;
289 }
290 
291 bool AffineExpr::isFunctionOfSymbol(unsigned position) const {
292  if (getKind() == AffineExprKind::SymbolId) {
293  return *this == mlir::getAffineSymbolExpr(position, getContext());
294  }
295  if (auto expr = this->dyn_cast<AffineBinaryOpExpr>()) {
296  return expr.getLHS().isFunctionOfSymbol(position) ||
297  expr.getRHS().isFunctionOfSymbol(position);
298  }
299  return false;
300 }
301 
303  : AffineExpr(ptr) {}
305  return static_cast<ImplType *>(expr)->lhs;
306 }
308  return static_cast<ImplType *>(expr)->rhs;
309 }
310 
312 unsigned AffineDimExpr::getPosition() const {
313  return static_cast<ImplType *>(expr)->position;
314 }
315 
316 /// Returns true if the expression is divisible by the given symbol with
317 /// position `symbolPos`. The argument `opKind` specifies here what kind of
318 /// division or mod operation called this division. It helps in implementing the
319 /// commutative property of the floordiv and ceildiv operations. If the argument
320 ///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv
321 /// operation, then the commutative property can be used otherwise, the floordiv
322 /// operation is not divisible. The same argument holds for ceildiv operation.
323 static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
324  AffineExprKind opKind) {
325  // The argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
326  assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
327  opKind == AffineExprKind::CeilDiv) &&
328  "unexpected opKind");
329  switch (expr.getKind()) {
331  return expr.cast<AffineConstantExpr>().getValue() == 0;
333  return false;
335  return (expr.cast<AffineSymbolExpr>().getPosition() == symbolPos);
336  // Checks divisibility by the given symbol for both operands.
337  case AffineExprKind::Add: {
338  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
339  return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) &&
340  isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
341  }
342  // Checks divisibility by the given symbol for both operands. Consider the
343  // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
344  // this is a division by s1 and both the operands of modulo are divisible by
345  // s1 but it is not divisible by s1 always. The third argument is
346  // `AffineExprKind::Mod` for this reason.
347  case AffineExprKind::Mod: {
348  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
349  return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos,
351  isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos,
353  }
354  // Checks if any of the operand divisible by the given symbol.
355  case AffineExprKind::Mul: {
356  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
357  return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) ||
358  isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
359  }
360  // Floordiv and ceildiv are divisible by the given symbol when the first
361  // operand is divisible, and the affine expression kind of the argument expr
362  // is same as the argument `opKind`. This can be inferred from commutative
363  // property of floordiv and ceildiv operations and are as follow:
364  // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
365  // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
366  // It will fail if operations are not same. For example:
367  // (exps1 ceildiv exp2) floordiv exp3 can not be simplified.
370  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
371  if (opKind != expr.getKind())
372  return false;
373  return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind());
374  }
375  }
376  llvm_unreachable("Unknown AffineExpr");
377 }
378 
379 /// Divides the given expression by the given symbol at position `symbolPos`. It
380 /// considers the divisibility condition is checked before calling itself. A
381 /// null expression is returned whenever the divisibility condition fails.
382 static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos,
383  AffineExprKind opKind) {
384  // THe argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
385  assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
386  opKind == AffineExprKind::CeilDiv) &&
387  "unexpected opKind");
388  switch (expr.getKind()) {
390  if (expr.cast<AffineConstantExpr>().getValue() != 0)
391  return nullptr;
392  return getAffineConstantExpr(0, expr.getContext());
394  return nullptr;
396  return getAffineConstantExpr(1, expr.getContext());
397  // Dividing both operands by the given symbol.
398  case AffineExprKind::Add: {
399  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
400  return getAffineBinaryOpExpr(
401  expr.getKind(), symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind),
402  symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind));
403  }
404  // Dividing both operands by the given symbol.
405  case AffineExprKind::Mod: {
406  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
407  return getAffineBinaryOpExpr(
408  expr.getKind(),
409  symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
410  symbolicDivide(binaryExpr.getRHS(), symbolPos, expr.getKind()));
411  }
412  // Dividing any of the operand by the given symbol.
413  case AffineExprKind::Mul: {
414  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
415  if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind))
416  return binaryExpr.getLHS() *
417  symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind);
418  return symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind) *
419  binaryExpr.getRHS();
420  }
421  // Dividing first operand only by the given symbol.
424  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
425  return getAffineBinaryOpExpr(
426  expr.getKind(),
427  symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
428  binaryExpr.getRHS());
429  }
430  }
431  llvm_unreachable("Unknown AffineExpr");
432 }
433 
434 /// Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv
435 /// operations when the second operand simplifies to a symbol and the first
436 /// operand is divisible by that symbol. It can be applied to any semi-affine
437 /// expression. Returned expression can either be a semi-affine or pure affine
438 /// expression.
440  switch (expr.getKind()) {
444  return expr;
445  case AffineExprKind::Add:
446  case AffineExprKind::Mul: {
447  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
448  return getAffineBinaryOpExpr(expr.getKind(),
449  simplifySemiAffine(binaryExpr.getLHS()),
450  simplifySemiAffine(binaryExpr.getRHS()));
451  }
452  // Check if the simplification of the second operand is a symbol, and the
453  // first operand is divisible by it. If the operation is a modulo, a constant
454  // zero expression is returned. In the case of floordiv and ceildiv, the
455  // symbol from the simplification of the second operand divides the first
456  // operand. Otherwise, simplification is not possible.
459  case AffineExprKind::Mod: {
460  AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
461  AffineExpr sLHS = simplifySemiAffine(binaryExpr.getLHS());
462  AffineExpr sRHS = simplifySemiAffine(binaryExpr.getRHS());
463  AffineSymbolExpr symbolExpr =
464  simplifySemiAffine(binaryExpr.getRHS()).dyn_cast<AffineSymbolExpr>();
465  if (!symbolExpr)
466  return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
467  unsigned symbolPos = symbolExpr.getPosition();
468  if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind()))
469  return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
470  if (expr.getKind() == AffineExprKind::Mod)
471  return getAffineConstantExpr(0, expr.getContext());
472  return symbolicDivide(sLHS, symbolPos, expr.getKind());
473  }
474  }
475  llvm_unreachable("Unknown AffineExpr");
476 }
477 
478 static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position,
479  MLIRContext *context) {
480  auto assignCtx = [context](AffineDimExprStorage *storage) {
481  storage->context = context;
482  };
483 
484  StorageUniquer &uniquer = context->getAffineUniquer();
485  return uniquer.get<AffineDimExprStorage>(
486  assignCtx, static_cast<unsigned>(kind), position);
487 }
488 
489 AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) {
490  return getAffineDimOrSymbol(AffineExprKind::DimId, position, context);
491 }
492 
494  : AffineExpr(ptr) {}
496  return static_cast<ImplType *>(expr)->position;
497 }
498 
499 AffineExpr mlir::getAffineSymbolExpr(unsigned position, MLIRContext *context) {
500  return getAffineDimOrSymbol(AffineExprKind::SymbolId, position, context);
501  ;
502 }
503 
505  : AffineExpr(ptr) {}
507  return static_cast<ImplType *>(expr)->constant;
508 }
509 
510 bool AffineExpr::operator==(int64_t v) const {
511  return *this == getAffineConstantExpr(v, getContext());
512 }
513 
515  auto assignCtx = [context](AffineConstantExprStorage *storage) {
516  storage->context = context;
517  };
518 
519  StorageUniquer &uniquer = context->getAffineUniquer();
520  return uniquer.get<AffineConstantExprStorage>(assignCtx, constant);
521 }
522 
523 /// Simplify add expression. Return nullptr if it can't be simplified.
525  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
526  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
527  // Fold if both LHS, RHS are a constant.
528  if (lhsConst && rhsConst)
529  return getAffineConstantExpr(lhsConst.getValue() + rhsConst.getValue(),
530  lhs.getContext());
531 
532  // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
533  // If only one of them is a symbolic expressions, make it the RHS.
534  if (lhs.isa<AffineConstantExpr>() ||
535  (lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())) {
536  return rhs + lhs;
537  }
538 
539  // At this point, if there was a constant, it would be on the right.
540 
541  // Addition with a zero is a noop, return the other input.
542  if (rhsConst) {
543  if (rhsConst.getValue() == 0)
544  return lhs;
545  }
546  // Fold successive additions like (d0 + 2) + 3 into d0 + 5.
547  auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
548  if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Add) {
549  if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>())
550  return lBin.getLHS() + (lrhs.getValue() + rhsConst.getValue());
551  }
552 
553  // Detect "c1 * expr + c_2 * expr" as "(c1 + c2) * expr".
554  // c1 is rRhsConst, c2 is rLhsConst; firstExpr, secondExpr are their
555  // respective multiplicands.
556  Optional<int64_t> rLhsConst, rRhsConst;
557  AffineExpr firstExpr, secondExpr;
558  AffineConstantExpr rLhsConstExpr;
559  auto lBinOpExpr = lhs.dyn_cast<AffineBinaryOpExpr>();
560  if (lBinOpExpr && lBinOpExpr.getKind() == AffineExprKind::Mul &&
561  (rLhsConstExpr = lBinOpExpr.getRHS().dyn_cast<AffineConstantExpr>())) {
562  rLhsConst = rLhsConstExpr.getValue();
563  firstExpr = lBinOpExpr.getLHS();
564  } else {
565  rLhsConst = 1;
566  firstExpr = lhs;
567  }
568 
569  auto rBinOpExpr = rhs.dyn_cast<AffineBinaryOpExpr>();
570  AffineConstantExpr rRhsConstExpr;
571  if (rBinOpExpr && rBinOpExpr.getKind() == AffineExprKind::Mul &&
572  (rRhsConstExpr = rBinOpExpr.getRHS().dyn_cast<AffineConstantExpr>())) {
573  rRhsConst = rRhsConstExpr.getValue();
574  secondExpr = rBinOpExpr.getLHS();
575  } else {
576  rRhsConst = 1;
577  secondExpr = rhs;
578  }
579 
580  if (rLhsConst && rRhsConst && firstExpr == secondExpr)
581  return getAffineBinaryOpExpr(
582  AffineExprKind::Mul, firstExpr,
583  getAffineConstantExpr(rLhsConst.getValue() + rRhsConst.getValue(),
584  lhs.getContext()));
585 
586  // When doing successive additions, bring constant to the right: turn (d0 + 2)
587  // + d1 into (d0 + d1) + 2.
588  if (lBin && lBin.getKind() == AffineExprKind::Add) {
589  if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
590  return lBin.getLHS() + rhs + lrhs;
591  }
592  }
593 
594  // Detect and transform "expr - q * (expr floordiv q)" to "expr mod q", where
595  // q may be a constant or symbolic expression. This leads to a much more
596  // efficient form when 'c' is a power of two, and in general a more compact
597  // and readable form.
598 
599  // Process '(expr floordiv c) * (-c)'.
600  if (!rBinOpExpr)
601  return nullptr;
602 
603  auto lrhs = rBinOpExpr.getLHS();
604  auto rrhs = rBinOpExpr.getRHS();
605 
606  AffineExpr llrhs, rlrhs;
607 
608  // Check if lrhsBinOpExpr is of the form (expr floordiv q) * q, where q is a
609  // symbolic expression.
610  auto lrhsBinOpExpr = lrhs.dyn_cast<AffineBinaryOpExpr>();
611  // Check rrhsConstOpExpr = -1.
612  auto rrhsConstOpExpr = rrhs.dyn_cast<AffineConstantExpr>();
613  if (rrhsConstOpExpr && rrhsConstOpExpr.getValue() == -1 && lrhsBinOpExpr &&
614  lrhsBinOpExpr.getKind() == AffineExprKind::Mul) {
615  // Check llrhs = expr floordiv q.
616  llrhs = lrhsBinOpExpr.getLHS();
617  // Check rlrhs = q.
618  rlrhs = lrhsBinOpExpr.getRHS();
619  auto llrhsBinOpExpr = llrhs.dyn_cast<AffineBinaryOpExpr>();
620  if (!llrhsBinOpExpr || llrhsBinOpExpr.getKind() != AffineExprKind::FloorDiv)
621  return nullptr;
622  if (llrhsBinOpExpr.getRHS() == rlrhs && lhs == llrhsBinOpExpr.getLHS())
623  return lhs % rlrhs;
624  }
625 
626  // Process lrhs, which is 'expr floordiv c'.
627  AffineBinaryOpExpr lrBinOpExpr = lrhs.dyn_cast<AffineBinaryOpExpr>();
628  if (!lrBinOpExpr || lrBinOpExpr.getKind() != AffineExprKind::FloorDiv)
629  return nullptr;
630 
631  llrhs = lrBinOpExpr.getLHS();
632  rlrhs = lrBinOpExpr.getRHS();
633 
634  if (lhs == llrhs && rlrhs == -rrhs) {
635  return lhs % rlrhs;
636  }
637  return nullptr;
638 }
639 
641  return *this + getAffineConstantExpr(v, getContext());
642 }
644  if (auto simplified = simplifyAdd(*this, other))
645  return simplified;
646 
648  return uniquer.get<AffineBinaryOpExprStorage>(
649  /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), *this, other);
650 }
651 
652 /// Simplify a multiply expression. Return nullptr if it can't be simplified.
654  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
655  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
656 
657  if (lhsConst && rhsConst)
658  return getAffineConstantExpr(lhsConst.getValue() * rhsConst.getValue(),
659  lhs.getContext());
660 
661  assert(lhs.isSymbolicOrConstant() || rhs.isSymbolicOrConstant());
662 
663  // Canonicalize the mul expression so that the constant/symbolic term is the
664  // RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a
665  // constant. (Note that a constant is trivially symbolic).
666  if (!rhs.isSymbolicOrConstant() || lhs.isa<AffineConstantExpr>()) {
667  // At least one of them has to be symbolic.
668  return rhs * lhs;
669  }
670 
671  // At this point, if there was a constant, it would be on the right.
672 
673  // Multiplication with a one is a noop, return the other input.
674  if (rhsConst) {
675  if (rhsConst.getValue() == 1)
676  return lhs;
677  // Multiplication with zero.
678  if (rhsConst.getValue() == 0)
679  return rhsConst;
680  }
681 
682  // Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6.
683  auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
684  if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Mul) {
685  if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>())
686  return lBin.getLHS() * (lrhs.getValue() * rhsConst.getValue());
687  }
688 
689  // When doing successive multiplication, bring constant to the right: turn (d0
690  // * 2) * d1 into (d0 * d1) * 2.
691  if (lBin && lBin.getKind() == AffineExprKind::Mul) {
692  if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
693  return (lBin.getLHS() * rhs) * lrhs;
694  }
695  }
696 
697  return nullptr;
698 }
699 
701  return *this * getAffineConstantExpr(v, getContext());
702 }
704  if (auto simplified = simplifyMul(*this, other))
705  return simplified;
706 
708  return uniquer.get<AffineBinaryOpExprStorage>(
709  /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), *this, other);
710 }
711 
712 // Unary minus, delegate to operator*.
714  return *this * getAffineConstantExpr(-1, getContext());
715 }
716 
717 // Delegate to operator+.
718 AffineExpr AffineExpr::operator-(int64_t v) const { return *this + (-v); }
720  return *this + (-other);
721 }
722 
724  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
725  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
726 
727  // mlir floordiv by zero or negative numbers is undefined and preserved as is.
728  if (!rhsConst || rhsConst.getValue() < 1)
729  return nullptr;
730 
731  if (lhsConst)
732  return getAffineConstantExpr(
733  floorDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext());
734 
735  // Fold floordiv of a multiply with a constant that is a multiple of the
736  // divisor. Eg: (i * 128) floordiv 64 = i * 2.
737  if (rhsConst == 1)
738  return lhs;
739 
740  // Simplify (expr * const) floordiv divConst when expr is known to be a
741  // multiple of divConst.
742  auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
743  if (lBin && lBin.getKind() == AffineExprKind::Mul) {
744  if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
745  // rhsConst is known to be a positive constant.
746  if (lrhs.getValue() % rhsConst.getValue() == 0)
747  return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
748  }
749  }
750 
751  // Simplify (expr1 + expr2) floordiv divConst when either expr1 or expr2 is
752  // known to be a multiple of divConst.
753  if (lBin && lBin.getKind() == AffineExprKind::Add) {
754  int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
755  int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
756  // rhsConst is known to be a positive constant.
757  if (llhsDiv % rhsConst.getValue() == 0 ||
758  lrhsDiv % rhsConst.getValue() == 0)
759  return lBin.getLHS().floorDiv(rhsConst.getValue()) +
760  lBin.getRHS().floorDiv(rhsConst.getValue());
761  }
762 
763  return nullptr;
764 }
765 
766 AffineExpr AffineExpr::floorDiv(uint64_t v) const {
768 }
770  if (auto simplified = simplifyFloorDiv(*this, other))
771  return simplified;
772 
774  return uniquer.get<AffineBinaryOpExprStorage>(
775  /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::FloorDiv), *this,
776  other);
777 }
778 
780  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
781  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
782 
783  if (!rhsConst || rhsConst.getValue() < 1)
784  return nullptr;
785 
786  if (lhsConst)
787  return getAffineConstantExpr(
788  ceilDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext());
789 
790  // Fold ceildiv of a multiply with a constant that is a multiple of the
791  // divisor. Eg: (i * 128) ceildiv 64 = i * 2.
792  if (rhsConst.getValue() == 1)
793  return lhs;
794 
795  // Simplify (expr * const) ceildiv divConst when const is known to be a
796  // multiple of divConst.
797  auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
798  if (lBin && lBin.getKind() == AffineExprKind::Mul) {
799  if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
800  // rhsConst is known to be a positive constant.
801  if (lrhs.getValue() % rhsConst.getValue() == 0)
802  return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
803  }
804  }
805 
806  return nullptr;
807 }
808 
809 AffineExpr AffineExpr::ceilDiv(uint64_t v) const {
811 }
813  if (auto simplified = simplifyCeilDiv(*this, other))
814  return simplified;
815 
817  return uniquer.get<AffineBinaryOpExprStorage>(
818  /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::CeilDiv), *this,
819  other);
820 }
821 
823  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
824  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
825 
826  // mod w.r.t zero or negative numbers is undefined and preserved as is.
827  if (!rhsConst || rhsConst.getValue() < 1)
828  return nullptr;
829 
830  if (lhsConst)
831  return getAffineConstantExpr(mod(lhsConst.getValue(), rhsConst.getValue()),
832  lhs.getContext());
833 
834  // Fold modulo of an expression that is known to be a multiple of a constant
835  // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128)
836  // mod 64 is folded to 0, and less trivially, (i*(j*4*(k*32))) mod 128 = 0.
837  if (lhs.getLargestKnownDivisor() % rhsConst.getValue() == 0)
838  return getAffineConstantExpr(0, lhs.getContext());
839 
840  // Simplify (expr1 + expr2) mod divConst when either expr1 or expr2 is
841  // known to be a multiple of divConst.
842  auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
843  if (lBin && lBin.getKind() == AffineExprKind::Add) {
844  int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
845  int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
846  // rhsConst is known to be a positive constant.
847  if (llhsDiv % rhsConst.getValue() == 0)
848  return lBin.getRHS() % rhsConst.getValue();
849  if (lrhsDiv % rhsConst.getValue() == 0)
850  return lBin.getLHS() % rhsConst.getValue();
851  }
852 
853  // Simplify (e % a) % b to e % b when b evenly divides a
854  if (lBin && lBin.getKind() == AffineExprKind::Mod) {
855  auto intermediate = lBin.getRHS().dyn_cast<AffineConstantExpr>();
856  if (intermediate && intermediate.getValue() >= 1 &&
857  mod(intermediate.getValue(), rhsConst.getValue()) == 0) {
858  return lBin.getLHS() % rhsConst.getValue();
859  }
860  }
861 
862  return nullptr;
863 }
864 
866  return *this % getAffineConstantExpr(v, getContext());
867 }
869  if (auto simplified = simplifyMod(*this, other))
870  return simplified;
871 
873  return uniquer.get<AffineBinaryOpExprStorage>(
874  /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mod), *this, other);
875 }
876 
878  SmallVector<AffineExpr, 8> dimReplacements(map.getResults().begin(),
879  map.getResults().end());
880  return replaceDimsAndSymbols(dimReplacements, {});
881 }
882 raw_ostream &mlir::operator<<(raw_ostream &os, AffineExpr expr) {
883  expr.print(os);
884  return os;
885 }
886 
887 /// Constructs an affine expression from a flat ArrayRef. If there are local
888 /// identifiers (neither dimensional nor symbolic) that appear in the sum of
889 /// products expression, `localExprs` is expected to have the AffineExpr
890 /// for it, and is substituted into. The ArrayRef `flatExprs` is expected to be
891 /// in the format [dims, symbols, locals, constant term].
893  unsigned numDims,
894  unsigned numSymbols,
895  ArrayRef<AffineExpr> localExprs,
896  MLIRContext *context) {
897  // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
898  assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
899  "unexpected number of local expressions");
900 
901  auto expr = getAffineConstantExpr(0, context);
902  // Dimensions and symbols.
903  for (unsigned j = 0; j < numDims + numSymbols; j++) {
904  if (flatExprs[j] == 0)
905  continue;
906  auto id = j < numDims ? getAffineDimExpr(j, context)
907  : getAffineSymbolExpr(j - numDims, context);
908  expr = expr + id * flatExprs[j];
909  }
910 
911  // Local identifiers.
912  for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
913  j++) {
914  if (flatExprs[j] == 0)
915  continue;
916  auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
917  expr = expr + term;
918  }
919 
920  // Constant term.
921  int64_t constTerm = flatExprs[flatExprs.size() - 1];
922  if (constTerm != 0)
923  expr = expr + constTerm;
924  return expr;
925 }
926 
927 /// Constructs a semi-affine expression from a flat ArrayRef. If there are
928 /// local identifiers (neither dimensional nor symbolic) that appear in the sum
929 /// of products expression, `localExprs` is expected to have the AffineExprs for
930 /// it, and is substituted into. The ArrayRef `flatExprs` is expected to be in
931 /// the format [dims, symbols, locals, constant term]. The semi-affine
932 /// expression is constructed in the sorted order of dimension and symbol
933 /// position numbers. Note: local expressions/ids are used for mod, div as well
934 /// as symbolic RHS terms for terms that are not pure affine.
936  unsigned numDims,
937  unsigned numSymbols,
938  ArrayRef<AffineExpr> localExprs,
939  MLIRContext *context) {
940  assert(!flatExprs.empty() && "flatExprs cannot be empty");
941 
942  // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
943  assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
944  "unexpected number of local expressions");
945 
946  AffineExpr expr = getAffineConstantExpr(0, context);
947 
948  // We design indices as a pair which help us present the semi-affine map as
949  // sum of product where terms are sorted based on dimension or symbol
950  // position: <keyA, keyB> for expressions of the form dimension * symbol,
951  // where keyA is the position number of the dimension and keyB is the
952  // position number of the symbol. For dimensional expressions we set the index
953  // as (position number of the dimension, -1), as we want dimensional
954  // expressions to appear before symbolic and product of dimensional and
955  // symbolic expressions having the dimension with the same position number.
956  // For symbolic expression set the index as (position number of the symbol,
957  // maximum of last dimension and symbol position) number. For example, we want
958  // the expression we are constructing to look something like: d0 + d0 * s0 +
959  // s0 + d1*s1 + s1.
960 
961  // Stores the affine expression corresponding to a given index.
963  // Stores the constant coefficient value corresponding to a given
964  // dimension, symbol or a non-pure affine expression stored in `localExprs`.
965  DenseMap<std::pair<unsigned, signed>, int64_t> coefficients;
966  // Stores the indices as defined above, and later sorted to produce
967  // the semi-affine expression in the desired form.
969 
970  // Example: expression = d0 + d0 * s0 + 2 * s0.
971  // indices = [{0,-1}, {0, 0}, {0, 1}]
972  // coefficients = [{{0, -1}, 1}, {{0, 0}, 1}, {{0, 1}, 2}]
973  // indexToExprMap = [{{0, -1}, d0}, {{0, 0}, d0 * s0}, {{0, 1}, s0}]
974 
975  // Adds entries to `indexToExprMap`, `coefficients` and `indices`.
976  auto addEntry = [&](std::pair<unsigned, signed> index, int64_t coefficient,
977  AffineExpr expr) {
978  assert(std::find(indices.begin(), indices.end(), index) == indices.end() &&
979  "Key is already present in indices vector and overwriting will "
980  "happen in `indexToExprMap` and `coefficients`!");
981 
982  indices.push_back(index);
983  coefficients.insert({index, coefficient});
984  indexToExprMap.insert({index, expr});
985  };
986 
987  // Design indices for dimensional or symbolic terms, and store the indices,
988  // constant coefficient corresponding to the indices in `coefficients` map,
989  // and affine expression corresponding to indices in `indexToExprMap` map.
990 
991  for (unsigned j = 0; j < numDims; ++j) {
992  if (flatExprs[j] == 0)
993  continue;
994  // For dimensional expressions we set the index as <position number of the
995  // dimension, 0>, as we want dimensional expressions to appear before
996  // symbolic ones and products of dimensional and symbolic expressions
997  // having the dimension with the same position number.
998  std::pair<unsigned, signed> indexEntry(j, -1);
999  addEntry(indexEntry, flatExprs[j], getAffineDimExpr(j, context));
1000  }
1001  for (unsigned j = numDims; j < numDims + numSymbols; ++j) {
1002  if (flatExprs[j] == 0)
1003  continue;
1004  // For symbolic expression set the index as <position number
1005  // of the symbol, max(dimCount, symCount)> number,
1006  // as we want symbolic expressions with the same positional number to
1007  // appear after dimensional expressions having the same positional number.
1008  std::pair<unsigned, signed> indexEntry(j - numDims,
1009  std::max(numDims, numSymbols));
1010  addEntry(indexEntry, flatExprs[j],
1011  getAffineSymbolExpr(j - numDims, context));
1012  }
1013 
1014  // Denotes semi-affine product, modulo or division terms, which has been added
1015  // to the `indexToExpr` map.
1016  SmallVector<bool, 4> addedToMap(flatExprs.size() - numDims - numSymbols - 1,
1017  false);
1018  unsigned lhsPos, rhsPos;
1019  // Construct indices for product terms involving dimension, symbol or constant
1020  // as lhs/rhs, and store the indices, constant coefficient corresponding to
1021  // the indices in `coefficients` map, and affine expression corresponding to
1022  // in indices in `indexToExprMap` map.
1023  for (const auto &it : llvm::enumerate(localExprs)) {
1024  AffineExpr expr = it.value();
1025  if (flatExprs[numDims + numSymbols + it.index()] == 0)
1026  continue;
1027  AffineExpr lhs = expr.cast<AffineBinaryOpExpr>().getLHS();
1028  AffineExpr rhs = expr.cast<AffineBinaryOpExpr>().getRHS();
1029  if (!((lhs.isa<AffineDimExpr>() || lhs.isa<AffineSymbolExpr>()) &&
1030  (rhs.isa<AffineDimExpr>() || rhs.isa<AffineSymbolExpr>() ||
1031  rhs.isa<AffineConstantExpr>()))) {
1032  continue;
1033  }
1034  if (rhs.isa<AffineConstantExpr>()) {
1035  // For product/modulo/division expressions, when rhs of modulo/division
1036  // expression is constant, we put 0 in place of keyB, because we want
1037  // them to appear earlier in the semi-affine expression we are
1038  // constructing. When rhs is constant, we place 0 in place of keyB.
1039  if (lhs.isa<AffineDimExpr>()) {
1040  lhsPos = lhs.cast<AffineDimExpr>().getPosition();
1041  std::pair<unsigned, signed> indexEntry(lhsPos, -1);
1042  addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1043  expr);
1044  } else {
1045  lhsPos = lhs.cast<AffineSymbolExpr>().getPosition();
1046  std::pair<unsigned, signed> indexEntry(lhsPos,
1047  std::max(numDims, numSymbols));
1048  addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1049  expr);
1050  }
1051  } else if (lhs.isa<AffineDimExpr>()) {
1052  // For product/modulo/division expressions having lhs as dimension and rhs
1053  // as symbol, we order the terms in the semi-affine expression based on
1054  // the pair: <keyA, keyB> for expressions of the form dimension * symbol,
1055  // where keyA is the position number of the dimension and keyB is the
1056  // position number of the symbol.
1057  lhsPos = lhs.cast<AffineDimExpr>().getPosition();
1058  rhsPos = rhs.cast<AffineSymbolExpr>().getPosition();
1059  std::pair<unsigned, signed> indexEntry(lhsPos, rhsPos);
1060  addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1061  } else {
1062  // For product/modulo/division expressions having both lhs and rhs as
1063  // symbol, we design indices as a pair: <keyA, keyB> for expressions
1064  // of the form dimension * symbol, where keyA is the position number of
1065  // the dimension and keyB is the position number of the symbol.
1066  lhsPos = lhs.cast<AffineSymbolExpr>().getPosition();
1067  rhsPos = rhs.cast<AffineSymbolExpr>().getPosition();
1068  std::pair<unsigned, signed> indexEntry(lhsPos, rhsPos);
1069  addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1070  }
1071  addedToMap[it.index()] = true;
1072  }
1073 
1074  // Constructing the simplified semi-affine sum of product/division/mod
1075  // expression from the flattened form in the desired sorted order of indices
1076  // of the various individual product/division/mod expressions.
1077  std::sort(indices.begin(), indices.end());
1078  for (const std::pair<unsigned, unsigned> index : indices) {
1079  assert(indexToExprMap.lookup(index) &&
1080  "cannot find key in `indexToExprMap` map");
1081  expr = expr + indexToExprMap.lookup(index) * coefficients.lookup(index);
1082  }
1083 
1084  // Local identifiers.
1085  for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
1086  j++) {
1087  // If the coefficient of the local expression is 0, continue as we need not
1088  // add it in out final expression.
1089  if (flatExprs[j] == 0 || addedToMap[j - numDims - numSymbols])
1090  continue;
1091  auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
1092  expr = expr + term;
1093  }
1094 
1095  // Constant term.
1096  int64_t constTerm = flatExprs.back();
1097  if (constTerm != 0)
1098  expr = expr + constTerm;
1099  return expr;
1100 }
1101 
1103  unsigned numSymbols)
1104  : numDims(numDims), numSymbols(numSymbols), numLocals(0) {
1105  operandExprStack.reserve(8);
1106 }
1107 
1108 // In pure affine t = expr * c, we multiply each coefficient of lhs with c.
1109 //
1110 // In case of semi affine multiplication expressions, t = expr * symbolic_expr,
1111 // introduce a local variable p (= expr * symbolic_expr), and the affine
1112 // expression expr * symbolic_expr is added to `localExprs`.
1114  assert(operandExprStack.size() >= 2);
1116  operandExprStack.pop_back();
1118 
1119  // Flatten semi-affine multiplication expressions by introducing a local
1120  // variable in place of the product; the affine expression
1121  // corresponding to the quantifier is added to `localExprs`.
1122  if (!expr.getRHS().isa<AffineConstantExpr>()) {
1123  MLIRContext *context = expr.getContext();
1125  localExprs, context);
1127  localExprs, context);
1128  addLocalVariableSemiAffine(a * b, lhs, lhs.size());
1129  return;
1130  }
1131 
1132  // Get the RHS constant.
1133  auto rhsConst = rhs[getConstantIndex()];
1134  for (unsigned i = 0, e = lhs.size(); i < e; i++) {
1135  lhs[i] *= rhsConst;
1136  }
1137 }
1138 
1140  assert(operandExprStack.size() >= 2);
1141  const auto &rhs = operandExprStack.back();
1142  auto &lhs = operandExprStack[operandExprStack.size() - 2];
1143  assert(lhs.size() == rhs.size());
1144  // Update the LHS in place.
1145  for (unsigned i = 0, e = rhs.size(); i < e; i++) {
1146  lhs[i] += rhs[i];
1147  }
1148  // Pop off the RHS.
1149  operandExprStack.pop_back();
1150 }
1151 
1152 //
1153 // t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1
1154 //
1155 // A mod expression "expr mod c" is thus flattened by introducing a new local
1156 // variable q (= expr floordiv c), such that expr mod c is replaced with
1157 // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
1158 //
1159 // In case of semi-affine modulo expressions, t = expr mod symbolic_expr,
1160 // introduce a local variable m (= expr mod symbolic_expr), and the affine
1161 // expression expr mod symbolic_expr is added to `localExprs`.
1163  assert(operandExprStack.size() >= 2);
1164 
1166  operandExprStack.pop_back();
1168  MLIRContext *context = expr.getContext();
1169 
1170  // Flatten semi affine modulo expressions by introducing a local
1171  // variable in place of the modulo value, and the affine expression
1172  // corresponding to the quantifier is added to `localExprs`.
1173  if (!expr.getRHS().isa<AffineConstantExpr>()) {
1174  AffineExpr dividendExpr = getAffineExprFromFlatForm(
1175  lhs, numDims, numSymbols, localExprs, context);
1177  localExprs, context);
1178  AffineExpr modExpr = dividendExpr % divisorExpr;
1179  addLocalVariableSemiAffine(modExpr, lhs, lhs.size());
1180  return;
1181  }
1182 
1183  int64_t rhsConst = rhs[getConstantIndex()];
1184  // TODO: handle modulo by zero case when this issue is fixed
1185  // at the other places in the IR.
1186  assert(rhsConst > 0 && "RHS constant has to be positive");
1187 
1188  // Check if the LHS expression is a multiple of modulo factor.
1189  unsigned i, e;
1190  for (i = 0, e = lhs.size(); i < e; i++)
1191  if (lhs[i] % rhsConst != 0)
1192  break;
1193  // If yes, modulo expression here simplifies to zero.
1194  if (i == lhs.size()) {
1195  std::fill(lhs.begin(), lhs.end(), 0);
1196  return;
1197  }
1198 
1199  // Add a local variable for the quotient, i.e., expr % c is replaced by
1200  // (expr - q * c) where q = expr floordiv c. Do this while canceling out
1201  // the GCD of expr and c.
1202  SmallVector<int64_t, 8> floorDividend(lhs);
1203  uint64_t gcd = rhsConst;
1204  for (unsigned i = 0, e = lhs.size(); i < e; i++)
1205  gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i]));
1206  // Simplify the numerator and the denominator.
1207  if (gcd != 1) {
1208  for (unsigned i = 0, e = floorDividend.size(); i < e; i++)
1209  floorDividend[i] = floorDividend[i] / static_cast<int64_t>(gcd);
1210  }
1211  int64_t floorDivisor = rhsConst / static_cast<int64_t>(gcd);
1212 
1213  // Construct the AffineExpr form of the floordiv to store in localExprs.
1214 
1215  AffineExpr dividendExpr = getAffineExprFromFlatForm(
1216  floorDividend, numDims, numSymbols, localExprs, context);
1217  AffineExpr divisorExpr = getAffineConstantExpr(floorDivisor, context);
1218  AffineExpr floorDivExpr = dividendExpr.floorDiv(divisorExpr);
1219  int loc;
1220  if ((loc = findLocalId(floorDivExpr)) == -1) {
1221  addLocalFloorDivId(floorDividend, floorDivisor, floorDivExpr);
1222  // Set result at top of stack to "lhs - rhsConst * q".
1223  lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst;
1224  } else {
1225  // Reuse the existing local id.
1226  lhs[getLocalVarStartIndex() + loc] = -rhsConst;
1227  }
1228 }
1229 
1231  visitDivExpr(expr, /*isCeil=*/true);
1232 }
1234  visitDivExpr(expr, /*isCeil=*/false);
1235 }
1236 
1238  operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1239  auto &eq = operandExprStack.back();
1240  assert(expr.getPosition() < numDims && "Inconsistent number of dims");
1241  eq[getDimStartIndex() + expr.getPosition()] = 1;
1242 }
1243 
1245  operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1246  auto &eq = operandExprStack.back();
1247  assert(expr.getPosition() < numSymbols && "inconsistent number of symbols");
1248  eq[getSymbolStartIndex() + expr.getPosition()] = 1;
1249 }
1250 
1252  operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1253  auto &eq = operandExprStack.back();
1254  eq[getConstantIndex()] = expr.getValue();
1255 }
1256 
1257 void SimpleAffineExprFlattener::addLocalVariableSemiAffine(
1258  AffineExpr expr, SmallVectorImpl<int64_t> &result,
1259  unsigned long resultSize) {
1260  assert(result.size() == resultSize &&
1261  "`result` vector passed is not of correct size");
1262  int loc;
1263  if ((loc = findLocalId(expr)) == -1)
1264  addLocalIdSemiAffine(expr);
1265  std::fill(result.begin(), result.end(), 0);
1266  if (loc == -1)
1267  result[getLocalVarStartIndex() + numLocals - 1] = 1;
1268  else
1269  result[getLocalVarStartIndex() + loc] = 1;
1270 }
1271 
1272 // t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1
1273 // A floordiv is thus flattened by introducing a new local variable q, and
1274 // replacing that expression with 'q' while adding the constraints
1275 // c * q <= expr <= c * q + c - 1 to localVarCst (done by
1276 // FlatAffineConstraints::addLocalFloorDiv).
1277 //
1278 // A ceildiv is similarly flattened:
1279 // t = expr ceildiv c <=> t = (expr + c - 1) floordiv c
1280 //
1281 // In case of semi affine division expressions, t = expr floordiv symbolic_expr
1282 // or t = expr ceildiv symbolic_expr, introduce a local variable q (= expr
1283 // floordiv/ceildiv symbolic_expr), and the affine floordiv/ceildiv is added to
1284 // `localExprs`.
1285 void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
1286  bool isCeil) {
1287  assert(operandExprStack.size() >= 2);
1288 
1289  MLIRContext *context = expr.getContext();
1291  operandExprStack.pop_back();
1293 
1294  // Flatten semi affine division expressions by introducing a local
1295  // variable in place of the quotient, and the affine expression corresponding
1296  // to the quantifier is added to `localExprs`.
1297  if (!expr.getRHS().isa<AffineConstantExpr>()) {
1299  localExprs, context);
1301  localExprs, context);
1302  AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
1303  addLocalVariableSemiAffine(divExpr, lhs, lhs.size());
1304  return;
1305  }
1306 
1307  // This is a pure affine expr; the RHS is a positive constant.
1308  int64_t rhsConst = rhs[getConstantIndex()];
1309  // TODO: handle division by zero at the same time the issue is
1310  // fixed at other places.
1311  assert(rhsConst > 0 && "RHS constant has to be positive");
1312 
1313  // Simplify the floordiv, ceildiv if possible by canceling out the greatest
1314  // common divisors of the numerator and denominator.
1315  uint64_t gcd = std::abs(rhsConst);
1316  for (unsigned i = 0, e = lhs.size(); i < e; i++)
1317  gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i]));
1318  // Simplify the numerator and the denominator.
1319  if (gcd != 1) {
1320  for (unsigned i = 0, e = lhs.size(); i < e; i++)
1321  lhs[i] = lhs[i] / static_cast<int64_t>(gcd);
1322  }
1323  int64_t divisor = rhsConst / static_cast<int64_t>(gcd);
1324  // If the divisor becomes 1, the updated LHS is the result. (The
1325  // divisor can't be negative since rhsConst is positive).
1326  if (divisor == 1)
1327  return;
1328 
1329  // If the divisor cannot be simplified to one, we will have to retain
1330  // the ceil/floor expr (simplified up until here). Add an existential
1331  // quantifier to express its result, i.e., expr1 div expr2 is replaced
1332  // by a new identifier, q.
1333  AffineExpr a =
1335  AffineExpr b = getAffineConstantExpr(divisor, context);
1336 
1337  int loc;
1338  AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
1339  if ((loc = findLocalId(divExpr)) == -1) {
1340  if (!isCeil) {
1341  SmallVector<int64_t, 8> dividend(lhs);
1342  addLocalFloorDivId(dividend, divisor, divExpr);
1343  } else {
1344  // lhs ceildiv c <=> (lhs + c - 1) floordiv c
1345  SmallVector<int64_t, 8> dividend(lhs);
1346  dividend.back() += divisor - 1;
1347  addLocalFloorDivId(dividend, divisor, divExpr);
1348  }
1349  }
1350  // Set the expression on stack to the local var introduced to capture the
1351  // result of the division (floor or ceil).
1352  std::fill(lhs.begin(), lhs.end(), 0);
1353  if (loc == -1)
1354  lhs[getLocalVarStartIndex() + numLocals - 1] = 1;
1355  else
1356  lhs[getLocalVarStartIndex() + loc] = 1;
1357 }
1358 
1359 // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
1360 // The local identifier added is always a floordiv of a pure add/mul affine
1361 // function of other identifiers, coefficients of which are specified in
1362 // dividend and with respect to a positive constant divisor. localExpr is the
1363 // simplified tree expression (AffineExpr) corresponding to the quantifier.
1365  int64_t divisor,
1366  AffineExpr localExpr) {
1367  assert(divisor > 0 && "positive constant divisor expected");
1368  for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
1369  subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
1370  localExprs.push_back(localExpr);
1371  numLocals++;
1372  // dividend and divisor are not used here; an override of this method uses it.
1373 }
1374 
1376  for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
1377  subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
1378  localExprs.push_back(localExpr);
1379  ++numLocals;
1380 }
1381 
1382 int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) {
1384  if ((it = llvm::find(localExprs, localExpr)) == localExprs.end())
1385  return -1;
1386  return it - localExprs.begin();
1387 }
1388 
1389 /// Simplify the affine expression by flattening it and reconstructing it.
1391  unsigned numSymbols) {
1392  // Simplify semi-affine expressions separately.
1393  if (!expr.isPureAffine())
1394  expr = simplifySemiAffine(expr);
1395 
1396  SimpleAffineExprFlattener flattener(numDims, numSymbols);
1397  flattener.walkPostOrder(expr);
1398  ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
1399  if (!expr.isPureAffine() &&
1400  expr == getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1401  flattener.localExprs,
1402  expr.getContext()))
1403  return expr;
1404  AffineExpr simplifiedExpr =
1405  expr.isPureAffine()
1406  ? getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1407  flattener.localExprs, expr.getContext())
1408  : getSemiAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1409  flattener.localExprs,
1410  expr.getContext());
1411 
1412  flattener.operandExprStack.pop_back();
1413  assert(flattener.operandExprStack.empty());
1414  return simplifiedExpr;
1415 }
Affine binary operation expression.
Definition: AffineExpr.h:207
Include the generated interface declarations.
static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef< int64_t > flatExprs, unsigned numDims, unsigned numSymbols, ArrayRef< AffineExpr > localExprs, MLIRContext *context)
Constructs a semi-affine expression from a flat ArrayRef.
Definition: AffineExpr.cpp:935
StorageUniquer & getAffineUniquer()
Returns the storage uniquer used for creating affine constructs.
RHS of mod is always a constant or a symbolic expression with a positive value.
Base storage class appearing in an affine expression.
AffineExpr replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements) const
This method substitutes any uses of dimensions and symbols (e.g.
Definition: AffineExpr.cpp:64
Storage * get(function_ref< void(Storage *)> initFn, TypeID id, Args &&...args)
Gets a uniqued instance of &#39;Storage&#39;.
AffineConstantExpr(AffineExpr::ImplType *ptr=nullptr)
Definition: AffineExpr.cpp:504
bool isPureAffine() const
Returns true if this is a pure affine expression, i.e., multiplication, floordiv, ceildiv...
Definition: AffineExpr.cpp:187
static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos, AffineExprKind opKind)
Returns true if the expression is divisible by the given symbol with position symbolPos.
Definition: AffineExpr.cpp:323
AffineExpr compose(AffineMap map) const
Compose with an AffineMap.
Definition: AffineExpr.cpp:877
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
Definition: AffineExpr.cpp:280
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:514
int64_t getValue() const
Definition: AffineExpr.cpp:506
AffineExpr shiftSymbols(unsigned numSymbols, unsigned shift, unsigned offset=0) const
Replace symbols[offset ...
Definition: AffineExpr.cpp:120
AffineBinaryOpExpr(AffineExpr::ImplType *ptr)
Definition: AffineExpr.cpp:302
bool operator==(AffineExpr other) const
Definition: AffineExpr.h:76
AffineExpr shiftDims(unsigned numDims, unsigned shift, unsigned offset=0) const
Replace dims[offset ...
Definition: AffineExpr.cpp:108
A binary operation appearing in an affine expression.
RetTy walkPostOrder(AffineExpr expr)
ImplType * expr
Definition: AffineExpr.h:198
AffineSymbolExpr(AffineExpr::ImplType *ptr)
Definition: AffineExpr.cpp:493
unsigned getPosition() const
Definition: AffineExpr.cpp:312
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
An integer constant appearing in affine expression.
Definition: AffineExpr.h:232
AffineExpr getAffineExprFromFlatForm(ArrayRef< int64_t > flatExprs, unsigned numDims, unsigned numSymbols, ArrayRef< AffineExpr > localExprs, MLIRContext *context)
Constructs an affine expression from a flat ArrayRef.
Definition: AffineExpr.cpp:892
void walk(std::function< void(AffineExpr)> callback) const
Walk all of the AffineExpr&#39;s in this expression in postorder.
Definition: AffineExpr.cpp:28
void visitConstantExpr(AffineConstantExpr expr)
AffineExpr getRHS() const
Definition: AffineExpr.cpp:307
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
bool isMultipleOf(int64_t factor) const
Return true if the affine expression is a multiple of &#39;factor&#39;.
Definition: AffineExpr.cpp:246
Base class for AffineExpr visitors/walkers.
bool isSymbolicOrConstant() const
Returns true if this expression is made out of only symbols and constants, i.e., it does not involve ...
Definition: AffineExpr.cpp:163
static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos, AffineExprKind opKind)
Divides the given expression by the given symbol at position symbolPos.
Definition: AffineExpr.cpp:382
U dyn_cast() const
Definition: AffineExpr.h:281
AffineExpr operator*(int64_t v) const
Definition: AffineExpr.cpp:700
static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:822
AffineDimExpr(AffineExpr::ImplType *ptr)
Definition: AffineExpr.cpp:311
void visitSymbolExpr(AffineSymbolExpr expr)
static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:779
static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs)
Simplify a multiply expression. Return nullptr if it can&#39;t be simplified.
Definition: AffineExpr.cpp:653
AffineExpr getLHS() const
Definition: AffineExpr.cpp:304
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:499
virtual void addLocalIdSemiAffine(AffineExpr localExpr)
Add a local identifier (needed to flatten a mod, floordiv, ceildiv, mul expr) when the rhs is a symbo...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:234
Base type for affine expression.
Definition: AffineExpr.h:68
MLIRContext * getContext() const
Definition: AffineExpr.cpp:23
int64_t mod(int64_t lhs, int64_t rhs)
Returns MLIR&#39;s mod operation on constants.
Definition: MathExtras.h:45
RHS of mul is always a constant or a symbolic expression.
void visitCeilDivExpr(AffineBinaryOpExpr expr)
virtual void addLocalFloorDivId(ArrayRef< int64_t > dividend, int64_t divisor, AffineExpr localExpr)
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
Definition: AffineExpr.cpp:218
A multi-dimensional affine map Affine map&#39;s are immutable like Type&#39;s, and they are uniqued...
Definition: AffineMap.h:41
U cast() const
Definition: AffineExpr.h:291
A utility class to get or create instances of "storage classes".
RHS of floordiv is always a constant or a symbolic expression.
AffineExpr ceilDiv(uint64_t v) const
Definition: AffineExpr.cpp:809
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:307
bool isFunctionOfSymbol(unsigned position) const
Return true if the affine expression involves AffineSymbolExpr position.
Definition: AffineExpr.cpp:291
Eliminates identifier at the specified position using Fourier-Motzkin variable elimination.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:489
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:766
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:45
RHS of ceildiv is always a constant or a symbolic expression.
bool isa() const
Definition: AffineExpr.h:270
unsigned getPosition() const
Definition: AffineExpr.cpp:495
void visitFloorDivExpr(AffineBinaryOpExpr expr)
AffineExpr replaceSymbols(ArrayRef< AffineExpr > symReplacements) const
Symbol-only version of replaceDimsAndSymbols.
Definition: AffineExpr.cpp:102
static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:723
AffineExpr replaceDims(ArrayRef< AffineExpr > dimReplacements) const
Dim-only version of replaceDimsAndSymbols.
Definition: AffineExpr.cpp:97
AffineExpr replace(AffineExpr expr, AffineExpr replacement) const
Sparse replace method.
Definition: AffineExpr.cpp:156
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:25
A dimensional or symbolic identifier appearing in an affine expression.
static AffineExpr simplifySemiAffine(AffineExpr expr)
Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv operations when the second...
Definition: AffineExpr.cpp:439
void print(raw_ostream &os) const
Symbolic identifier.
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
void visitDimExpr(AffineDimExpr expr)
static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs)
Simplify add expression. Return nullptr if it can&#39;t be simplified.
Definition: AffineExpr.cpp:524
SimpleAffineExprFlattener(unsigned numDims, unsigned numSymbols)
AffineExprKind
Definition: AffineExpr.h:40
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
AffineExpr operator+(int64_t v) const
Definition: AffineExpr.cpp:640
An integer constant appearing in affine expression.
void visitMulExpr(AffineBinaryOpExpr expr)
void visitModExpr(AffineBinaryOpExpr expr)
static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:478
Dimensional identifier.
AffineExpr operator-() const
Definition: AffineExpr.cpp:713
AffineExpr operator%(uint64_t v) const
Definition: AffineExpr.cpp:865
void visitAddExpr(AffineBinaryOpExpr expr)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
std::vector< SmallVector< int64_t, 8 > > operandExprStack
A symbolic identifier appearing in an affine expression.
Definition: AffineExpr.h:224
SmallVector< AffineExpr, 4 > localExprs