MLIR  22.0.0git
AffineExpr.cpp
Go to the documentation of this file.
1 //===- AffineExpr.cpp - MLIR Affine Expr Classes --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <cmath>
10 #include <cstdint>
11 #include <utility>
12 
13 #include "AffineExprDetail.h"
14 #include "mlir/IR/AffineExpr.h"
16 #include "mlir/IR/AffineMap.h"
17 #include "mlir/IR/IntegerSet.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/Support/MathExtras.h"
20 #include <numeric>
21 #include <optional>
22 
23 using namespace mlir;
24 using namespace mlir::detail;
25 
26 using llvm::divideCeilSigned;
27 using llvm::divideFloorSigned;
28 using llvm::divideSignedWouldOverflow;
29 using llvm::mod;
30 
31 MLIRContext *AffineExpr::getContext() const { return expr->context; }
32 
33 AffineExprKind AffineExpr::getKind() const { return expr->kind; }
34 
35 /// Walk all of the AffineExprs in `e` in postorder. This is a private factory
36 /// method to help handle lambda walk functions. Users should use the regular
37 /// (non-static) `walk` method.
38 template <typename WalkRetTy>
40  function_ref<WalkRetTy(AffineExpr)> callback) {
41  struct AffineExprWalker
42  : public AffineExprVisitor<AffineExprWalker, WalkRetTy> {
43  function_ref<WalkRetTy(AffineExpr)> callback;
44 
45  AffineExprWalker(function_ref<WalkRetTy(AffineExpr)> callback)
46  : callback(callback) {}
47 
48  WalkRetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
49  return callback(expr);
50  }
51  WalkRetTy visitConstantExpr(AffineConstantExpr expr) {
52  return callback(expr);
53  }
54  WalkRetTy visitDimExpr(AffineDimExpr expr) { return callback(expr); }
55  WalkRetTy visitSymbolExpr(AffineSymbolExpr expr) { return callback(expr); }
56  };
57 
58  return AffineExprWalker(callback).walkPostOrder(e);
59 }
60 // Explicitly instantiate for the two supported return types.
61 template void mlir::AffineExpr::walk(AffineExpr e,
62  function_ref<void(AffineExpr)> callback);
63 template WalkResult
66 
67 // Dispatch affine expression construction based on kind.
69  AffineExpr rhs) {
71  return lhs + rhs;
73  return lhs * rhs;
75  return lhs.floorDiv(rhs);
77  return lhs.ceilDiv(rhs);
79  return lhs % rhs;
80 
81  llvm_unreachable("unknown binary operation on affine expressions");
82 }
83 
84 /// This method substitutes any uses of dimensions and symbols (e.g.
85 /// dim#0 with dimReplacements[0]) and returns the modified expression tree.
88  ArrayRef<AffineExpr> symReplacements) const {
89  switch (getKind()) {
91  return *this;
92  case AffineExprKind::DimId: {
93  unsigned dimId = llvm::cast<AffineDimExpr>(*this).getPosition();
94  if (dimId >= dimReplacements.size())
95  return *this;
96  return dimReplacements[dimId];
97  }
99  unsigned symId = llvm::cast<AffineSymbolExpr>(*this).getPosition();
100  if (symId >= symReplacements.size())
101  return *this;
102  return symReplacements[symId];
103  }
104  case AffineExprKind::Add:
105  case AffineExprKind::Mul:
108  case AffineExprKind::Mod:
109  auto binOp = llvm::cast<AffineBinaryOpExpr>(*this);
110  auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
111  auto newLHS = lhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
112  auto newRHS = rhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
113  if (newLHS == lhs && newRHS == rhs)
114  return *this;
115  return getAffineBinaryOpExpr(getKind(), newLHS, newRHS);
116  }
117  llvm_unreachable("Unknown AffineExpr");
118 }
119 
121  return replaceDimsAndSymbols(dimReplacements, {});
122 }
123 
126  return replaceDimsAndSymbols({}, symReplacements);
127 }
128 
129 /// Replace dims[offset ... numDims)
130 /// by dims[offset + shift ... shift + numDims).
131 AffineExpr AffineExpr::shiftDims(unsigned numDims, unsigned shift,
132  unsigned offset) const {
134  for (unsigned idx = 0; idx < offset; ++idx)
135  dims.push_back(getAffineDimExpr(idx, getContext()));
136  for (unsigned idx = offset; idx < numDims; ++idx)
137  dims.push_back(getAffineDimExpr(idx + shift, getContext()));
138  return replaceDimsAndSymbols(dims, {});
139 }
140 
141 /// Replace symbols[offset ... numSymbols)
142 /// by symbols[offset + shift ... shift + numSymbols).
143 AffineExpr AffineExpr::shiftSymbols(unsigned numSymbols, unsigned shift,
144  unsigned offset) const {
146  for (unsigned idx = 0; idx < offset; ++idx)
147  symbols.push_back(getAffineSymbolExpr(idx, getContext()));
148  for (unsigned idx = offset; idx < numSymbols; ++idx)
149  symbols.push_back(getAffineSymbolExpr(idx + shift, getContext()));
150  return replaceDimsAndSymbols({}, symbols);
151 }
152 
153 /// Sparse replace method. Return the modified expression tree.
156  auto it = map.find(*this);
157  if (it != map.end())
158  return it->second;
159  switch (getKind()) {
160  default:
161  return *this;
162  case AffineExprKind::Add:
163  case AffineExprKind::Mul:
166  case AffineExprKind::Mod:
167  auto binOp = llvm::cast<AffineBinaryOpExpr>(*this);
168  auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
169  auto newLHS = lhs.replace(map);
170  auto newRHS = rhs.replace(map);
171  if (newLHS == lhs && newRHS == rhs)
172  return *this;
173  return getAffineBinaryOpExpr(getKind(), newLHS, newRHS);
174  }
175  llvm_unreachable("Unknown AffineExpr");
176 }
177 
178 /// Sparse replace method. Return the modified expression tree.
181  map.insert(std::make_pair(expr, replacement));
182  return replace(map);
183 }
184 /// Returns true if this expression is made out of only symbols and
185 /// constants (no dimensional identifiers).
187  switch (getKind()) {
189  return true;
191  return false;
193  return true;
194 
195  case AffineExprKind::Add:
196  case AffineExprKind::Mul:
199  case AffineExprKind::Mod: {
200  auto expr = llvm::cast<AffineBinaryOpExpr>(*this);
201  return expr.getLHS().isSymbolicOrConstant() &&
202  expr.getRHS().isSymbolicOrConstant();
203  }
204  }
205  llvm_unreachable("Unknown AffineExpr");
206 }
207 
208 /// Returns true if this is a pure affine expression, i.e., multiplication,
209 /// floordiv, ceildiv, and mod is only allowed w.r.t constants.
211  switch (getKind()) {
215  return true;
216  case AffineExprKind::Add: {
217  auto op = llvm::cast<AffineBinaryOpExpr>(*this);
218  return op.getLHS().isPureAffine() && op.getRHS().isPureAffine();
219  }
220 
221  case AffineExprKind::Mul: {
222  // TODO: Canonicalize the constants in binary operators to the RHS when
223  // possible, allowing this to merge into the next case.
224  auto op = llvm::cast<AffineBinaryOpExpr>(*this);
225  return op.getLHS().isPureAffine() && op.getRHS().isPureAffine() &&
226  (llvm::isa<AffineConstantExpr>(op.getLHS()) ||
227  llvm::isa<AffineConstantExpr>(op.getRHS()));
228  }
231  case AffineExprKind::Mod: {
232  auto op = llvm::cast<AffineBinaryOpExpr>(*this);
233  return op.getLHS().isPureAffine() &&
234  llvm::isa<AffineConstantExpr>(op.getRHS());
235  }
236  }
237  llvm_unreachable("Unknown AffineExpr");
238 }
239 
240 // Returns the greatest known integral divisor of this affine expression.
242  AffineBinaryOpExpr binExpr(nullptr);
243  switch (getKind()) {
245  [[fallthrough]];
247  return 1;
249  [[fallthrough]];
251  // If the RHS is a constant and divides the known divisor on the LHS, the
252  // quotient is a known divisor of the expression.
253  binExpr = llvm::cast<AffineBinaryOpExpr>(*this);
254  auto rhs = llvm::dyn_cast<AffineConstantExpr>(binExpr.getRHS());
255  // Leave alone undefined expressions.
256  if (rhs && rhs.getValue() != 0) {
257  int64_t lhsDiv = binExpr.getLHS().getLargestKnownDivisor();
258  if (lhsDiv % rhs.getValue() == 0)
259  return std::abs(lhsDiv / rhs.getValue());
260  }
261  return 1;
262  }
264  return std::abs(llvm::cast<AffineConstantExpr>(*this).getValue());
265  case AffineExprKind::Mul: {
266  binExpr = llvm::cast<AffineBinaryOpExpr>(*this);
267  return binExpr.getLHS().getLargestKnownDivisor() *
268  binExpr.getRHS().getLargestKnownDivisor();
269  }
270  case AffineExprKind::Add:
271  [[fallthrough]];
272  case AffineExprKind::Mod: {
273  binExpr = llvm::cast<AffineBinaryOpExpr>(*this);
274  return std::gcd((uint64_t)binExpr.getLHS().getLargestKnownDivisor(),
275  (uint64_t)binExpr.getRHS().getLargestKnownDivisor());
276  }
277  }
278  llvm_unreachable("Unknown AffineExpr");
279 }
280 
281 bool AffineExpr::isMultipleOf(int64_t factor) const {
282  AffineBinaryOpExpr binExpr(nullptr);
283  uint64_t l, u;
284  switch (getKind()) {
286  [[fallthrough]];
288  return factor * factor == 1;
290  return llvm::cast<AffineConstantExpr>(*this).getValue() % factor == 0;
291  case AffineExprKind::Mul: {
292  binExpr = llvm::cast<AffineBinaryOpExpr>(*this);
293  // It's probably not worth optimizing this further (to not traverse the
294  // whole sub-tree under - it that would require a version of isMultipleOf
295  // that on a 'false' return also returns the largest known divisor).
296  return (l = binExpr.getLHS().getLargestKnownDivisor()) % factor == 0 ||
297  (u = binExpr.getRHS().getLargestKnownDivisor()) % factor == 0 ||
298  (l * u) % factor == 0;
299  }
300  case AffineExprKind::Add:
303  case AffineExprKind::Mod: {
304  binExpr = llvm::cast<AffineBinaryOpExpr>(*this);
305  return std::gcd((uint64_t)binExpr.getLHS().getLargestKnownDivisor(),
306  (uint64_t)binExpr.getRHS().getLargestKnownDivisor()) %
307  factor ==
308  0;
309  }
310  }
311  llvm_unreachable("Unknown AffineExpr");
312 }
313 
314 bool AffineExpr::isFunctionOfDim(unsigned position) const {
315  if (getKind() == AffineExprKind::DimId) {
316  return *this == mlir::getAffineDimExpr(position, getContext());
317  }
318  if (auto expr = llvm::dyn_cast<AffineBinaryOpExpr>(*this)) {
319  return expr.getLHS().isFunctionOfDim(position) ||
320  expr.getRHS().isFunctionOfDim(position);
321  }
322  return false;
323 }
324 
325 bool AffineExpr::isFunctionOfSymbol(unsigned position) const {
326  if (getKind() == AffineExprKind::SymbolId) {
327  return *this == mlir::getAffineSymbolExpr(position, getContext());
328  }
329  if (auto expr = llvm::dyn_cast<AffineBinaryOpExpr>(*this)) {
330  return expr.getLHS().isFunctionOfSymbol(position) ||
331  expr.getRHS().isFunctionOfSymbol(position);
332  }
333  return false;
334 }
335 
337  : AffineExpr(ptr) {}
339  return static_cast<ImplType *>(expr)->lhs;
340 }
342  return static_cast<ImplType *>(expr)->rhs;
343 }
344 
346 unsigned AffineDimExpr::getPosition() const {
347  return static_cast<ImplType *>(expr)->position;
348 }
349 
350 /// Returns true if the expression is divisible by the given symbol with
351 /// position `symbolPos`. The argument `opKind` specifies here what kind of
352 /// division or mod operation called this division. It helps in implementing the
353 /// commutative property of the floordiv and ceildiv operations. If the argument
354 ///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv
355 /// operation, then the commutative property can be used otherwise, the floordiv
356 /// operation is not divisible. The same argument holds for ceildiv operation.
357 static bool canSimplifyDivisionBySymbol(AffineExpr expr, unsigned symbolPos,
358  AffineExprKind opKind,
359  bool fromMul = false) {
360  // The argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
361  assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
362  opKind == AffineExprKind::CeilDiv) &&
363  "unexpected opKind");
364  switch (expr.getKind()) {
366  return cast<AffineConstantExpr>(expr).getValue() == 0;
368  return false;
370  return (cast<AffineSymbolExpr>(expr).getPosition() == symbolPos);
371  // Checks divisibility by the given symbol for both operands.
372  case AffineExprKind::Add: {
373  AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
374  return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos,
375  opKind) &&
376  canSimplifyDivisionBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
377  }
378  // Checks divisibility by the given symbol for both operands. Consider the
379  // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
380  // this is a division by s1 and both the operands of modulo are divisible by
381  // s1 but it is not divisible by s1 always. The third argument is
382  // `AffineExprKind::Mod` for this reason.
383  case AffineExprKind::Mod: {
384  AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
385  return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos,
387  canSimplifyDivisionBySymbol(binaryExpr.getRHS(), symbolPos,
389  }
390  // Checks if any of the operand divisible by the given symbol.
391  case AffineExprKind::Mul: {
392  AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
393  return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos, opKind,
394  true) ||
395  canSimplifyDivisionBySymbol(binaryExpr.getRHS(), symbolPos, opKind,
396  true);
397  }
398  // Floordiv and ceildiv are divisible by the given symbol when the first
399  // operand is divisible, and the affine expression kind of the argument expr
400  // is same as the argument `opKind`. This can be inferred from commutative
401  // property of floordiv and ceildiv operations and are as follow:
402  // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
403  // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
404  // It will fail 1.if operations are not same. For example:
405  // (exps1 ceildiv exp2) floordiv exp3 can not be simplified. 2.if there is a
406  // multiplication operation in the expression. For example:
407  // (exps1 ceildiv exp2) mul exp3 ceildiv exp4 can not be simplified.
410  AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
411  if (opKind != expr.getKind())
412  return false;
413  if (fromMul)
414  return false;
415  return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos,
416  expr.getKind());
417  }
418  }
419  llvm_unreachable("Unknown AffineExpr");
420 }
421 
422 /// Divides the given expression by the given symbol at position `symbolPos`. It
423 /// considers the divisibility condition is checked before calling itself. A
424 /// null expression is returned whenever the divisibility condition fails.
425 static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos,
426  AffineExprKind opKind) {
427  // THe argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
428  assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
429  opKind == AffineExprKind::CeilDiv) &&
430  "unexpected opKind");
431  switch (expr.getKind()) {
433  if (cast<AffineConstantExpr>(expr).getValue() != 0)
434  return nullptr;
435  return getAffineConstantExpr(0, expr.getContext());
437  return nullptr;
439  return getAffineConstantExpr(1, expr.getContext());
440  // Dividing both operands by the given symbol.
441  case AffineExprKind::Add: {
442  AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
443  return getAffineBinaryOpExpr(
444  expr.getKind(), symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind),
445  symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind));
446  }
447  // Dividing both operands by the given symbol.
448  case AffineExprKind::Mod: {
449  AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
450  return getAffineBinaryOpExpr(
451  expr.getKind(),
452  symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
453  symbolicDivide(binaryExpr.getRHS(), symbolPos, expr.getKind()));
454  }
455  // Dividing any of the operand by the given symbol.
456  case AffineExprKind::Mul: {
457  AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
458  if (!canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos, opKind))
459  return binaryExpr.getLHS() *
460  symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind);
461  return symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind) *
462  binaryExpr.getRHS();
463  }
464  // Dividing first operand only by the given symbol.
467  AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
468  return getAffineBinaryOpExpr(
469  expr.getKind(),
470  symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
471  binaryExpr.getRHS());
472  }
473  }
474  llvm_unreachable("Unknown AffineExpr");
475 }
476 
477 /// Populate `result` with all summand operands of given (potentially nested)
478 /// addition. If the given expression is not an addition, just populate the
479 /// expression itself.
480 /// Example: Add(Add(7, 8), Mul(9, 10)) will return [7, 8, Mul(9, 10)].
482  auto addExpr = dyn_cast<AffineBinaryOpExpr>(expr);
483  if (!addExpr || addExpr.getKind() != AffineExprKind::Add) {
484  result.push_back(expr);
485  return;
486  }
487  getSummandExprs(addExpr.getLHS(), result);
488  getSummandExprs(addExpr.getRHS(), result);
489 }
490 
491 /// Return "true" if `candidate` is a negated expression, i.e., Mul(-1, expr).
492 /// If so, also return the non-negated expression via `expr`.
493 static bool isNegatedAffineExpr(AffineExpr candidate, AffineExpr &expr) {
494  auto mulExpr = dyn_cast<AffineBinaryOpExpr>(candidate);
495  if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul)
496  return false;
497  if (auto lhs = dyn_cast<AffineConstantExpr>(mulExpr.getLHS())) {
498  if (lhs.getValue() == -1) {
499  expr = mulExpr.getRHS();
500  return true;
501  }
502  }
503  if (auto rhs = dyn_cast<AffineConstantExpr>(mulExpr.getRHS())) {
504  if (rhs.getValue() == -1) {
505  expr = mulExpr.getLHS();
506  return true;
507  }
508  }
509  return false;
510 }
511 
512 /// Return "true" if `lhs` % `rhs` is guaranteed to evaluate to zero based on
513 /// the fact that `lhs` contains another modulo expression that ensures that
514 /// `lhs` is divisible by `rhs`. This is a common pattern in the resulting IR
515 /// after loop peeling.
516 ///
517 /// Example: lhs = ub - ub % step
518 /// rhs = step
519 /// => (ub - ub % step) % step is guaranteed to evaluate to 0.
521  unsigned numDims, unsigned numSymbols) {
522  // TODO: Try to unify this function with `getBoundForAffineExpr`.
523  // Collect all summands in lhs.
524  SmallVector<AffineExpr> summands;
525  getSummandExprs(lhs, summands);
526  // Look for Mul(-1, Mod(x, rhs)) among the summands. If x matches the
527  // remaining summands, then lhs % rhs is guaranteed to evaluate to 0.
528  for (int64_t i = 0, e = summands.size(); i < e; ++i) {
529  AffineExpr current = summands[i];
530  AffineExpr beforeNegation;
531  if (!isNegatedAffineExpr(current, beforeNegation))
532  continue;
533  AffineBinaryOpExpr innerMod = dyn_cast<AffineBinaryOpExpr>(beforeNegation);
534  if (!innerMod || innerMod.getKind() != AffineExprKind::Mod)
535  continue;
536  if (innerMod.getRHS() != rhs)
537  continue;
538  // Sum all remaining summands and subtract x. If that expression can be
539  // simplified to zero, then the remaining summands and x are equal.
541  for (int64_t j = 0; j < e; ++j)
542  if (i != j)
543  diff = diff + summands[j];
544  diff = diff - innerMod.getLHS();
545  diff = simplifyAffineExpr(diff, numDims, numSymbols);
546  auto constExpr = dyn_cast<AffineConstantExpr>(diff);
547  if (constExpr && constExpr.getValue() == 0)
548  return true;
549  }
550  return false;
551 }
552 
553 /// Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv
554 /// operations when the second operand simplifies to a symbol and the first
555 /// operand is divisible by that symbol. It can be applied to any semi-affine
556 /// expression. Returned expression can either be a semi-affine or pure affine
557 /// expression.
558 static AffineExpr simplifySemiAffine(AffineExpr expr, unsigned numDims,
559  unsigned numSymbols) {
560  switch (expr.getKind()) {
564  return expr;
565  case AffineExprKind::Add:
566  case AffineExprKind::Mul: {
567  AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
568  return getAffineBinaryOpExpr(
569  expr.getKind(),
570  simplifySemiAffine(binaryExpr.getLHS(), numDims, numSymbols),
571  simplifySemiAffine(binaryExpr.getRHS(), numDims, numSymbols));
572  }
573  // Check if the simplification of the second operand is a symbol, and the
574  // first operand is divisible by it. If the operation is a modulo, a constant
575  // zero expression is returned. In the case of floordiv and ceildiv, the
576  // symbol from the simplification of the second operand divides the first
577  // operand. Otherwise, simplification is not possible.
580  case AffineExprKind::Mod: {
581  AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
582  AffineExpr sLHS =
583  simplifySemiAffine(binaryExpr.getLHS(), numDims, numSymbols);
584  AffineExpr sRHS =
585  simplifySemiAffine(binaryExpr.getRHS(), numDims, numSymbols);
586  if (isModOfModSubtraction(sLHS, sRHS, numDims, numSymbols))
587  return getAffineConstantExpr(0, expr.getContext());
588  AffineSymbolExpr symbolExpr = dyn_cast<AffineSymbolExpr>(
589  simplifySemiAffine(binaryExpr.getRHS(), numDims, numSymbols));
590  if (!symbolExpr)
591  return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
592  unsigned symbolPos = symbolExpr.getPosition();
593  if (!canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos,
594  expr.getKind()))
595  return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
596  if (expr.getKind() == AffineExprKind::Mod)
597  return getAffineConstantExpr(0, expr.getContext());
598  AffineExpr simplifiedQuotient =
599  symbolicDivide(sLHS, symbolPos, expr.getKind());
600  return simplifiedQuotient
601  ? simplifiedQuotient
602  : getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
603  }
604  }
605  llvm_unreachable("Unknown AffineExpr");
606 }
607 
609  MLIRContext *context) {
610  auto assignCtx = [context](AffineDimExprStorage *storage) {
611  storage->context = context;
612  };
613 
614  StorageUniquer &uniquer = context->getAffineUniquer();
615  return uniquer.get<AffineDimExprStorage>(
616  assignCtx, static_cast<unsigned>(kind), position);
617 }
618 
619 AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) {
620  return getAffineDimOrSymbol(AffineExprKind::DimId, position, context);
621 }
622 
624  : AffineExpr(ptr) {}
626  return static_cast<ImplType *>(expr)->position;
627 }
628 
629 AffineExpr mlir::getAffineSymbolExpr(unsigned position, MLIRContext *context) {
630  return getAffineDimOrSymbol(AffineExprKind::SymbolId, position, context);
631 }
632 
634  : AffineExpr(ptr) {}
636  return static_cast<ImplType *>(expr)->constant;
637 }
638 
639 bool AffineExpr::operator==(int64_t v) const {
640  return *this == getAffineConstantExpr(v, getContext());
641 }
642 
644  auto assignCtx = [context](AffineConstantExprStorage *storage) {
645  storage->context = context;
646  };
647 
648  StorageUniquer &uniquer = context->getAffineUniquer();
649  return uniquer.get<AffineConstantExprStorage>(assignCtx, constant);
650 }
651 
654  MLIRContext *context) {
655  return llvm::to_vector(llvm::map_range(constants, [&](int64_t constant) {
656  return getAffineConstantExpr(constant, context);
657  }));
658 }
659 
660 /// Simplify add expression. Return nullptr if it can't be simplified.
662  auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
663  auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
664  // Fold if both LHS, RHS are a constant and the sum does not overflow.
665  if (lhsConst && rhsConst) {
666  int64_t sum;
667  if (llvm::AddOverflow(lhsConst.getValue(), rhsConst.getValue(), sum)) {
668  return nullptr;
669  }
670  return getAffineConstantExpr(sum, lhs.getContext());
671  }
672 
673  // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
674  // If only one of them is a symbolic expressions, make it the RHS.
675  if (isa<AffineConstantExpr>(lhs) ||
676  (lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())) {
677  return rhs + lhs;
678  }
679 
680  // At this point, if there was a constant, it would be on the right.
681 
682  // Addition with a zero is a noop, return the other input.
683  if (rhsConst) {
684  if (rhsConst.getValue() == 0)
685  return lhs;
686  }
687  // Fold successive additions like (d0 + 2) + 3 into d0 + 5.
688  auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
689  if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Add) {
690  if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS()))
691  return lBin.getLHS() + (lrhs.getValue() + rhsConst.getValue());
692  }
693 
694  // Detect "c1 * expr + c_2 * expr" as "(c1 + c2) * expr".
695  // c1 is rRhsConst, c2 is rLhsConst; firstExpr, secondExpr are their
696  // respective multiplicands.
697  std::optional<int64_t> rLhsConst, rRhsConst;
698  AffineExpr firstExpr, secondExpr;
699  AffineConstantExpr rLhsConstExpr;
700  auto lBinOpExpr = dyn_cast<AffineBinaryOpExpr>(lhs);
701  if (lBinOpExpr && lBinOpExpr.getKind() == AffineExprKind::Mul &&
702  (rLhsConstExpr = dyn_cast<AffineConstantExpr>(lBinOpExpr.getRHS()))) {
703  rLhsConst = rLhsConstExpr.getValue();
704  firstExpr = lBinOpExpr.getLHS();
705  } else {
706  rLhsConst = 1;
707  firstExpr = lhs;
708  }
709 
710  auto rBinOpExpr = dyn_cast<AffineBinaryOpExpr>(rhs);
711  AffineConstantExpr rRhsConstExpr;
712  if (rBinOpExpr && rBinOpExpr.getKind() == AffineExprKind::Mul &&
713  (rRhsConstExpr = dyn_cast<AffineConstantExpr>(rBinOpExpr.getRHS()))) {
714  rRhsConst = rRhsConstExpr.getValue();
715  secondExpr = rBinOpExpr.getLHS();
716  } else {
717  rRhsConst = 1;
718  secondExpr = rhs;
719  }
720 
721  if (rLhsConst && rRhsConst && firstExpr == secondExpr)
722  return getAffineBinaryOpExpr(
723  AffineExprKind::Mul, firstExpr,
724  getAffineConstantExpr(*rLhsConst + *rRhsConst, lhs.getContext()));
725 
726  // When doing successive additions, bring constant to the right: turn (d0 + 2)
727  // + d1 into (d0 + d1) + 2.
728  if (lBin && lBin.getKind() == AffineExprKind::Add) {
729  if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) {
730  return lBin.getLHS() + rhs + lrhs;
731  }
732  }
733 
734  // Detect and transform "expr - q * (expr floordiv q)" to "expr mod q", where
735  // q may be a constant or symbolic expression. This leads to a much more
736  // efficient form when 'c' is a power of two, and in general a more compact
737  // and readable form.
738 
739  // Process '(expr floordiv c) * (-c)'.
740  if (!rBinOpExpr)
741  return nullptr;
742 
743  auto lrhs = rBinOpExpr.getLHS();
744  auto rrhs = rBinOpExpr.getRHS();
745 
746  AffineExpr llrhs, rlrhs;
747 
748  // Check if lrhsBinOpExpr is of the form (expr floordiv q) * q, where q is a
749  // symbolic expression.
750  auto lrhsBinOpExpr = dyn_cast<AffineBinaryOpExpr>(lrhs);
751  // Check rrhsConstOpExpr = -1.
752  auto rrhsConstOpExpr = dyn_cast<AffineConstantExpr>(rrhs);
753  if (rrhsConstOpExpr && rrhsConstOpExpr.getValue() == -1 && lrhsBinOpExpr &&
754  lrhsBinOpExpr.getKind() == AffineExprKind::Mul) {
755  // Check llrhs = expr floordiv q.
756  llrhs = lrhsBinOpExpr.getLHS();
757  // Check rlrhs = q.
758  rlrhs = lrhsBinOpExpr.getRHS();
759  auto llrhsBinOpExpr = dyn_cast<AffineBinaryOpExpr>(llrhs);
760  if (!llrhsBinOpExpr || llrhsBinOpExpr.getKind() != AffineExprKind::FloorDiv)
761  return nullptr;
762  if (llrhsBinOpExpr.getRHS() == rlrhs && lhs == llrhsBinOpExpr.getLHS())
763  return lhs % rlrhs;
764  }
765 
766  // Process lrhs, which is 'expr floordiv c'.
767  // expr + (expr // c * -c) = expr % c
768  AffineBinaryOpExpr lrBinOpExpr = dyn_cast<AffineBinaryOpExpr>(lrhs);
769  if (!lrBinOpExpr || rhs.getKind() != AffineExprKind::Mul ||
770  lrBinOpExpr.getKind() != AffineExprKind::FloorDiv)
771  return nullptr;
772 
773  llrhs = lrBinOpExpr.getLHS();
774  rlrhs = lrBinOpExpr.getRHS();
775  auto rlrhsConstOpExpr = dyn_cast<AffineConstantExpr>(rlrhs);
776  // We don't support modulo with a negative RHS.
777  bool isPositiveRhs = rlrhsConstOpExpr && rlrhsConstOpExpr.getValue() > 0;
778 
779  if (isPositiveRhs && lhs == llrhs && rlrhs == -rrhs) {
780  return lhs % rlrhs;
781  }
782 
783  // Try simplify lhs's last operand with rhs. e.g:
784  // (s0 * 64 + s1) + (s1 // c * -c) --->
785  // s0 * 64 + (s1 + s1 // c * -c) -->
786  // s0 * 64 + s1 % c
787  if (lBinOpExpr && lBinOpExpr.getKind() == AffineExprKind::Add) {
788  if (auto simplified = simplifyAdd(lBinOpExpr.getRHS(), rhs))
789  return lBinOpExpr.getLHS() + simplified;
790  }
791  return nullptr;
792 }
793 
794 /// Get the canonical order of two commutative exprs arguments.
795 static std::pair<AffineExpr, AffineExpr>
797  auto sym1 = dyn_cast<AffineSymbolExpr>(expr1);
798  auto sym2 = dyn_cast<AffineSymbolExpr>(expr2);
799  // Try to order by symbol/dim position first.
800  if (sym1 && sym2)
801  return sym1.getPosition() < sym2.getPosition() ? std::pair{expr1, expr2}
802  : std::pair{expr2, expr1};
803 
804  auto dim1 = dyn_cast<AffineDimExpr>(expr1);
805  auto dim2 = dyn_cast<AffineDimExpr>(expr2);
806  if (dim1 && dim2)
807  return dim1.getPosition() < dim2.getPosition() ? std::pair{expr1, expr2}
808  : std::pair{expr2, expr1};
809 
810  // Put dims before symbols.
811  if (dim1 && sym2)
812  return {dim1, sym2};
813 
814  if (sym1 && dim2)
815  return {dim2, sym1};
816 
817  // Otherwise, keep original order.
818  return {expr1, expr2};
819 }
820 
822  return *this + getAffineConstantExpr(v, getContext());
823 }
825  if (auto simplified = simplifyAdd(*this, other))
826  return simplified;
827 
828  auto [lhs, rhs] = orderCommutativeArgs(*this, other);
829 
831  return uniquer.get<AffineBinaryOpExprStorage>(
832  /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), lhs, rhs);
833 }
834 
835 /// Simplify a multiply expression. Return nullptr if it can't be simplified.
837  auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
838  auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
839 
840  if (lhsConst && rhsConst) {
841  int64_t product;
842  if (llvm::MulOverflow(lhsConst.getValue(), rhsConst.getValue(), product)) {
843  return nullptr;
844  }
846  }
847 
848  if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())
849  return nullptr;
850 
851  // Canonicalize the mul expression so that the constant/symbolic term is the
852  // RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a
853  // constant. (Note that a constant is trivially symbolic).
854  if (!rhs.isSymbolicOrConstant() || isa<AffineConstantExpr>(lhs)) {
855  // At least one of them has to be symbolic.
856  return rhs * lhs;
857  }
858 
859  // At this point, if there was a constant, it would be on the right.
860 
861  // Multiplication with a one is a noop, return the other input.
862  if (rhsConst) {
863  if (rhsConst.getValue() == 1)
864  return lhs;
865  // Multiplication with zero.
866  if (rhsConst.getValue() == 0)
867  return rhsConst;
868  }
869 
870  // Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6.
871  auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
872  if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Mul) {
873  if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS()))
874  return lBin.getLHS() * (lrhs.getValue() * rhsConst.getValue());
875  }
876 
877  // When doing successive multiplication, bring constant to the right: turn (d0
878  // * 2) * d1 into (d0 * d1) * 2.
879  if (lBin && lBin.getKind() == AffineExprKind::Mul) {
880  if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) {
881  return (lBin.getLHS() * rhs) * lrhs;
882  }
883  }
884 
885  return nullptr;
886 }
887 
889  return *this * getAffineConstantExpr(v, getContext());
890 }
892  if (auto simplified = simplifyMul(*this, other))
893  return simplified;
894 
895  auto [lhs, rhs] = orderCommutativeArgs(*this, other);
896 
898  return uniquer.get<AffineBinaryOpExprStorage>(
899  /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), lhs, rhs);
900 }
901 
902 // Unary minus, delegate to operator*.
904  return *this * getAffineConstantExpr(-1, getContext());
905 }
906 
907 // Delegate to operator+.
908 AffineExpr AffineExpr::operator-(int64_t v) const { return *this + (-v); }
910  return *this + (-other);
911 }
912 
914  auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
915  auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
916 
917  if (!rhsConst || rhsConst.getValue() == 0)
918  return nullptr;
919 
920  if (lhsConst) {
921  if (divideSignedWouldOverflow(lhsConst.getValue(), rhsConst.getValue()))
922  return nullptr;
923  return getAffineConstantExpr(
924  divideFloorSigned(lhsConst.getValue(), rhsConst.getValue()),
925  lhs.getContext());
926  }
927 
928  // Fold floordiv of a multiply with a constant that is a multiple of the
929  // divisor. Eg: (i * 128) floordiv 64 = i * 2.
930  if (rhsConst == 1)
931  return lhs;
932 
933  // Simplify `(expr * lrhs) floordiv rhsConst` when `lrhs` is known to be a
934  // multiple of `rhsConst`.
935  auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
936  if (lBin && lBin.getKind() == AffineExprKind::Mul) {
937  if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) {
938  // `rhsConst` is known to be a nonzero constant.
939  if (lrhs.getValue() % rhsConst.getValue() == 0)
940  return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
941  }
942  }
943 
944  // Simplify (expr1 + expr2) floordiv divConst when either expr1 or expr2 is
945  // known to be a multiple of divConst.
946  if (lBin && lBin.getKind() == AffineExprKind::Add) {
947  int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
948  int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
949  // rhsConst is known to be a nonzero constant.
950  if (llhsDiv % rhsConst.getValue() == 0 ||
951  lrhsDiv % rhsConst.getValue() == 0)
952  return lBin.getLHS().floorDiv(rhsConst.getValue()) +
953  lBin.getRHS().floorDiv(rhsConst.getValue());
954  }
955 
956  return nullptr;
957 }
958 
959 AffineExpr AffineExpr::floorDiv(uint64_t v) const {
961 }
963  if (auto simplified = simplifyFloorDiv(*this, other))
964  return simplified;
965 
967  return uniquer.get<AffineBinaryOpExprStorage>(
968  /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::FloorDiv), *this,
969  other);
970 }
971 
973  auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
974  auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
975 
976  if (!rhsConst || rhsConst.getValue() == 0)
977  return nullptr;
978 
979  if (lhsConst) {
980  if (divideSignedWouldOverflow(lhsConst.getValue(), rhsConst.getValue()))
981  return nullptr;
982  return getAffineConstantExpr(
983  divideCeilSigned(lhsConst.getValue(), rhsConst.getValue()),
984  lhs.getContext());
985  }
986 
987  // Fold ceildiv of a multiply with a constant that is a multiple of the
988  // divisor. Eg: (i * 128) ceildiv 64 = i * 2.
989  if (rhsConst.getValue() == 1)
990  return lhs;
991 
992  // Simplify `(expr * lrhs) ceildiv rhsConst` when `lrhs` is known to be a
993  // multiple of `rhsConst`.
994  auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
995  if (lBin && lBin.getKind() == AffineExprKind::Mul) {
996  if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) {
997  // `rhsConst` is known to be a nonzero constant.
998  if (lrhs.getValue() % rhsConst.getValue() == 0)
999  return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
1000  }
1001  }
1002 
1003  return nullptr;
1004 }
1005 
1006 AffineExpr AffineExpr::ceilDiv(uint64_t v) const {
1008 }
1010  if (auto simplified = simplifyCeilDiv(*this, other))
1011  return simplified;
1012 
1013  StorageUniquer &uniquer = getContext()->getAffineUniquer();
1014  return uniquer.get<AffineBinaryOpExprStorage>(
1015  /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::CeilDiv), *this,
1016  other);
1017 }
1018 
1020  auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
1021  auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
1022 
1023  // mod w.r.t zero or negative numbers is undefined and preserved as is.
1024  if (!rhsConst || rhsConst.getValue() < 1)
1025  return nullptr;
1026 
1027  if (lhsConst) {
1028  // mod never overflows.
1029  return getAffineConstantExpr(mod(lhsConst.getValue(), rhsConst.getValue()),
1030  lhs.getContext());
1031  }
1032 
1033  // Fold modulo of an expression that is known to be a multiple of a constant
1034  // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128)
1035  // mod 64 is folded to 0, and less trivially, (i*(j*4*(k*32))) mod 128 = 0.
1036  if (lhs.getLargestKnownDivisor() % rhsConst.getValue() == 0)
1037  return getAffineConstantExpr(0, lhs.getContext());
1038 
1039  // Simplify (expr1 + expr2) mod divConst when either expr1 or expr2 is
1040  // known to be a multiple of divConst.
1041  auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
1042  if (lBin && lBin.getKind() == AffineExprKind::Add) {
1043  int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
1044  int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
1045  // rhsConst is known to be a positive constant.
1046  if (llhsDiv % rhsConst.getValue() == 0)
1047  return lBin.getRHS() % rhsConst.getValue();
1048  if (lrhsDiv % rhsConst.getValue() == 0)
1049  return lBin.getLHS() % rhsConst.getValue();
1050  }
1051 
1052  // Simplify (e % a) % b to e % b when b evenly divides a
1053  if (lBin && lBin.getKind() == AffineExprKind::Mod) {
1054  auto intermediate = dyn_cast<AffineConstantExpr>(lBin.getRHS());
1055  if (intermediate && intermediate.getValue() >= 1 &&
1056  mod(intermediate.getValue(), rhsConst.getValue()) == 0) {
1057  return lBin.getLHS() % rhsConst.getValue();
1058  }
1059  }
1060 
1061  return nullptr;
1062 }
1063 
1065  return *this % getAffineConstantExpr(v, getContext());
1066 }
1068  if (auto simplified = simplifyMod(*this, other))
1069  return simplified;
1070 
1071  StorageUniquer &uniquer = getContext()->getAffineUniquer();
1072  return uniquer.get<AffineBinaryOpExprStorage>(
1073  /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mod), *this, other);
1074 }
1075 
1077  SmallVector<AffineExpr, 8> dimReplacements(map.getResults());
1078  return replaceDimsAndSymbols(dimReplacements, {});
1079 }
1080 raw_ostream &mlir::operator<<(raw_ostream &os, AffineExpr expr) {
1081  expr.print(os);
1082  return os;
1083 }
1084 
1085 /// Constructs an affine expression from a flat ArrayRef. If there are local
1086 /// identifiers (neither dimensional nor symbolic) that appear in the sum of
1087 /// products expression, `localExprs` is expected to have the AffineExpr
1088 /// for it, and is substituted into. The ArrayRef `flatExprs` is expected to be
1089 /// in the format [dims, symbols, locals, constant term].
1091  unsigned numDims,
1092  unsigned numSymbols,
1093  ArrayRef<AffineExpr> localExprs,
1094  MLIRContext *context) {
1095  // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
1096  assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
1097  "unexpected number of local expressions");
1098 
1099  auto expr = getAffineConstantExpr(0, context);
1100  // Dimensions and symbols.
1101  for (unsigned j = 0; j < numDims + numSymbols; j++) {
1102  if (flatExprs[j] == 0)
1103  continue;
1104  auto id = j < numDims ? getAffineDimExpr(j, context)
1105  : getAffineSymbolExpr(j - numDims, context);
1106  expr = expr + id * flatExprs[j];
1107  }
1108 
1109  // Local identifiers.
1110  for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
1111  j++) {
1112  if (flatExprs[j] == 0)
1113  continue;
1114  auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
1115  expr = expr + term;
1116  }
1117 
1118  // Constant term.
1119  int64_t constTerm = flatExprs[flatExprs.size() - 1];
1120  if (constTerm != 0)
1121  expr = expr + constTerm;
1122  return expr;
1123 }
1124 
1125 /// Constructs a semi-affine expression from a flat ArrayRef. If there are
1126 /// local identifiers (neither dimensional nor symbolic) that appear in the sum
1127 /// of products expression, `localExprs` is expected to have the AffineExprs for
1128 /// it, and is substituted into. The ArrayRef `flatExprs` is expected to be in
1129 /// the format [dims, symbols, locals, constant term]. The semi-affine
1130 /// expression is constructed in the sorted order of dimension and symbol
1131 /// position numbers. Note: local expressions/ids are used for mod, div as well
1132 /// as symbolic RHS terms for terms that are not pure affine.
1134  unsigned numDims,
1135  unsigned numSymbols,
1136  ArrayRef<AffineExpr> localExprs,
1137  MLIRContext *context) {
1138  assert(!flatExprs.empty() && "flatExprs cannot be empty");
1139 
1140  // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
1141  assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
1142  "unexpected number of local expressions");
1143 
1144  AffineExpr expr = getAffineConstantExpr(0, context);
1145 
1146  // We design indices as a pair which help us present the semi-affine map as
1147  // sum of product where terms are sorted based on dimension or symbol
1148  // position: <keyA, keyB> for expressions of the form dimension * symbol,
1149  // where keyA is the position number of the dimension and keyB is the
1150  // position number of the symbol. For dimensional expressions we set the index
1151  // as (position number of the dimension, -1), as we want dimensional
1152  // expressions to appear before symbolic and product of dimensional and
1153  // symbolic expressions having the dimension with the same position number.
1154  // For symbolic expression set the index as (position number of the symbol,
1155  // maximum of last dimension and symbol position) number. For example, we want
1156  // the expression we are constructing to look something like: d0 + d0 * s0 +
1157  // s0 + d1*s1 + s1.
1158 
1159  // Stores the affine expression corresponding to a given index.
1161  // Stores the constant coefficient value corresponding to a given
1162  // dimension, symbol or a non-pure affine expression stored in `localExprs`.
1163  DenseMap<std::pair<unsigned, signed>, int64_t> coefficients;
1164  // Stores the indices as defined above, and later sorted to produce
1165  // the semi-affine expression in the desired form.
1167 
1168  // Example: expression = d0 + d0 * s0 + 2 * s0.
1169  // indices = [{0,-1}, {0, 0}, {0, 1}]
1170  // coefficients = [{{0, -1}, 1}, {{0, 0}, 1}, {{0, 1}, 2}]
1171  // indexToExprMap = [{{0, -1}, d0}, {{0, 0}, d0 * s0}, {{0, 1}, s0}]
1172 
1173  // Adds entries to `indexToExprMap`, `coefficients` and `indices`.
1174  auto addEntry = [&](std::pair<unsigned, signed> index, int64_t coefficient,
1175  AffineExpr expr) {
1176  assert(!llvm::is_contained(indices, index) &&
1177  "Key is already present in indices vector and overwriting will "
1178  "happen in `indexToExprMap` and `coefficients`!");
1179 
1180  indices.push_back(index);
1181  coefficients.insert({index, coefficient});
1182  indexToExprMap.insert({index, expr});
1183  };
1184 
1185  // Design indices for dimensional or symbolic terms, and store the indices,
1186  // constant coefficient corresponding to the indices in `coefficients` map,
1187  // and affine expression corresponding to indices in `indexToExprMap` map.
1188 
1189  // Ensure we do not have duplicate keys in `indexToExpr` map.
1190  unsigned offsetSym = 0;
1191  signed offsetDim = -1;
1192  for (unsigned j = numDims; j < numDims + numSymbols; ++j) {
1193  if (flatExprs[j] == 0)
1194  continue;
1195  // For symbolic expression set the index as <position number
1196  // of the symbol, max(dimCount, symCount)> number,
1197  // as we want symbolic expressions with the same positional number to
1198  // appear after dimensional expressions having the same positional number.
1199  std::pair<unsigned, signed> indexEntry(
1200  j - numDims, std::max(numDims, numSymbols) + offsetSym++);
1201  addEntry(indexEntry, flatExprs[j],
1202  getAffineSymbolExpr(j - numDims, context));
1203  }
1204 
1205  // Denotes semi-affine product, modulo or division terms, which has been added
1206  // to the `indexToExpr` map.
1207  SmallVector<bool, 4> addedToMap(flatExprs.size() - numDims - numSymbols - 1,
1208  false);
1209  unsigned lhsPos, rhsPos;
1210  // Construct indices for product terms involving dimension, symbol or constant
1211  // as lhs/rhs, and store the indices, constant coefficient corresponding to
1212  // the indices in `coefficients` map, and affine expression corresponding to
1213  // in indices in `indexToExprMap` map.
1214  for (const auto &it : llvm::enumerate(localExprs)) {
1215  if (flatExprs[numDims + numSymbols + it.index()] == 0)
1216  continue;
1217  AffineExpr expr = it.value();
1218  auto binaryExpr = dyn_cast<AffineBinaryOpExpr>(expr);
1219  if (!binaryExpr)
1220  continue;
1221 
1222  AffineExpr lhs = binaryExpr.getLHS();
1223  AffineExpr rhs = binaryExpr.getRHS();
1224  if (!((isa<AffineDimExpr>(lhs) || isa<AffineSymbolExpr>(lhs)) &&
1225  (isa<AffineDimExpr>(rhs) || isa<AffineSymbolExpr>(rhs) ||
1226  isa<AffineConstantExpr>(rhs)))) {
1227  continue;
1228  }
1229  if (isa<AffineConstantExpr>(rhs)) {
1230  // For product/modulo/division expressions, when rhs of modulo/division
1231  // expression is constant, we put 0 in place of keyB, because we want
1232  // them to appear earlier in the semi-affine expression we are
1233  // constructing. When rhs is constant, we place 0 in place of keyB.
1234  if (isa<AffineDimExpr>(lhs)) {
1235  lhsPos = cast<AffineDimExpr>(lhs).getPosition();
1236  std::pair<unsigned, signed> indexEntry(lhsPos, offsetDim--);
1237  addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1238  expr);
1239  } else {
1240  lhsPos = cast<AffineSymbolExpr>(lhs).getPosition();
1241  std::pair<unsigned, signed> indexEntry(
1242  lhsPos, std::max(numDims, numSymbols) + offsetSym++);
1243  addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1244  expr);
1245  }
1246  } else if (isa<AffineDimExpr>(lhs)) {
1247  // For product/modulo/division expressions having lhs as dimension and rhs
1248  // as symbol, we order the terms in the semi-affine expression based on
1249  // the pair: <keyA, keyB> for expressions of the form dimension * symbol,
1250  // where keyA is the position number of the dimension and keyB is the
1251  // position number of the symbol.
1252  lhsPos = cast<AffineDimExpr>(lhs).getPosition();
1253  rhsPos = cast<AffineSymbolExpr>(rhs).getPosition();
1254  std::pair<unsigned, signed> indexEntry(lhsPos, rhsPos);
1255  addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1256  } else {
1257  // For product/modulo/division expressions having both lhs and rhs as
1258  // symbol, we design indices as a pair: <keyA, keyB> for expressions
1259  // of the form dimension * symbol, where keyA is the position number of
1260  // the dimension and keyB is the position number of the symbol.
1261  lhsPos = cast<AffineSymbolExpr>(lhs).getPosition();
1262  rhsPos = cast<AffineSymbolExpr>(rhs).getPosition();
1263  std::pair<unsigned, signed> indexEntry(
1264  lhsPos, std::max(numDims, numSymbols) + offsetSym++);
1265  addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1266  }
1267  addedToMap[it.index()] = true;
1268  }
1269 
1270  for (unsigned j = 0; j < numDims; ++j) {
1271  if (flatExprs[j] == 0)
1272  continue;
1273  // For dimensional expressions we set the index as <position number of the
1274  // dimension, 0>, as we want dimensional expressions to appear before
1275  // symbolic ones and products of dimensional and symbolic expressions
1276  // having the dimension with the same position number.
1277  std::pair<unsigned, signed> indexEntry(j, offsetDim--);
1278  addEntry(indexEntry, flatExprs[j], getAffineDimExpr(j, context));
1279  }
1280 
1281  // Constructing the simplified semi-affine sum of product/division/mod
1282  // expression from the flattened form in the desired sorted order of indices
1283  // of the various individual product/division/mod expressions.
1284  llvm::sort(indices);
1285  for (const std::pair<unsigned, unsigned> index : indices) {
1286  assert(indexToExprMap.lookup(index) &&
1287  "cannot find key in `indexToExprMap` map");
1288  expr = expr + indexToExprMap.lookup(index) * coefficients.lookup(index);
1289  }
1290 
1291  // Local identifiers.
1292  for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
1293  j++) {
1294  // If the coefficient of the local expression is 0, continue as we need not
1295  // add it in out final expression.
1296  if (flatExprs[j] == 0 || addedToMap[j - numDims - numSymbols])
1297  continue;
1298  auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
1299  expr = expr + term;
1300  }
1301 
1302  // Constant term.
1303  int64_t constTerm = flatExprs.back();
1304  if (constTerm != 0)
1305  expr = expr + constTerm;
1306  return expr;
1307 }
1308 
1310  unsigned numSymbols)
1311  : numDims(numDims), numSymbols(numSymbols), numLocals(0) {
1312  operandExprStack.reserve(8);
1313 }
1314 
1315 // In pure affine t = expr * c, we multiply each coefficient of lhs with c.
1316 //
1317 // In case of semi affine multiplication expressions, t = expr * symbolic_expr,
1318 // introduce a local variable p (= expr * symbolic_expr), and the affine
1319 // expression expr * symbolic_expr is added to `localExprs`.
1321  assert(operandExprStack.size() >= 2);
1323  operandExprStack.pop_back();
1325 
1326  // Flatten semi-affine multiplication expressions by introducing a local
1327  // variable in place of the product; the affine expression
1328  // corresponding to the quantifier is added to `localExprs`.
1329  if (!isa<AffineConstantExpr>(expr.getRHS())) {
1330  SmallVector<int64_t, 8> mulLhs(lhs);
1331  MLIRContext *context = expr.getContext();
1333  localExprs, context);
1335  localExprs, context);
1336  return addLocalVariableSemiAffine(mulLhs, rhs, a * b, lhs, lhs.size());
1337  }
1338 
1339  // Get the RHS constant.
1340  int64_t rhsConst = rhs[getConstantIndex()];
1341  for (int64_t &lhsElt : lhs)
1342  lhsElt *= rhsConst;
1343 
1344  return success();
1345 }
1346 
1348  assert(operandExprStack.size() >= 2);
1349  const auto &rhs = operandExprStack.back();
1350  auto &lhs = operandExprStack[operandExprStack.size() - 2];
1351  assert(lhs.size() == rhs.size());
1352  // Update the LHS in place.
1353  for (unsigned i = 0, e = rhs.size(); i < e; i++) {
1354  lhs[i] += rhs[i];
1355  }
1356  // Pop off the RHS.
1357  operandExprStack.pop_back();
1358  return success();
1359 }
1360 
1361 //
1362 // t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1
1363 //
1364 // A mod expression "expr mod c" is thus flattened by introducing a new local
1365 // variable q (= expr floordiv c), such that expr mod c is replaced with
1366 // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
1367 //
1368 // In case of semi-affine modulo expressions, t = expr mod symbolic_expr,
1369 // introduce a local variable m (= expr mod symbolic_expr), and the affine
1370 // expression expr mod symbolic_expr is added to `localExprs`.
1372  assert(operandExprStack.size() >= 2);
1373 
1375  operandExprStack.pop_back();
1377  MLIRContext *context = expr.getContext();
1378 
1379  // Flatten semi affine modulo expressions by introducing a local
1380  // variable in place of the modulo value, and the affine expression
1381  // corresponding to the quantifier is added to `localExprs`.
1382  if (!isa<AffineConstantExpr>(expr.getRHS())) {
1383  SmallVector<int64_t, 8> modLhs(lhs);
1384  AffineExpr dividendExpr = getAffineExprFromFlatForm(
1385  lhs, numDims, numSymbols, localExprs, context);
1387  localExprs, context);
1388  AffineExpr modExpr = dividendExpr % divisorExpr;
1389  return addLocalVariableSemiAffine(modLhs, rhs, modExpr, lhs, lhs.size());
1390  }
1391 
1392  int64_t rhsConst = rhs[getConstantIndex()];
1393  if (rhsConst <= 0)
1394  return failure();
1395 
1396  // Check if the LHS expression is a multiple of modulo factor.
1397  unsigned i, e;
1398  for (i = 0, e = lhs.size(); i < e; i++)
1399  if (lhs[i] % rhsConst != 0)
1400  break;
1401  // If yes, modulo expression here simplifies to zero.
1402  if (i == lhs.size()) {
1403  llvm::fill(lhs, 0);
1404  return success();
1405  }
1406 
1407  // Add a local variable for the quotient, i.e., expr % c is replaced by
1408  // (expr - q * c) where q = expr floordiv c. Do this while canceling out
1409  // the GCD of expr and c.
1410  SmallVector<int64_t, 8> floorDividend(lhs);
1411  uint64_t gcd = rhsConst;
1412  for (int64_t lhsElt : lhs)
1413  gcd = std::gcd(gcd, (uint64_t)std::abs(lhsElt));
1414  // Simplify the numerator and the denominator.
1415  if (gcd != 1) {
1416  for (int64_t &floorDividendElt : floorDividend)
1417  floorDividendElt = floorDividendElt / static_cast<int64_t>(gcd);
1418  }
1419  int64_t floorDivisor = rhsConst / static_cast<int64_t>(gcd);
1420 
1421  // Construct the AffineExpr form of the floordiv to store in localExprs.
1422 
1423  AffineExpr dividendExpr = getAffineExprFromFlatForm(
1424  floorDividend, numDims, numSymbols, localExprs, context);
1425  AffineExpr divisorExpr = getAffineConstantExpr(floorDivisor, context);
1426  AffineExpr floorDivExpr = dividendExpr.floorDiv(divisorExpr);
1427  int loc;
1428  if ((loc = findLocalId(floorDivExpr)) == -1) {
1429  addLocalFloorDivId(floorDividend, floorDivisor, floorDivExpr);
1430  // Set result at top of stack to "lhs - rhsConst * q".
1431  lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst;
1432  } else {
1433  // Reuse the existing local id.
1434  lhs[getLocalVarStartIndex() + loc] -= rhsConst;
1435  }
1436  return success();
1437 }
1438 
1439 LogicalResult
1441  return visitDivExpr(expr, /*isCeil=*/true);
1442 }
1443 LogicalResult
1445  return visitDivExpr(expr, /*isCeil=*/false);
1446 }
1447 
1449  operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1450  auto &eq = operandExprStack.back();
1451  assert(expr.getPosition() < numDims && "Inconsistent number of dims");
1452  eq[getDimStartIndex() + expr.getPosition()] = 1;
1453  return success();
1454 }
1455 
1456 LogicalResult
1458  operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1459  auto &eq = operandExprStack.back();
1460  assert(expr.getPosition() < numSymbols && "inconsistent number of symbols");
1461  eq[getSymbolStartIndex() + expr.getPosition()] = 1;
1462  return success();
1463 }
1464 
1465 LogicalResult
1467  operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1468  auto &eq = operandExprStack.back();
1469  eq[getConstantIndex()] = expr.getValue();
1470  return success();
1471 }
1472 
1473 LogicalResult SimpleAffineExprFlattener::addLocalVariableSemiAffine(
1474  ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs, AffineExpr localExpr,
1475  SmallVectorImpl<int64_t> &result, unsigned long resultSize) {
1476  assert(result.size() == resultSize &&
1477  "`result` vector passed is not of correct size");
1478  int loc;
1479  if ((loc = findLocalId(localExpr)) == -1) {
1480  if (failed(addLocalIdSemiAffine(lhs, rhs, localExpr)))
1481  return failure();
1482  }
1483  llvm::fill(result, 0);
1484  if (loc == -1)
1485  result[getLocalVarStartIndex() + numLocals - 1] = 1;
1486  else
1487  result[getLocalVarStartIndex() + loc] = 1;
1488  return success();
1489 }
1490 
1491 // t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1
1492 // A floordiv is thus flattened by introducing a new local variable q, and
1493 // replacing that expression with 'q' while adding the constraints
1494 // c * q <= expr <= c * q + c - 1 to localVarCst (done by
1495 // IntegerRelation::addLocalFloorDiv).
1496 //
1497 // A ceildiv is similarly flattened:
1498 // t = expr ceildiv c <=> t = (expr + c - 1) floordiv c
1499 //
1500 // In case of semi affine division expressions, t = expr floordiv symbolic_expr
1501 // or t = expr ceildiv symbolic_expr, introduce a local variable q (= expr
1502 // floordiv/ceildiv symbolic_expr), and the affine floordiv/ceildiv is added to
1503 // `localExprs`.
1504 LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
1505  bool isCeil) {
1506  assert(operandExprStack.size() >= 2);
1507 
1508  MLIRContext *context = expr.getContext();
1510  operandExprStack.pop_back();
1512 
1513  // Flatten semi affine division expressions by introducing a local
1514  // variable in place of the quotient, and the affine expression corresponding
1515  // to the quantifier is added to `localExprs`.
1516  if (!isa<AffineConstantExpr>(expr.getRHS())) {
1517  SmallVector<int64_t, 8> divLhs(lhs);
1519  localExprs, context);
1521  localExprs, context);
1522  AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
1523  return addLocalVariableSemiAffine(divLhs, rhs, divExpr, lhs, lhs.size());
1524  }
1525 
1526  // This is a pure affine expr; the RHS is a positive constant.
1527  int64_t rhsConst = rhs[getConstantIndex()];
1528  if (rhsConst <= 0)
1529  return failure();
1530 
1531  // Simplify the floordiv, ceildiv if possible by canceling out the greatest
1532  // common divisors of the numerator and denominator.
1533  uint64_t gcd = std::abs(rhsConst);
1534  for (int64_t lhsElt : lhs)
1535  gcd = std::gcd(gcd, (uint64_t)std::abs(lhsElt));
1536  // Simplify the numerator and the denominator.
1537  if (gcd != 1) {
1538  for (int64_t &lhsElt : lhs)
1539  lhsElt = lhsElt / static_cast<int64_t>(gcd);
1540  }
1541  int64_t divisor = rhsConst / static_cast<int64_t>(gcd);
1542  // If the divisor becomes 1, the updated LHS is the result. (The
1543  // divisor can't be negative since rhsConst is positive).
1544  if (divisor == 1)
1545  return success();
1546 
1547  // If the divisor cannot be simplified to one, we will have to retain
1548  // the ceil/floor expr (simplified up until here). Add an existential
1549  // quantifier to express its result, i.e., expr1 div expr2 is replaced
1550  // by a new identifier, q.
1551  AffineExpr a =
1553  AffineExpr b = getAffineConstantExpr(divisor, context);
1554 
1555  int loc;
1556  AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
1557  if ((loc = findLocalId(divExpr)) == -1) {
1558  if (!isCeil) {
1559  SmallVector<int64_t, 8> dividend(lhs);
1560  addLocalFloorDivId(dividend, divisor, divExpr);
1561  } else {
1562  // lhs ceildiv c <=> (lhs + c - 1) floordiv c
1563  SmallVector<int64_t, 8> dividend(lhs);
1564  dividend.back() += divisor - 1;
1565  addLocalFloorDivId(dividend, divisor, divExpr);
1566  }
1567  }
1568  // Set the expression on stack to the local var introduced to capture the
1569  // result of the division (floor or ceil).
1570  llvm::fill(lhs, 0);
1571  if (loc == -1)
1572  lhs[getLocalVarStartIndex() + numLocals - 1] = 1;
1573  else
1574  lhs[getLocalVarStartIndex() + loc] = 1;
1575  return success();
1576 }
1577 
1578 // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
1579 // The local identifier added is always a floordiv of a pure add/mul affine
1580 // function of other identifiers, coefficients of which are specified in
1581 // dividend and with respect to a positive constant divisor. localExpr is the
1582 // simplified tree expression (AffineExpr) corresponding to the quantifier.
1584  int64_t divisor,
1585  AffineExpr localExpr) {
1586  assert(divisor > 0 && "positive constant divisor expected");
1587  for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
1588  subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
1589  localExprs.push_back(localExpr);
1590  numLocals++;
1591  // dividend and divisor are not used here; an override of this method uses it.
1592 }
1593 
1595  ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs, AffineExpr localExpr) {
1596  for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
1597  subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
1598  localExprs.push_back(localExpr);
1599  ++numLocals;
1600  // lhs and rhs are not used here; an override of this method uses them.
1601  return success();
1602 }
1603 
1604 int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) {
1606  if ((it = llvm::find(localExprs, localExpr)) == localExprs.end())
1607  return -1;
1608  return it - localExprs.begin();
1609 }
1610 
1611 /// Simplify the affine expression by flattening it and reconstructing it.
1613  unsigned numSymbols) {
1614  // Simplify semi-affine expressions separately.
1615  if (!expr.isPureAffine())
1616  expr = simplifySemiAffine(expr, numDims, numSymbols);
1617 
1618  SimpleAffineExprFlattener flattener(numDims, numSymbols);
1619  // has poison expression
1620  if (failed(flattener.walkPostOrder(expr)))
1621  return expr;
1622  ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
1623  if (!expr.isPureAffine() &&
1624  expr == getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1625  flattener.localExprs,
1626  expr.getContext()))
1627  return expr;
1628  AffineExpr simplifiedExpr =
1629  expr.isPureAffine()
1630  ? getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1631  flattener.localExprs, expr.getContext())
1632  : getSemiAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1633  flattener.localExprs,
1634  expr.getContext());
1635 
1636  flattener.operandExprStack.pop_back();
1637  assert(flattener.operandExprStack.empty());
1638  return simplifiedExpr;
1639 }
1640 
1641 std::optional<int64_t> mlir::getBoundForAffineExpr(
1642  AffineExpr expr, unsigned numDims, unsigned numSymbols,
1643  ArrayRef<std::optional<int64_t>> constLowerBounds,
1644  ArrayRef<std::optional<int64_t>> constUpperBounds, bool isUpper) {
1645  // Handle divs and mods.
1646  if (auto binOpExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
1647  // If the LHS of a floor or ceil is bounded and the RHS is a constant, we
1648  // can compute an upper bound.
1649  if (binOpExpr.getKind() == AffineExprKind::FloorDiv) {
1650  auto rhsConst = dyn_cast<AffineConstantExpr>(binOpExpr.getRHS());
1651  if (!rhsConst || rhsConst.getValue() < 1)
1652  return std::nullopt;
1653  auto bound =
1654  getBoundForAffineExpr(binOpExpr.getLHS(), numDims, numSymbols,
1655  constLowerBounds, constUpperBounds, isUpper);
1656  if (!bound)
1657  return std::nullopt;
1658  return divideFloorSigned(*bound, rhsConst.getValue());
1659  }
1660  if (binOpExpr.getKind() == AffineExprKind::CeilDiv) {
1661  auto rhsConst = dyn_cast<AffineConstantExpr>(binOpExpr.getRHS());
1662  if (rhsConst && rhsConst.getValue() >= 1) {
1663  auto bound =
1664  getBoundForAffineExpr(binOpExpr.getLHS(), numDims, numSymbols,
1665  constLowerBounds, constUpperBounds, isUpper);
1666  if (!bound)
1667  return std::nullopt;
1668  return divideCeilSigned(*bound, rhsConst.getValue());
1669  }
1670  return std::nullopt;
1671  }
1672  if (binOpExpr.getKind() == AffineExprKind::Mod) {
1673  // lhs mod c is always <= c - 1 and non-negative. In addition, if `lhs` is
1674  // bounded such that lb <= lhs <= ub and lb floordiv c == ub floordiv c
1675  // (same "interval"), then lb mod c <= lhs mod c <= ub mod c.
1676  auto rhsConst = dyn_cast<AffineConstantExpr>(binOpExpr.getRHS());
1677  if (rhsConst && rhsConst.getValue() >= 1) {
1678  int64_t rhsConstVal = rhsConst.getValue();
1679  auto lb = getBoundForAffineExpr(binOpExpr.getLHS(), numDims, numSymbols,
1680  constLowerBounds, constUpperBounds,
1681  /*isUpper=*/false);
1682  auto ub =
1683  getBoundForAffineExpr(binOpExpr.getLHS(), numDims, numSymbols,
1684  constLowerBounds, constUpperBounds, isUpper);
1685  if (ub && lb &&
1686  divideFloorSigned(*lb, rhsConstVal) ==
1687  divideFloorSigned(*ub, rhsConstVal))
1688  return isUpper ? mod(*ub, rhsConstVal) : mod(*lb, rhsConstVal);
1689  return isUpper ? rhsConstVal - 1 : 0;
1690  }
1691  }
1692  }
1693  // Flatten the expression.
1694  SimpleAffineExprFlattener flattener(numDims, numSymbols);
1695  auto simpleResult = flattener.walkPostOrder(expr);
1696  // has poison expression
1697  if (failed(simpleResult))
1698  return std::nullopt;
1699  ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
1700  // TODO: Handle local variables. We can get hold of flattener.localExprs and
1701  // get bound on the local expr recursively.
1702  if (flattener.numLocals > 0)
1703  return std::nullopt;
1704  int64_t bound = 0;
1705  // Substitute the constant lower or upper bound for the dimensional or
1706  // symbolic input depending on `isUpper` to determine the bound.
1707  for (unsigned i = 0, e = numDims + numSymbols; i < e; ++i) {
1708  if (flattenedExpr[i] > 0) {
1709  auto &constBound = isUpper ? constUpperBounds[i] : constLowerBounds[i];
1710  if (!constBound)
1711  return std::nullopt;
1712  bound += *constBound * flattenedExpr[i];
1713  } else if (flattenedExpr[i] < 0) {
1714  auto &constBound = isUpper ? constLowerBounds[i] : constUpperBounds[i];
1715  if (!constBound)
1716  return std::nullopt;
1717  bound += *constBound * flattenedExpr[i];
1718  }
1719  }
1720  // Constant term.
1721  bound += flattenedExpr.back();
1722  return bound;
1723 }
static int64_t product(ArrayRef< int64_t > vals)
static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos, AffineExprKind opKind)
Divides the given expression by the given symbol at position symbolPos.
Definition: AffineExpr.cpp:425
static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs)
Simplify a multiply expression. Return nullptr if it can't be simplified.
Definition: AffineExpr.cpp:836
static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs)
static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs)
Simplify add expression. Return nullptr if it can't be simplified.
Definition: AffineExpr.cpp:661
static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef< int64_t > flatExprs, unsigned numDims, unsigned numSymbols, ArrayRef< AffineExpr > localExprs, MLIRContext *context)
Constructs a semi-affine expression from a flat ArrayRef.
static bool canSimplifyDivisionBySymbol(AffineExpr expr, unsigned symbolPos, AffineExprKind opKind, bool fromMul=false)
Returns true if the expression is divisible by the given symbol with position symbolPos.
Definition: AffineExpr.cpp:357
static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:972
static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:913
static bool isNegatedAffineExpr(AffineExpr candidate, AffineExpr &expr)
Return "true" if candidate is a negated expression, i.e., Mul(-1, expr).
Definition: AffineExpr.cpp:493
static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:608
static bool isModOfModSubtraction(AffineExpr lhs, AffineExpr rhs, unsigned numDims, unsigned numSymbols)
Return "true" if lhs % rhs is guaranteed to evaluate to zero based on the fact that lhs contains anot...
Definition: AffineExpr.cpp:520
static std::pair< AffineExpr, AffineExpr > orderCommutativeArgs(AffineExpr expr1, AffineExpr expr2)
Get the canonical order of two commutative exprs arguments.
Definition: AffineExpr.cpp:796
static void getSummandExprs(AffineExpr expr, SmallVector< AffineExpr > &result)
Populate result with all summand operands of given (potentially nested) addition.
Definition: AffineExpr.cpp:481
static AffineExpr simplifySemiAffine(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv operations when the second...
Definition: AffineExpr.cpp:558
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1244::ArityGroupAndKind::Kind kind
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Affine binary operation expression.
Definition: AffineExpr.h:214
AffineExpr getLHS() const
Definition: AffineExpr.cpp:338
AffineBinaryOpExpr(AffineExpr::ImplType *ptr)
Definition: AffineExpr.cpp:336
AffineExpr getRHS() const
Definition: AffineExpr.cpp:341
An integer constant appearing in affine expression.
Definition: AffineExpr.h:239
AffineConstantExpr(AffineExpr::ImplType *ptr=nullptr)
Definition: AffineExpr.cpp:633
int64_t getValue() const
Definition: AffineExpr.cpp:635
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:223
AffineDimExpr(AffineExpr::ImplType *ptr)
Definition: AffineExpr.cpp:345
unsigned getPosition() const
Definition: AffineExpr.cpp:346
See documentation for AffineExprVisitorBase.
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements) const
This method substitutes any uses of dimensions and symbols (e.g.
Definition: AffineExpr.cpp:87
AffineExpr shiftDims(unsigned numDims, unsigned shift, unsigned offset=0) const
Replace dims[offset ...
Definition: AffineExpr.cpp:131
AffineExpr operator+(int64_t v) const
Definition: AffineExpr.cpp:821
bool isSymbolicOrConstant() const
Returns true if this expression is made out of only symbols and constants, i.e., it does not involve ...
Definition: AffineExpr.cpp:186
AffineExpr operator*(int64_t v) const
Definition: AffineExpr.cpp:888
bool operator==(AffineExpr other) const
Definition: AffineExpr.h:76
bool isPureAffine() const
Returns true if this is a pure affine expression, i.e., multiplication, floordiv, ceildiv,...
Definition: AffineExpr.cpp:210
AffineExpr shiftSymbols(unsigned numSymbols, unsigned shift, unsigned offset=0) const
Replace symbols[offset ...
Definition: AffineExpr.cpp:143
AffineExpr operator-() const
Definition: AffineExpr.cpp:903
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:959
ImplType * expr
Definition: AffineExpr.h:196
RetT walk(FnT &&callback) const
Walk all of the AffineExpr's in this expression in postorder.
Definition: AffineExpr.h:117
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:33
bool isMultipleOf(int64_t factor) const
Return true if the affine expression is a multiple of 'factor'.
Definition: AffineExpr.cpp:281
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
Definition: AffineExpr.cpp:241
AffineExpr compose(AffineMap map) const
Compose with an AffineMap.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
Definition: AffineExpr.cpp:314
bool isFunctionOfSymbol(unsigned position) const
Return true if the affine expression involves AffineSymbolExpr position.
Definition: AffineExpr.cpp:325
AffineExpr replaceDims(ArrayRef< AffineExpr > dimReplacements) const
Dim-only version of replaceDimsAndSymbols.
Definition: AffineExpr.cpp:120
AffineExpr operator%(uint64_t v) const
MLIRContext * getContext() const
Definition: AffineExpr.cpp:31
AffineExpr replace(AffineExpr expr, AffineExpr replacement) const
Sparse replace method.
Definition: AffineExpr.cpp:179
AffineExpr replaceSymbols(ArrayRef< AffineExpr > symReplacements) const
Symbol-only version of replaceDimsAndSymbols.
Definition: AffineExpr.cpp:125
AffineExpr ceilDiv(uint64_t v) const
void print(raw_ostream &os) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
A symbolic identifier appearing in an affine expression.
Definition: AffineExpr.h:231
AffineSymbolExpr(AffineExpr::ImplType *ptr)
Definition: AffineExpr.cpp:623
unsigned getPosition() const
Definition: AffineExpr.cpp:625
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
StorageUniquer & getAffineUniquer()
Returns the storage uniquer used for creating affine constructs.
virtual void addLocalFloorDivId(ArrayRef< int64_t > dividend, int64_t divisor, AffineExpr localExpr)
LogicalResult visitSymbolExpr(AffineSymbolExpr expr)
std::vector< SmallVector< int64_t, 8 > > operandExprStack
LogicalResult visitDimExpr(AffineDimExpr expr)
LogicalResult visitFloorDivExpr(AffineBinaryOpExpr expr)
LogicalResult visitConstantExpr(AffineConstantExpr expr)
virtual LogicalResult addLocalIdSemiAffine(ArrayRef< int64_t > lhs, ArrayRef< int64_t > rhs, AffineExpr localExpr)
Add a local identifier (needed to flatten a mod, floordiv, ceildiv, mul expr) when the rhs is a symbo...
LogicalResult visitModExpr(AffineBinaryOpExpr expr)
LogicalResult visitAddExpr(AffineBinaryOpExpr expr)
LogicalResult visitCeilDivExpr(AffineBinaryOpExpr expr)
LogicalResult visitMulExpr(AffineBinaryOpExpr expr)
SmallVector< AffineExpr, 4 > localExprs
SimpleAffineExprFlattener(unsigned numDims, unsigned numSymbols)
A utility class to get or create instances of "storage classes".
Storage * get(function_ref< void(Storage *)> initFn, TypeID id, Args &&...args)
Gets a uniqued instance of 'Storage'.
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: WalkResult.h:29
AttrTypeReplacer.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
std::optional< int64_t > getBoundForAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols, ArrayRef< std::optional< int64_t >> constLowerBounds, ArrayRef< std::optional< int64_t >> constUpperBounds, bool isUpper)
Get a lower or upper (depending on isUpper) bound for expr while using the constant lower and upper b...
AffineExprKind
Definition: AffineExpr.h:40
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ Constant
Constant integer.
@ SymbolId
Symbolic identifier.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:68
AffineExpr getAffineExprFromFlatForm(ArrayRef< int64_t > flatExprs, unsigned numDims, unsigned numSymbols, ArrayRef< AffineExpr > localExprs, MLIRContext *context)
Constructs an affine expression from a flat ArrayRef.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:643
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
SmallVector< AffineExpr > getAffineConstantExprs(ArrayRef< int64_t > constants, MLIRContext *context)
Definition: AffineExpr.cpp:653
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:619
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:629
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
A binary operation appearing in an affine expression.
An integer constant appearing in affine expression.
A dimensional or symbolic identifier appearing in an affine expression.
Base storage class appearing in an affine expression.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.