MLIR 23.0.0git
AffineExpr.cpp
Go to the documentation of this file.
1//===- AffineExpr.cpp - MLIR Affine Expr Classes --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <cmath>
10#include <cstdint>
11#include <utility>
12
13#include "AffineExprDetail.h"
14#include "mlir/IR/AffineExpr.h"
16#include "mlir/IR/AffineMap.h"
17#include "mlir/IR/IntegerSet.h"
18#include "llvm/ADT/STLExtras.h"
19#include "llvm/ADT/SmallVectorExtras.h"
20#include "llvm/Support/MathExtras.h"
21#include <numeric>
22#include <optional>
23
24using namespace mlir;
25using namespace mlir::detail;
26
27using llvm::divideCeilSigned;
28using llvm::divideFloorSigned;
29using llvm::divideSignedWouldOverflow;
30using llvm::mod;
31
32MLIRContext *AffineExpr::getContext() const { return expr->context; }
33
34AffineExprKind AffineExpr::getKind() const { return expr->kind; }
35
36/// Walk all of the AffineExprs in `e` in postorder. This is a private factory
37/// method to help handle lambda walk functions. Users should use the regular
38/// (non-static) `walk` method.
39template <typename WalkRetTy>
41 function_ref<WalkRetTy(AffineExpr)> callback) {
42 struct AffineExprWalker
43 : public AffineExprVisitor<AffineExprWalker, WalkRetTy> {
44 function_ref<WalkRetTy(AffineExpr)> callback;
45
46 AffineExprWalker(function_ref<WalkRetTy(AffineExpr)> callback)
47 : callback(callback) {}
48
49 WalkRetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
50 return callback(expr);
51 }
52 WalkRetTy visitConstantExpr(AffineConstantExpr expr) {
53 return callback(expr);
54 }
55 WalkRetTy visitDimExpr(AffineDimExpr expr) { return callback(expr); }
56 WalkRetTy visitSymbolExpr(AffineSymbolExpr expr) { return callback(expr); }
57 };
58
59 return AffineExprWalker(callback).walkPostOrder(e);
60}
61// Explicitly instantiate for the two supported return types.
63 function_ref<void(AffineExpr)> callback);
64template WalkResult
67
68// Dispatch affine expression construction based on kind.
71 if (kind == AffineExprKind::Add)
72 return lhs + rhs;
73 if (kind == AffineExprKind::Mul)
74 return lhs * rhs;
75 if (kind == AffineExprKind::FloorDiv)
76 return lhs.floorDiv(rhs);
77 if (kind == AffineExprKind::CeilDiv)
78 return lhs.ceilDiv(rhs);
79 if (kind == AffineExprKind::Mod)
80 return lhs % rhs;
81
82 llvm_unreachable("unknown binary operation on affine expressions");
83}
84
85/// This method substitutes any uses of dimensions and symbols (e.g.
86/// dim#0 with dimReplacements[0]) and returns the modified expression tree.
89 ArrayRef<AffineExpr> symReplacements) const {
90 switch (getKind()) {
92 return *this;
94 unsigned dimId = llvm::cast<AffineDimExpr>(*this).getPosition();
95 if (dimId >= dimReplacements.size())
96 return *this;
97 return dimReplacements[dimId];
98 }
100 unsigned symId = llvm::cast<AffineSymbolExpr>(*this).getPosition();
101 if (symId >= symReplacements.size())
102 return *this;
103 return symReplacements[symId];
104 }
110 auto binOp = llvm::cast<AffineBinaryOpExpr>(*this);
111 auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
112 auto newLHS = lhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
113 auto newRHS = rhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
114 if (newLHS == lhs && newRHS == rhs)
115 return *this;
116 return getAffineBinaryOpExpr(getKind(), newLHS, newRHS);
117 }
118 llvm_unreachable("Unknown AffineExpr");
119}
120
122 return replaceDimsAndSymbols(dimReplacements, {});
123}
124
127 return replaceDimsAndSymbols({}, symReplacements);
128}
129
130/// Replace dims[offset ... numDims)
131/// by dims[offset + shift ... shift + numDims).
132AffineExpr AffineExpr::shiftDims(unsigned numDims, unsigned shift,
133 unsigned offset) const {
135 for (unsigned idx = 0; idx < offset; ++idx)
136 dims.push_back(getAffineDimExpr(idx, getContext()));
137 for (unsigned idx = offset; idx < numDims; ++idx)
138 dims.push_back(getAffineDimExpr(idx + shift, getContext()));
139 return replaceDimsAndSymbols(dims, {});
140}
141
142/// Replace symbols[offset ... numSymbols)
143/// by symbols[offset + shift ... shift + numSymbols).
144AffineExpr AffineExpr::shiftSymbols(unsigned numSymbols, unsigned shift,
145 unsigned offset) const {
147 for (unsigned idx = 0; idx < offset; ++idx)
148 symbols.push_back(getAffineSymbolExpr(idx, getContext()));
149 for (unsigned idx = offset; idx < numSymbols; ++idx)
150 symbols.push_back(getAffineSymbolExpr(idx + shift, getContext()));
151 return replaceDimsAndSymbols({}, symbols);
152}
153
154/// Sparse replace method. Return the modified expression tree.
157 auto it = map.find(*this);
158 if (it != map.end())
159 return it->second;
160 switch (getKind()) {
161 default:
162 return *this;
168 auto binOp = llvm::cast<AffineBinaryOpExpr>(*this);
169 auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
170 auto newLHS = lhs.replace(map);
171 auto newRHS = rhs.replace(map);
172 if (newLHS == lhs && newRHS == rhs)
173 return *this;
174 return getAffineBinaryOpExpr(getKind(), newLHS, newRHS);
175 }
176 llvm_unreachable("Unknown AffineExpr");
177}
178
179/// Sparse replace method. Return the modified expression tree.
182 map.insert(std::make_pair(expr, replacement));
183 return replace(map);
184}
185/// Returns true if this expression is made out of only symbols and
186/// constants (no dimensional identifiers).
188 switch (getKind()) {
190 return true;
192 return false;
194 return true;
195
200 case AffineExprKind::Mod: {
201 auto expr = llvm::cast<AffineBinaryOpExpr>(*this);
202 return expr.getLHS().isSymbolicOrConstant() &&
203 expr.getRHS().isSymbolicOrConstant();
204 }
205 }
206 llvm_unreachable("Unknown AffineExpr");
207}
208
209/// Returns true if this is a pure affine expression, i.e., multiplication,
210/// floordiv, ceildiv, and mod is only allowed w.r.t constants.
212 switch (getKind()) {
216 return true;
217 case AffineExprKind::Add: {
218 auto op = llvm::cast<AffineBinaryOpExpr>(*this);
219 return op.getLHS().isPureAffine() && op.getRHS().isPureAffine();
220 }
221
222 case AffineExprKind::Mul: {
223 // TODO: Canonicalize the constants in binary operators to the RHS when
224 // possible, allowing this to merge into the next case.
225 auto op = llvm::cast<AffineBinaryOpExpr>(*this);
226 return op.getLHS().isPureAffine() && op.getRHS().isPureAffine() &&
227 (llvm::isa<AffineConstantExpr>(op.getLHS()) ||
228 llvm::isa<AffineConstantExpr>(op.getRHS()));
229 }
232 case AffineExprKind::Mod: {
233 auto op = llvm::cast<AffineBinaryOpExpr>(*this);
234 return op.getLHS().isPureAffine() &&
235 llvm::isa<AffineConstantExpr>(op.getRHS());
236 }
237 }
238 llvm_unreachable("Unknown AffineExpr");
239}
240
241// Returns the greatest known integral divisor of this affine expression.
243 AffineBinaryOpExpr binExpr(nullptr);
244 switch (getKind()) {
246 [[fallthrough]];
248 return 1;
250 [[fallthrough]];
252 // If the RHS is a constant and divides the known divisor on the LHS, the
253 // quotient is a known divisor of the expression.
254 binExpr = llvm::cast<AffineBinaryOpExpr>(*this);
255 auto rhs = llvm::dyn_cast<AffineConstantExpr>(binExpr.getRHS());
256 // Leave alone undefined expressions.
257 if (rhs && rhs.getValue() != 0) {
258 int64_t lhsDiv = binExpr.getLHS().getLargestKnownDivisor();
259 if (lhsDiv % rhs.getValue() == 0)
260 return std::abs(lhsDiv / rhs.getValue());
261 }
262 return 1;
263 }
265 return std::abs(llvm::cast<AffineConstantExpr>(*this).getValue());
266 case AffineExprKind::Mul: {
267 binExpr = llvm::cast<AffineBinaryOpExpr>(*this);
268 return binExpr.getLHS().getLargestKnownDivisor() *
269 binExpr.getRHS().getLargestKnownDivisor();
270 }
272 [[fallthrough]];
273 case AffineExprKind::Mod: {
274 binExpr = llvm::cast<AffineBinaryOpExpr>(*this);
275 return std::gcd((uint64_t)binExpr.getLHS().getLargestKnownDivisor(),
276 (uint64_t)binExpr.getRHS().getLargestKnownDivisor());
277 }
278 }
279 llvm_unreachable("Unknown AffineExpr");
280}
281
283 AffineBinaryOpExpr binExpr(nullptr);
284 uint64_t l, u;
285 switch (getKind()) {
287 [[fallthrough]];
289 return factor * factor == 1;
291 return llvm::cast<AffineConstantExpr>(*this).getValue() % factor == 0;
292 case AffineExprKind::Mul: {
293 binExpr = llvm::cast<AffineBinaryOpExpr>(*this);
294 // It's probably not worth optimizing this further (to not traverse the
295 // whole sub-tree under - it that would require a version of isMultipleOf
296 // that on a 'false' return also returns the largest known divisor).
297 return (l = binExpr.getLHS().getLargestKnownDivisor()) % factor == 0 ||
298 (u = binExpr.getRHS().getLargestKnownDivisor()) % factor == 0 ||
299 (l * u) % factor == 0;
300 }
304 case AffineExprKind::Mod: {
305 binExpr = llvm::cast<AffineBinaryOpExpr>(*this);
306 return std::gcd((uint64_t)binExpr.getLHS().getLargestKnownDivisor(),
307 (uint64_t)binExpr.getRHS().getLargestKnownDivisor()) %
308 factor ==
309 0;
310 }
311 }
312 llvm_unreachable("Unknown AffineExpr");
313}
314
315bool AffineExpr::isFunctionOfDim(unsigned position) const {
317 return *this == mlir::getAffineDimExpr(position, getContext());
318 }
319 if (auto expr = llvm::dyn_cast<AffineBinaryOpExpr>(*this)) {
320 return expr.getLHS().isFunctionOfDim(position) ||
321 expr.getRHS().isFunctionOfDim(position);
322 }
323 return false;
324}
325
326bool AffineExpr::isFunctionOfSymbol(unsigned position) const {
328 return *this == mlir::getAffineSymbolExpr(position, getContext());
329 }
330 if (auto expr = llvm::dyn_cast<AffineBinaryOpExpr>(*this)) {
331 return expr.getLHS().isFunctionOfSymbol(position) ||
332 expr.getRHS().isFunctionOfSymbol(position);
333 }
334 return false;
335}
336
340 return static_cast<ImplType *>(expr)->lhs;
341}
343 return static_cast<ImplType *>(expr)->rhs;
344}
345
348 return static_cast<ImplType *>(expr)->position;
349}
350
351/// Returns true if the expression is divisible by the given symbol with
352/// position `symbolPos`. The argument `opKind` specifies here what kind of
353/// division or mod operation called this division. It helps in implementing the
354/// commutative property of the floordiv and ceildiv operations. If the argument
355///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv
356/// operation, then the commutative property can be used otherwise, the floordiv
357/// operation is not divisible. The same argument holds for ceildiv operation.
358static bool canSimplifyDivisionBySymbol(AffineExpr expr, unsigned symbolPos,
359 AffineExprKind opKind,
360 bool fromMul = false) {
361 // The argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
362 assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
363 opKind == AffineExprKind::CeilDiv) &&
364 "unexpected opKind");
365 switch (expr.getKind()) {
367 return cast<AffineConstantExpr>(expr).getValue() == 0;
369 return false;
371 return (cast<AffineSymbolExpr>(expr).getPosition() == symbolPos);
372 // Checks divisibility by the given symbol for both operands.
373 case AffineExprKind::Add: {
374 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
375 return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos,
376 opKind) &&
377 canSimplifyDivisionBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
378 }
379 // Checks divisibility by the given symbol for both operands. Consider the
380 // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
381 // this is a division by s1 and both the operands of modulo are divisible by
382 // s1 but it is not divisible by s1 always. The third argument is
383 // `AffineExprKind::Mod` for this reason.
384 case AffineExprKind::Mod: {
385 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
386 return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos,
388 canSimplifyDivisionBySymbol(binaryExpr.getRHS(), symbolPos,
390 }
391 // Checks if any of the operand divisible by the given symbol.
392 case AffineExprKind::Mul: {
393 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
394 return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos, opKind,
395 true) ||
396 canSimplifyDivisionBySymbol(binaryExpr.getRHS(), symbolPos, opKind,
397 true);
398 }
399 // Floordiv and ceildiv are divisible by the given symbol when the first
400 // operand is divisible, and the affine expression kind of the argument expr
401 // is same as the argument `opKind`. This can be inferred from commutative
402 // property of floordiv and ceildiv operations and are as follow:
403 // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
404 // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
405 // It will fail 1.if operations are not same. For example:
406 // (exps1 ceildiv exp2) floordiv exp3 can not be simplified. 2.if there is a
407 // multiplication operation in the expression. For example:
408 // (exps1 ceildiv exp2) mul exp3 ceildiv exp4 can not be simplified.
411 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
412 if (opKind != expr.getKind())
413 return false;
414 if (fromMul)
415 return false;
416 return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos,
417 expr.getKind());
418 }
419 }
420 llvm_unreachable("Unknown AffineExpr");
421}
422
423/// Divides the given expression by the given symbol at position `symbolPos`. It
424/// considers the divisibility condition is checked before calling itself. A
425/// null expression is returned whenever the divisibility condition fails.
426static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos,
427 AffineExprKind opKind) {
428 // THe argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
429 assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
430 opKind == AffineExprKind::CeilDiv) &&
431 "unexpected opKind");
432 switch (expr.getKind()) {
434 if (cast<AffineConstantExpr>(expr).getValue() != 0)
435 return nullptr;
436 return getAffineConstantExpr(0, expr.getContext());
438 return nullptr;
440 return getAffineConstantExpr(1, expr.getContext());
441 // Dividing both operands by the given symbol.
442 case AffineExprKind::Add: {
443 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
445 expr.getKind(), symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind),
446 symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind));
447 }
448 // Dividing both operands by the given symbol.
449 case AffineExprKind::Mod: {
450 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
452 expr.getKind(),
453 symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
454 symbolicDivide(binaryExpr.getRHS(), symbolPos, expr.getKind()));
455 }
456 // Dividing any of the operand by the given symbol.
457 case AffineExprKind::Mul: {
458 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
459 if (!canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos, opKind))
460 return binaryExpr.getLHS() *
461 symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind);
462 return symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind) *
463 binaryExpr.getRHS();
464 }
465 // Dividing first operand only by the given symbol.
468 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
470 expr.getKind(),
471 symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
472 binaryExpr.getRHS());
473 }
474 }
475 llvm_unreachable("Unknown AffineExpr");
476}
477
478/// Populate `result` with all summand operands of given (potentially nested)
479/// addition. If the given expression is not an addition, just populate the
480/// expression itself.
481/// Example: Add(Add(7, 8), Mul(9, 10)) will return [7, 8, Mul(9, 10)].
482static void getSummandExprs(AffineExpr expr, SmallVector<AffineExpr> &result) {
483 auto addExpr = dyn_cast<AffineBinaryOpExpr>(expr);
484 if (!addExpr || addExpr.getKind() != AffineExprKind::Add) {
485 result.push_back(expr);
486 return;
487 }
488 getSummandExprs(addExpr.getLHS(), result);
489 getSummandExprs(addExpr.getRHS(), result);
490}
491
492/// Return "true" if `candidate` is a negated expression, i.e., Mul(-1, expr).
493/// If so, also return the non-negated expression via `expr`.
494static bool isNegatedAffineExpr(AffineExpr candidate, AffineExpr &expr) {
495 auto mulExpr = dyn_cast<AffineBinaryOpExpr>(candidate);
496 if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul)
497 return false;
498 if (auto lhs = dyn_cast<AffineConstantExpr>(mulExpr.getLHS())) {
499 if (lhs.getValue() == -1) {
500 expr = mulExpr.getRHS();
501 return true;
502 }
503 }
504 if (auto rhs = dyn_cast<AffineConstantExpr>(mulExpr.getRHS())) {
505 if (rhs.getValue() == -1) {
506 expr = mulExpr.getLHS();
507 return true;
508 }
509 }
510 return false;
511}
512
513/// Return "true" if `lhs` % `rhs` is guaranteed to evaluate to zero based on
514/// the fact that `lhs` contains another modulo expression that ensures that
515/// `lhs` is divisible by `rhs`. This is a common pattern in the resulting IR
516/// after loop peeling.
517///
518/// Example: lhs = ub - ub % step
519/// rhs = step
520/// => (ub - ub % step) % step is guaranteed to evaluate to 0.
521static bool isModOfModSubtraction(AffineExpr lhs, AffineExpr rhs,
522 unsigned numDims, unsigned numSymbols) {
523 // TODO: Try to unify this function with `getBoundForAffineExpr`.
524 // Collect all summands in lhs.
526 getSummandExprs(lhs, summands);
527 // Look for Mul(-1, Mod(x, rhs)) among the summands. If x matches the
528 // remaining summands, then lhs % rhs is guaranteed to evaluate to 0.
529 for (int64_t i = 0, e = summands.size(); i < e; ++i) {
530 AffineExpr current = summands[i];
531 AffineExpr beforeNegation;
532 if (!isNegatedAffineExpr(current, beforeNegation))
533 continue;
534 AffineBinaryOpExpr innerMod = dyn_cast<AffineBinaryOpExpr>(beforeNegation);
535 if (!innerMod || innerMod.getKind() != AffineExprKind::Mod)
536 continue;
537 if (innerMod.getRHS() != rhs)
538 continue;
539 // Sum all remaining summands and subtract x. If that expression can be
540 // simplified to zero, then the remaining summands and x are equal.
541 AffineExpr diff = getAffineConstantExpr(0, lhs.getContext());
542 for (int64_t j = 0; j < e; ++j)
543 if (i != j)
544 diff = diff + summands[j];
545 diff = diff - innerMod.getLHS();
546 diff = simplifyAffineExpr(diff, numDims, numSymbols);
547 auto constExpr = dyn_cast<AffineConstantExpr>(diff);
548 if (constExpr && constExpr.getValue() == 0)
549 return true;
550 }
551 return false;
552}
553
554/// Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv
555/// operations when the second operand simplifies to a symbol and the first
556/// operand is divisible by that symbol. It can be applied to any semi-affine
557/// expression. Returned expression can either be a semi-affine or pure affine
558/// expression.
559static AffineExpr simplifySemiAffine(AffineExpr expr, unsigned numDims,
560 unsigned numSymbols) {
561 switch (expr.getKind()) {
565 return expr;
567 case AffineExprKind::Mul: {
568 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
570 expr.getKind(),
571 simplifySemiAffine(binaryExpr.getLHS(), numDims, numSymbols),
572 simplifySemiAffine(binaryExpr.getRHS(), numDims, numSymbols));
573 }
574 // Check if the simplification of the second operand is a symbol, and the
575 // first operand is divisible by it. If the operation is a modulo, a constant
576 // zero expression is returned. In the case of floordiv and ceildiv, the
577 // symbol from the simplification of the second operand divides the first
578 // operand. Otherwise, simplification is not possible.
581 case AffineExprKind::Mod: {
582 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
583 AffineExpr sLHS =
584 simplifySemiAffine(binaryExpr.getLHS(), numDims, numSymbols);
585 AffineExpr sRHS =
586 simplifySemiAffine(binaryExpr.getRHS(), numDims, numSymbols);
587 if (isModOfModSubtraction(sLHS, sRHS, numDims, numSymbols))
588 return getAffineConstantExpr(0, expr.getContext());
589 AffineSymbolExpr symbolExpr = dyn_cast<AffineSymbolExpr>(
590 simplifySemiAffine(binaryExpr.getRHS(), numDims, numSymbols));
591 if (!symbolExpr)
592 return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
593 unsigned symbolPos = symbolExpr.getPosition();
594 if (!canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos,
595 expr.getKind()))
596 return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
597 if (expr.getKind() == AffineExprKind::Mod)
598 return getAffineConstantExpr(0, expr.getContext());
599 AffineExpr simplifiedQuotient =
600 symbolicDivide(sLHS, symbolPos, expr.getKind());
601 return simplifiedQuotient
602 ? simplifiedQuotient
603 : getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
604 }
605 }
606 llvm_unreachable("Unknown AffineExpr");
607}
608
609static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position,
610 MLIRContext *context) {
611 auto assignCtx = [context](AffineDimExprStorage *storage) {
612 storage->context = context;
613 };
614
615 StorageUniquer &uniquer = context->getAffineUniquer();
616 return uniquer.get<AffineDimExprStorage>(
617 assignCtx, static_cast<unsigned>(kind), position);
618}
619
620AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) {
621 return getAffineDimOrSymbol(AffineExprKind::DimId, position, context);
622}
623
625 : AffineExpr(ptr) {}
626unsigned AffineSymbolExpr::getPosition() const {
627 return static_cast<ImplType *>(expr)->position;
628}
629
630AffineExpr mlir::getAffineSymbolExpr(unsigned position, MLIRContext *context) {
631 return getAffineDimOrSymbol(AffineExprKind::SymbolId, position, context);
632}
633
635 : AffineExpr(ptr) {}
636int64_t AffineConstantExpr::getValue() const {
637 return static_cast<ImplType *>(expr)->constant;
638}
639
640bool AffineExpr::operator==(int64_t v) const {
641 return *this == getAffineConstantExpr(v, getContext());
642}
643
644AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) {
645 auto assignCtx = [context](AffineConstantExprStorage *storage) {
646 storage->context = context;
647 };
648
649 StorageUniquer &uniquer = context->getAffineUniquer();
650 return uniquer.get<AffineConstantExprStorage>(assignCtx, constant);
651}
652
655 MLIRContext *context) {
656 return llvm::map_to_vector(constants, [&](int64_t constant) {
657 return getAffineConstantExpr(constant, context);
658 });
659}
660
661/// Simplify add expression. Return nullptr if it can't be simplified.
662static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
663 auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
664 auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
665 // Fold if both LHS, RHS are a constant and the sum does not overflow.
666 if (lhsConst && rhsConst) {
667 int64_t sum;
668 if (llvm::AddOverflow(lhsConst.getValue(), rhsConst.getValue(), sum)) {
669 return nullptr;
670 }
671 return getAffineConstantExpr(sum, lhs.getContext());
672 }
673
674 // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
675 // If only one of them is a symbolic expressions, make it the RHS.
676 if (isa<AffineConstantExpr>(lhs) ||
677 (lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())) {
678 return rhs + lhs;
679 }
680
681 // At this point, if there was a constant, it would be on the right.
682
683 // Addition with a zero is a noop, return the other input.
684 if (rhsConst) {
685 if (rhsConst.getValue() == 0)
686 return lhs;
687 }
688 // Fold successive additions like (d0 + 2) + 3 into d0 + 5.
689 auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
690 if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Add) {
691 if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS()))
692 return lBin.getLHS() + (lrhs.getValue() + rhsConst.getValue());
693 }
694
695 // Detect "c1 * expr + c_2 * expr" as "(c1 + c2) * expr".
696 // c1 is rRhsConst, c2 is rLhsConst; firstExpr, secondExpr are their
697 // respective multiplicands.
698 std::optional<int64_t> rLhsConst, rRhsConst;
699 AffineExpr firstExpr, secondExpr;
700 AffineConstantExpr rLhsConstExpr;
701 auto lBinOpExpr = dyn_cast<AffineBinaryOpExpr>(lhs);
702 if (lBinOpExpr && lBinOpExpr.getKind() == AffineExprKind::Mul &&
703 (rLhsConstExpr = dyn_cast<AffineConstantExpr>(lBinOpExpr.getRHS()))) {
704 rLhsConst = rLhsConstExpr.getValue();
705 firstExpr = lBinOpExpr.getLHS();
706 } else {
707 rLhsConst = 1;
708 firstExpr = lhs;
709 }
710
711 auto rBinOpExpr = dyn_cast<AffineBinaryOpExpr>(rhs);
712 AffineConstantExpr rRhsConstExpr;
713 if (rBinOpExpr && rBinOpExpr.getKind() == AffineExprKind::Mul &&
714 (rRhsConstExpr = dyn_cast<AffineConstantExpr>(rBinOpExpr.getRHS()))) {
715 rRhsConst = rRhsConstExpr.getValue();
716 secondExpr = rBinOpExpr.getLHS();
717 } else {
718 rRhsConst = 1;
719 secondExpr = rhs;
720 }
721
722 if (rLhsConst && rRhsConst && firstExpr == secondExpr)
724 AffineExprKind::Mul, firstExpr,
725 getAffineConstantExpr(*rLhsConst + *rRhsConst, lhs.getContext()));
726
727 // When doing successive additions, bring constant to the right: turn (d0 + 2)
728 // + d1 into (d0 + d1) + 2.
729 if (lBin && lBin.getKind() == AffineExprKind::Add) {
730 if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) {
731 return lBin.getLHS() + rhs + lrhs;
732 }
733 }
734
735 // Detect and transform "expr - q * (expr floordiv q)" to "expr mod q", where
736 // q may be a constant or symbolic expression. This leads to a much more
737 // efficient form when 'c' is a power of two, and in general a more compact
738 // and readable form.
739
740 // Process '(expr floordiv c) * (-c)'.
741 if (!rBinOpExpr || rBinOpExpr.getKind() != AffineExprKind::Mul)
742 return nullptr;
743
744 auto lrhs = rBinOpExpr.getLHS();
745 auto rrhs = rBinOpExpr.getRHS();
746
747 AffineExpr llrhs, rlrhs;
748
749 // Check if lrhsBinOpExpr is of the form (expr floordiv q) * q,
750 // where q is a symbolic expression.
751 auto lrhsBinOpExpr = dyn_cast<AffineBinaryOpExpr>(lrhs);
752 // Check rrhsConstOpExpr = -1 as part of ((expr floordiv q) * q)) * (-1).
753 auto rrhsConstOpExpr = dyn_cast<AffineConstantExpr>(rrhs);
754 if (rrhsConstOpExpr && rrhsConstOpExpr.getValue() == -1 && lrhsBinOpExpr &&
755 lrhsBinOpExpr.getKind() == AffineExprKind::Mul) {
756 // Check llrhs = expr floordiv q.
757 llrhs = lrhsBinOpExpr.getLHS();
758 // Check rlrhs = q.
759 rlrhs = lrhsBinOpExpr.getRHS();
760 auto llrhsBinOpExpr = dyn_cast<AffineBinaryOpExpr>(llrhs);
761 if (!llrhsBinOpExpr || llrhsBinOpExpr.getKind() != AffineExprKind::FloorDiv)
762 return nullptr;
763 if (llrhsBinOpExpr.getRHS() == rlrhs && lhs == llrhsBinOpExpr.getLHS())
764 return lhs % rlrhs;
765 }
766
767 // Process lrhs, which is 'expr floordiv c'.
768 // expr + (expr // c * -c) = expr % c
769 AffineBinaryOpExpr lrBinOpExpr = dyn_cast<AffineBinaryOpExpr>(lrhs);
770 if (!lrBinOpExpr || rhs.getKind() != AffineExprKind::Mul ||
771 lrBinOpExpr.getKind() != AffineExprKind::FloorDiv)
772 return nullptr;
773
774 llrhs = lrBinOpExpr.getLHS();
775 rlrhs = lrBinOpExpr.getRHS();
776 auto rlrhsConstOpExpr = dyn_cast<AffineConstantExpr>(rlrhs);
777 // We don't support modulo with a negative RHS.
778 bool isPositiveRhs = rlrhsConstOpExpr && rlrhsConstOpExpr.getValue() > 0;
779
780 if (isPositiveRhs && lhs == llrhs && rlrhs == -rrhs) {
781 return lhs % rlrhs;
782 }
783
784 // Try simplify lhs's last operand with rhs. e.g:
785 // (s0 * 64 + s1) + (s1 // c * -c) --->
786 // s0 * 64 + (s1 + s1 // c * -c) -->
787 // s0 * 64 + s1 % c
788 if (lBinOpExpr && lBinOpExpr.getKind() == AffineExprKind::Add) {
789 if (auto simplified = simplifyAdd(lBinOpExpr.getRHS(), rhs))
790 return lBinOpExpr.getLHS() + simplified;
791 }
792 return nullptr;
793}
794
795/// Get the canonical order of two commutative exprs arguments.
796static std::pair<AffineExpr, AffineExpr>
797orderCommutativeArgs(AffineExpr expr1, AffineExpr expr2) {
798 auto sym1 = dyn_cast<AffineSymbolExpr>(expr1);
799 auto sym2 = dyn_cast<AffineSymbolExpr>(expr2);
800 // Try to order by symbol/dim position first.
801 if (sym1 && sym2)
802 return sym1.getPosition() < sym2.getPosition() ? std::pair{expr1, expr2}
803 : std::pair{expr2, expr1};
804
805 auto dim1 = dyn_cast<AffineDimExpr>(expr1);
806 auto dim2 = dyn_cast<AffineDimExpr>(expr2);
807 if (dim1 && dim2)
808 return dim1.getPosition() < dim2.getPosition() ? std::pair{expr1, expr2}
809 : std::pair{expr2, expr1};
810
811 // Put dims before symbols.
812 if (dim1 && sym2)
813 return {dim1, sym2};
814
815 if (sym1 && dim2)
816 return {dim2, sym1};
817
818 // Otherwise, keep original order.
819 return {expr1, expr2};
820}
821
822AffineExpr AffineExpr::operator+(int64_t v) const {
823 return *this + getAffineConstantExpr(v, getContext());
824}
826 if (auto simplified = simplifyAdd(*this, other))
827 return simplified;
828
829 auto [lhs, rhs] = orderCommutativeArgs(*this, other);
830
831 StorageUniquer &uniquer = getContext()->getAffineUniquer();
832 return uniquer.get<AffineBinaryOpExprStorage>(
833 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), lhs, rhs);
834}
835
836/// Simplify a multiply expression. Return nullptr if it can't be simplified.
838 auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
839 auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
840
841 if (lhsConst && rhsConst) {
843 if (llvm::MulOverflow(lhsConst.getValue(), rhsConst.getValue(), product)) {
844 return nullptr;
845 }
846 return getAffineConstantExpr(product, lhs.getContext());
847 }
848
849 if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())
850 return nullptr;
851
852 // Canonicalize the mul expression so that the constant/symbolic term is the
853 // RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a
854 // constant. (Note that a constant is trivially symbolic).
855 if (!rhs.isSymbolicOrConstant() || isa<AffineConstantExpr>(lhs)) {
856 // At least one of them has to be symbolic.
857 return rhs * lhs;
858 }
859
860 // At this point, if there was a constant, it would be on the right.
861
862 // Multiplication with a one is a noop, return the other input.
863 if (rhsConst) {
864 if (rhsConst.getValue() == 1)
865 return lhs;
866 // Multiplication with zero.
867 if (rhsConst.getValue() == 0)
868 return rhsConst;
869 }
870
871 // Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6.
872 auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
873 if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Mul) {
874 if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS()))
875 return lBin.getLHS() * (lrhs.getValue() * rhsConst.getValue());
876 }
877
878 // When doing successive multiplication, bring constant to the right: turn (d0
879 // * 2) * d1 into (d0 * d1) * 2.
880 if (lBin && lBin.getKind() == AffineExprKind::Mul) {
881 if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) {
882 return (lBin.getLHS() * rhs) * lrhs;
883 }
884 }
885
886 return nullptr;
887}
888
893 if (auto simplified = simplifyMul(*this, other))
894 return simplified;
895
896 auto [lhs, rhs] = orderCommutativeArgs(*this, other);
897
899 return uniquer.get<AffineBinaryOpExprStorage>(
900 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), lhs, rhs);
901}
902
903// Unary minus, delegate to operator*.
905 return *this * getAffineConstantExpr(-1, getContext());
906}
907
908// Delegate to operator+.
909AffineExpr AffineExpr::operator-(int64_t v) const { return *this + (-v); }
911 return *this + (-other);
912}
913
915 auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
916
917 // For the defined cases, simplify x floordiv x is 1.
918 if (lhs == rhs && (!rhsConst || rhsConst.getValue() >= 1))
919 return getAffineConstantExpr(1, lhs.getContext());
920
921 // All other simplifications further below are for the RHS constant case.
922 if (!rhsConst || rhsConst.getValue() == 0)
923 return nullptr;
924
925 if (auto lhsConst = dyn_cast<AffineConstantExpr>(lhs)) {
926 if (divideSignedWouldOverflow(lhsConst.getValue(), rhsConst.getValue()))
927 return nullptr;
929 divideFloorSigned(lhsConst.getValue(), rhsConst.getValue()),
930 lhs.getContext());
931 }
932
933 // Fold floordiv of a multiply with a constant that is a multiple of the
934 // divisor. Eg: (i * 128) floordiv 64 = i * 2.
935 if (rhsConst == 1)
936 return lhs;
937
938 // Simplify `(expr * lrhs) floordiv rhsConst` when `lrhs` is known to be a
939 // multiple of `rhsConst`.
940 auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
941 if (lBin && lBin.getKind() == AffineExprKind::Mul) {
942 if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) {
943 // `rhsConst` is known to be a nonzero constant.
944 if (lrhs.getValue() % rhsConst.getValue() == 0)
945 return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
946 }
947 }
948
949 // Simplify (expr1 + expr2) floordiv divConst when either expr1 or expr2 is
950 // known to be a multiple of divConst.
951 if (lBin && lBin.getKind() == AffineExprKind::Add) {
952 int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
953 int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
954 // rhsConst is known to be a nonzero constant.
955 if (llhsDiv % rhsConst.getValue() == 0 ||
956 lrhsDiv % rhsConst.getValue() == 0)
957 return lBin.getLHS().floorDiv(rhsConst.getValue()) +
958 lBin.getRHS().floorDiv(rhsConst.getValue());
959 }
960
961 return nullptr;
962}
963
966}
968 if (auto simplified = simplifyFloorDiv(*this, other))
969 return simplified;
970
972 return uniquer.get<AffineBinaryOpExprStorage>(
973 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::FloorDiv), *this,
974 other);
975}
976
978 auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
979
980 // For the defined cases, simplify x ceildiv x is 1.
981 if (lhs == rhs && (!rhsConst || rhsConst.getValue() >= 1))
982 return getAffineConstantExpr(1, lhs.getContext());
983
984 // All other simplifications further below are for the RHS constant case.
985 if (!rhsConst || rhsConst.getValue() == 0)
986 return nullptr;
987
988 if (auto lhsConst = dyn_cast<AffineConstantExpr>(lhs)) {
989 if (divideSignedWouldOverflow(lhsConst.getValue(), rhsConst.getValue()))
990 return nullptr;
992 divideCeilSigned(lhsConst.getValue(), rhsConst.getValue()),
993 lhs.getContext());
994 }
995
996 // Fold ceildiv of a multiply with a constant that is a multiple of the
997 // divisor. Eg: (i * 128) ceildiv 64 = i * 2.
998 if (rhsConst.getValue() == 1)
999 return lhs;
1000
1001 // Simplify `(expr * lrhs) ceildiv rhsConst` when `lrhs` is known to be a
1002 // multiple of `rhsConst`.
1003 auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
1004 if (lBin && lBin.getKind() == AffineExprKind::Mul) {
1005 if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) {
1006 // `rhsConst` is known to be a nonzero constant.
1007 if (lrhs.getValue() % rhsConst.getValue() == 0)
1008 return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
1009 }
1010 }
1011
1012 return nullptr;
1013}
1014
1017}
1019 if (auto simplified = simplifyCeilDiv(*this, other))
1020 return simplified;
1021
1023 return uniquer.get<AffineBinaryOpExprStorage>(
1024 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::CeilDiv), *this,
1025 other);
1026}
1027
1029 auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
1030
1031 // For the defined cases, simplify x % x to 0.
1032 if (lhs == rhs && (!rhsConst || rhsConst.getValue() >= 1))
1033 return getAffineConstantExpr(0, lhs.getContext());
1034
1035 // mod w.r.t zero or negative numbers is undefined and preserved as is.
1036 // All other simplifications further below are for the RHS constant case.
1037 if (!rhsConst || rhsConst.getValue() < 1)
1038 return nullptr;
1039
1040 if (auto lhsConst = dyn_cast<AffineConstantExpr>(lhs)) {
1041 // mod never overflows.
1042 return getAffineConstantExpr(mod(lhsConst.getValue(), rhsConst.getValue()),
1043 lhs.getContext());
1044 }
1045
1046 // Fold modulo of an expression that is known to be a multiple of a constant
1047 // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128)
1048 // mod 64 is folded to 0, and less trivially, (i*(j*4*(k*32))) mod 128 = 0.
1049 if (lhs.getLargestKnownDivisor() % rhsConst.getValue() == 0)
1050 return getAffineConstantExpr(0, lhs.getContext());
1051
1052 // Simplify (expr1 + expr2) mod divConst when either expr1 or expr2 is
1053 // known to be a multiple of divConst.
1054 auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
1055 if (lBin && lBin.getKind() == AffineExprKind::Add) {
1056 int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
1057 int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
1058 // rhsConst is known to be a positive constant.
1059 if (llhsDiv % rhsConst.getValue() == 0)
1060 return lBin.getRHS() % rhsConst.getValue();
1061 if (lrhsDiv % rhsConst.getValue() == 0)
1062 return lBin.getLHS() % rhsConst.getValue();
1063 }
1064
1065 // Simplify (e % a) % b to e % b when b evenly divides a
1066 if (lBin && lBin.getKind() == AffineExprKind::Mod) {
1067 auto intermediate = dyn_cast<AffineConstantExpr>(lBin.getRHS());
1068 if (intermediate && intermediate.getValue() >= 1 &&
1069 mod(intermediate.getValue(), rhsConst.getValue()) == 0) {
1070 return lBin.getLHS() % rhsConst.getValue();
1071 }
1072 }
1073
1074 return nullptr;
1075}
1076
1078 return *this % getAffineConstantExpr(v, getContext());
1079}
1081 if (auto simplified = simplifyMod(*this, other))
1082 return simplified;
1083
1085 return uniquer.get<AffineBinaryOpExprStorage>(
1086 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mod), *this, other);
1087}
1088
1090 SmallVector<AffineExpr, 8> dimReplacements(map.getResults());
1091 return replaceDimsAndSymbols(dimReplacements, {});
1092}
1094 expr.print(os);
1095 return os;
1096}
1097
1098/// Constructs an affine expression from a flat ArrayRef. If there are local
1099/// identifiers (neither dimensional nor symbolic) that appear in the sum of
1100/// products expression, `localExprs` is expected to have the AffineExpr
1101/// for it, and is substituted into. The ArrayRef `flatExprs` is expected to be
1102/// in the format [dims, symbols, locals, constant term].
1104 unsigned numDims,
1105 unsigned numSymbols,
1106 ArrayRef<AffineExpr> localExprs,
1107 MLIRContext *context) {
1108 // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
1109 assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
1110 "unexpected number of local expressions");
1111
1112 auto expr = getAffineConstantExpr(0, context);
1113 // Dimensions and symbols.
1114 for (unsigned j = 0; j < numDims + numSymbols; j++) {
1115 if (flatExprs[j] == 0)
1116 continue;
1117 auto id = j < numDims ? getAffineDimExpr(j, context)
1118 : getAffineSymbolExpr(j - numDims, context);
1119 expr = expr + id * flatExprs[j];
1120 }
1121
1122 // Local identifiers.
1123 for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
1124 j++) {
1125 if (flatExprs[j] == 0)
1126 continue;
1127 auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
1128 expr = expr + term;
1129 }
1130
1131 // Constant term.
1132 int64_t constTerm = flatExprs[flatExprs.size() - 1];
1133 if (constTerm != 0)
1134 expr = expr + constTerm;
1135 return expr;
1136}
1137
1138/// Constructs a semi-affine expression from a flat ArrayRef. If there are
1139/// local identifiers (neither dimensional nor symbolic) that appear in the sum
1140/// of products expression, `localExprs` is expected to have the AffineExprs for
1141/// it, and is substituted into. The ArrayRef `flatExprs` is expected to be in
1142/// the format [dims, symbols, locals, constant term]. The semi-affine
1143/// expression is constructed in the sorted order of dimension and symbol
1144/// position numbers. Note: local expressions/ids are used for mod, div as well
1145/// as symbolic RHS terms for terms that are not pure affine.
1147 unsigned numDims,
1148 unsigned numSymbols,
1149 ArrayRef<AffineExpr> localExprs,
1150 MLIRContext *context) {
1151 assert(!flatExprs.empty() && "flatExprs cannot be empty");
1152
1153 // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
1154 assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
1155 "unexpected number of local expressions");
1156
1157 AffineExpr expr = getAffineConstantExpr(0, context);
1158
1159 // We design indices as a pair which help us present the semi-affine map as
1160 // sum of product where terms are sorted based on dimension or symbol
1161 // position: <keyA, keyB> for expressions of the form dimension * symbol,
1162 // where keyA is the position number of the dimension and keyB is the
1163 // position number of the symbol. For dimensional expressions we set the index
1164 // as (position number of the dimension, -1), as we want dimensional
1165 // expressions to appear before symbolic and product of dimensional and
1166 // symbolic expressions having the dimension with the same position number.
1167 // For symbolic expression set the index as (position number of the symbol,
1168 // maximum of last dimension and symbol position) number. For example, we want
1169 // the expression we are constructing to look something like: d0 + d0 * s0 +
1170 // s0 + d1*s1 + s1.
1171
1172 // Stores the affine expression corresponding to a given index.
1174 // Stores the constant coefficient value corresponding to a given
1175 // dimension, symbol or a non-pure affine expression stored in `localExprs`.
1177 // Stores the indices as defined above, and later sorted to produce
1178 // the semi-affine expression in the desired form.
1180
1181 // Example: expression = d0 + d0 * s0 + 2 * s0.
1182 // indices = [{0,-1}, {0, 0}, {0, 1}]
1183 // coefficients = [{{0, -1}, 1}, {{0, 0}, 1}, {{0, 1}, 2}]
1184 // indexToExprMap = [{{0, -1}, d0}, {{0, 0}, d0 * s0}, {{0, 1}, s0}]
1185
1186 // Adds entries to `indexToExprMap`, `coefficients` and `indices`.
1187 auto addEntry = [&](std::pair<unsigned, signed> index, int64_t coefficient,
1188 AffineExpr expr) {
1189 assert(!llvm::is_contained(indices, index) &&
1190 "Key is already present in indices vector and overwriting will "
1191 "happen in `indexToExprMap` and `coefficients`!");
1192
1193 indices.push_back(index);
1194 coefficients.insert({index, coefficient});
1195 indexToExprMap.insert({index, expr});
1196 };
1197
1198 // Design indices for dimensional or symbolic terms, and store the indices,
1199 // constant coefficient corresponding to the indices in `coefficients` map,
1200 // and affine expression corresponding to indices in `indexToExprMap` map.
1201
1202 // Ensure we do not have duplicate keys in `indexToExpr` map.
1203 unsigned offsetSym = 0;
1204 signed offsetDim = -1;
1205 for (unsigned j = numDims; j < numDims + numSymbols; ++j) {
1206 if (flatExprs[j] == 0)
1207 continue;
1208 // For symbolic expression set the index as <position number
1209 // of the symbol, max(dimCount, symCount)> number,
1210 // as we want symbolic expressions with the same positional number to
1211 // appear after dimensional expressions having the same positional number.
1212 std::pair<unsigned, signed> indexEntry(
1213 j - numDims, std::max(numDims, numSymbols) + offsetSym++);
1214 addEntry(indexEntry, flatExprs[j],
1215 getAffineSymbolExpr(j - numDims, context));
1216 }
1217
1218 // Denotes semi-affine product, modulo or division terms, which has been added
1219 // to the `indexToExpr` map.
1220 SmallVector<bool, 4> addedToMap(flatExprs.size() - numDims - numSymbols - 1,
1221 false);
1222 unsigned lhsPos, rhsPos;
1223 // Construct indices for product terms involving dimension, symbol or constant
1224 // as lhs/rhs, and store the indices, constant coefficient corresponding to
1225 // the indices in `coefficients` map, and affine expression corresponding to
1226 // in indices in `indexToExprMap` map.
1227 for (const auto &it : llvm::enumerate(localExprs)) {
1228 if (flatExprs[numDims + numSymbols + it.index()] == 0)
1229 continue;
1230 AffineExpr expr = it.value();
1231 auto binaryExpr = dyn_cast<AffineBinaryOpExpr>(expr);
1232 if (!binaryExpr)
1233 continue;
1234
1235 AffineExpr lhs = binaryExpr.getLHS();
1236 AffineExpr rhs = binaryExpr.getRHS();
1237 if (!((isa<AffineDimExpr>(lhs) || isa<AffineSymbolExpr>(lhs)) &&
1238 (isa<AffineDimExpr>(rhs) || isa<AffineSymbolExpr>(rhs) ||
1239 isa<AffineConstantExpr>(rhs)))) {
1240 continue;
1241 }
1242 if (isa<AffineConstantExpr>(rhs)) {
1243 // For product/modulo/division expressions, when rhs of modulo/division
1244 // expression is constant, we put 0 in place of keyB, because we want
1245 // them to appear earlier in the semi-affine expression we are
1246 // constructing. When rhs is constant, we place 0 in place of keyB.
1247 if (isa<AffineDimExpr>(lhs)) {
1248 lhsPos = cast<AffineDimExpr>(lhs).getPosition();
1249 std::pair<unsigned, signed> indexEntry(lhsPos, offsetDim--);
1250 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1251 expr);
1252 } else {
1253 lhsPos = cast<AffineSymbolExpr>(lhs).getPosition();
1254 std::pair<unsigned, signed> indexEntry(
1255 lhsPos, std::max(numDims, numSymbols) + offsetSym++);
1256 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1257 expr);
1258 }
1259 } else if (isa<AffineDimExpr>(lhs)) {
1260 // For product/modulo/division expressions having lhs as dimension and rhs
1261 // as symbol, we order the terms in the semi-affine expression based on
1262 // the pair: <keyA, keyB> for expressions of the form dimension * symbol,
1263 // where keyA is the position number of the dimension and keyB is the
1264 // position number of the symbol.
1265 lhsPos = cast<AffineDimExpr>(lhs).getPosition();
1266 rhsPos = cast<AffineSymbolExpr>(rhs).getPosition();
1267 std::pair<unsigned, signed> indexEntry(lhsPos, rhsPos);
1268 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1269 } else {
1270 // For product/modulo/division expressions having both lhs and rhs as
1271 // symbol, we design indices as a pair: <keyA, keyB> for expressions
1272 // of the form dimension * symbol, where keyA is the position number of
1273 // the dimension and keyB is the position number of the symbol.
1274 lhsPos = cast<AffineSymbolExpr>(lhs).getPosition();
1275 rhsPos = cast<AffineSymbolExpr>(rhs).getPosition();
1276 std::pair<unsigned, signed> indexEntry(
1277 lhsPos, std::max(numDims, numSymbols) + offsetSym++);
1278 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1279 }
1280 addedToMap[it.index()] = true;
1281 }
1282
1283 for (unsigned j = 0; j < numDims; ++j) {
1284 if (flatExprs[j] == 0)
1285 continue;
1286 // For dimensional expressions we set the index as <position number of the
1287 // dimension, 0>, as we want dimensional expressions to appear before
1288 // symbolic ones and products of dimensional and symbolic expressions
1289 // having the dimension with the same position number.
1290 std::pair<unsigned, signed> indexEntry(j, offsetDim--);
1291 addEntry(indexEntry, flatExprs[j], getAffineDimExpr(j, context));
1292 }
1293
1294 // Constructing the simplified semi-affine sum of product/division/mod
1295 // expression from the flattened form in the desired sorted order of indices
1296 // of the various individual product/division/mod expressions.
1297 llvm::sort(indices);
1298 for (const std::pair<unsigned, unsigned> index : indices) {
1299 assert(indexToExprMap.lookup(index) &&
1300 "cannot find key in `indexToExprMap` map");
1301 expr = expr + indexToExprMap.lookup(index) * coefficients.lookup(index);
1302 }
1303
1304 // Local identifiers.
1305 for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
1306 j++) {
1307 // If the coefficient of the local expression is 0, continue as we need not
1308 // add it in out final expression.
1309 if (flatExprs[j] == 0 || addedToMap[j - numDims - numSymbols])
1310 continue;
1311 auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
1312 expr = expr + term;
1313 }
1314
1315 // Constant term.
1316 int64_t constTerm = flatExprs.back();
1317 if (constTerm != 0)
1318 expr = expr + constTerm;
1319 return expr;
1320}
1321
1327
1328// In pure affine t = expr * c, we multiply each coefficient of lhs with c.
1329//
1330// In case of semi affine multiplication expressions, t = expr * symbolic_expr,
1331// introduce a local variable p (= expr * symbolic_expr), and the affine
1332// expression expr * symbolic_expr is added to `localExprs`.
1334 assert(operandExprStack.size() >= 2);
1336 operandExprStack.pop_back();
1338
1339 // Flatten semi-affine multiplication expressions by introducing a local
1340 // variable in place of the product; the affine expression
1341 // corresponding to the quantifier is added to `localExprs`.
1342 if (!isa<AffineConstantExpr>(expr.getRHS())) {
1344 MLIRContext *context = expr.getContext();
1346 localExprs, context);
1348 localExprs, context);
1349 return addLocalVariableSemiAffine(mulLhs, rhs, a * b, lhs, lhs.size());
1350 }
1351
1352 // Get the RHS constant.
1353 int64_t rhsConst = rhs[getConstantIndex()];
1354 for (int64_t &lhsElt : lhs)
1355 lhsElt *= rhsConst;
1356
1357 return success();
1358}
1359
1361 assert(operandExprStack.size() >= 2);
1362 const auto &rhs = operandExprStack.back();
1363 auto &lhs = operandExprStack[operandExprStack.size() - 2];
1364 assert(lhs.size() == rhs.size());
1365 // Update the LHS in place.
1366 for (unsigned i = 0, e = rhs.size(); i < e; i++) {
1367 lhs[i] += rhs[i];
1368 }
1369 // Pop off the RHS.
1370 operandExprStack.pop_back();
1371 return success();
1372}
1373
1374//
1375// t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1
1376//
1377// A mod expression "expr mod c" is thus flattened by introducing a new local
1378// variable q (= expr floordiv c), such that expr mod c is replaced with
1379// 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
1380//
1381// In case of semi-affine modulo expressions, t = expr mod symbolic_expr,
1382// introduce a local variable m (= expr mod symbolic_expr), and the affine
1383// expression expr mod symbolic_expr is added to `localExprs`.
1385 assert(operandExprStack.size() >= 2);
1386
1388 operandExprStack.pop_back();
1390 MLIRContext *context = expr.getContext();
1391
1392 // Flatten semi affine modulo expressions by introducing a local
1393 // variable in place of the modulo value, and the affine expression
1394 // corresponding to the quantifier is added to `localExprs`.
1395 if (!isa<AffineConstantExpr>(expr.getRHS())) {
1398 lhs, numDims, numSymbols, localExprs, context);
1400 localExprs, context);
1401 AffineExpr modExpr = dividendExpr % divisorExpr;
1402 return addLocalVariableSemiAffine(modLhs, rhs, modExpr, lhs, lhs.size());
1403 }
1404
1405 int64_t rhsConst = rhs[getConstantIndex()];
1406 if (rhsConst <= 0)
1407 return failure();
1408
1409 // Check if the LHS expression is a multiple of modulo factor.
1410 unsigned i, e;
1411 for (i = 0, e = lhs.size(); i < e; i++)
1412 if (lhs[i] % rhsConst != 0)
1413 break;
1414 // If yes, modulo expression here simplifies to zero.
1415 if (i == lhs.size()) {
1416 llvm::fill(lhs, 0);
1417 return success();
1418 }
1419
1420 // Add a local variable for the quotient, i.e., expr % c is replaced by
1421 // (expr - q * c) where q = expr floordiv c. Do this while canceling out
1422 // the GCD of expr and c.
1423 SmallVector<int64_t, 8> floorDividend(lhs);
1424 uint64_t gcd = rhsConst;
1425 for (int64_t lhsElt : lhs)
1426 gcd = std::gcd(gcd, (uint64_t)std::abs(lhsElt));
1427 // Simplify the numerator and the denominator.
1428 if (gcd != 1) {
1429 for (int64_t &floorDividendElt : floorDividend)
1430 floorDividendElt = floorDividendElt / static_cast<int64_t>(gcd);
1431 }
1432 int64_t floorDivisor = rhsConst / static_cast<int64_t>(gcd);
1433
1434 // Construct the AffineExpr form of the floordiv to store in localExprs.
1435
1437 floorDividend, numDims, numSymbols, localExprs, context);
1438 AffineExpr divisorExpr = getAffineConstantExpr(floorDivisor, context);
1439 AffineExpr floorDivExpr = dividendExpr.floorDiv(divisorExpr);
1440 int loc;
1441 if ((loc = findLocalId(floorDivExpr)) == -1) {
1442 addLocalFloorDivId(floorDividend, floorDivisor, floorDivExpr);
1443 // Set result at top of stack to "lhs - rhsConst * q".
1444 lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst;
1445 } else {
1446 // Reuse the existing local id.
1447 lhs[getLocalVarStartIndex() + loc] -= rhsConst;
1448 }
1449 return success();
1450}
1451
1452LogicalResult
1454 return visitDivExpr(expr, /*isCeil=*/true);
1455}
1456LogicalResult
1458 return visitDivExpr(expr, /*isCeil=*/false);
1459}
1460
1462 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1463 auto &eq = operandExprStack.back();
1464 assert(expr.getPosition() < numDims && "Inconsistent number of dims");
1465 eq[getDimStartIndex() + expr.getPosition()] = 1;
1466 return success();
1467}
1468
1469LogicalResult
1471 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1472 auto &eq = operandExprStack.back();
1473 assert(expr.getPosition() < numSymbols && "inconsistent number of symbols");
1474 eq[getSymbolStartIndex() + expr.getPosition()] = 1;
1475 return success();
1476}
1477
1478LogicalResult
1480 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1481 auto &eq = operandExprStack.back();
1482 eq[getConstantIndex()] = expr.getValue();
1483 return success();
1484}
1485
1486LogicalResult SimpleAffineExprFlattener::addLocalVariableSemiAffine(
1488 SmallVectorImpl<int64_t> &result, unsigned long resultSize) {
1489 assert(result.size() == resultSize &&
1490 "`result` vector passed is not of correct size");
1491 int loc;
1492 if ((loc = findLocalId(localExpr)) == -1) {
1493 if (failed(addLocalIdSemiAffine(lhs, rhs, localExpr)))
1494 return failure();
1495 }
1496 llvm::fill(result, 0);
1497 if (loc == -1)
1498 result[getLocalVarStartIndex() + numLocals - 1] = 1;
1499 else
1500 result[getLocalVarStartIndex() + loc] = 1;
1501 return success();
1502}
1503
1504// t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1
1505// A floordiv is thus flattened by introducing a new local variable q, and
1506// replacing that expression with 'q' while adding the constraints
1507// c * q <= expr <= c * q + c - 1 to localVarCst (done by
1508// IntegerRelation::addLocalFloorDiv).
1509//
1510// A ceildiv is similarly flattened:
1511// t = expr ceildiv c <=> t = (expr + c - 1) floordiv c
1512//
1513// In case of semi affine division expressions, t = expr floordiv symbolic_expr
1514// or t = expr ceildiv symbolic_expr, introduce a local variable q (= expr
1515// floordiv/ceildiv symbolic_expr), and the affine floordiv/ceildiv is added to
1516// `localExprs`.
1517LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
1518 bool isCeil) {
1519 assert(operandExprStack.size() >= 2);
1520
1521 MLIRContext *context = expr.getContext();
1522 SmallVector<int64_t, 8> rhs = operandExprStack.back();
1523 operandExprStack.pop_back();
1524 SmallVector<int64_t, 8> &lhs = operandExprStack.back();
1525
1526 // Flatten semi affine division expressions by introducing a local
1527 // variable in place of the quotient, and the affine expression corresponding
1528 // to the quantifier is added to `localExprs`.
1529 if (!isa<AffineConstantExpr>(expr.getRHS())) {
1530 SmallVector<int64_t, 8> divLhs(lhs);
1532 localExprs, context);
1534 localExprs, context);
1535 AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
1536 return addLocalVariableSemiAffine(divLhs, rhs, divExpr, lhs, lhs.size());
1537 }
1538
1539 // This is a pure affine expr; the RHS is a positive constant.
1540 int64_t rhsConst = rhs[getConstantIndex()];
1541 if (rhsConst <= 0)
1542 return failure();
1543
1544 // Simplify the floordiv, ceildiv if possible by canceling out the greatest
1545 // common divisors of the numerator and denominator.
1546 uint64_t gcd = std::abs(rhsConst);
1547 for (int64_t lhsElt : lhs)
1548 gcd = std::gcd(gcd, (uint64_t)std::abs(lhsElt));
1549 // Simplify the numerator and the denominator.
1550 if (gcd != 1) {
1551 for (int64_t &lhsElt : lhs)
1552 lhsElt = lhsElt / static_cast<int64_t>(gcd);
1553 }
1554 int64_t divisor = rhsConst / static_cast<int64_t>(gcd);
1555 // If the divisor becomes 1, the updated LHS is the result. (The
1556 // divisor can't be negative since rhsConst is positive).
1557 if (divisor == 1)
1558 return success();
1559
1560 // If the divisor cannot be simplified to one, we will have to retain
1561 // the ceil/floor expr (simplified up until here). Add an existential
1562 // quantifier to express its result, i.e., expr1 div expr2 is replaced
1563 // by a new identifier, q.
1564 AffineExpr a =
1566 AffineExpr b = getAffineConstantExpr(divisor, context);
1567
1568 int loc;
1569 AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
1570 if ((loc = findLocalId(divExpr)) == -1) {
1571 if (!isCeil) {
1572 SmallVector<int64_t, 8> dividend(lhs);
1573 addLocalFloorDivId(dividend, divisor, divExpr);
1574 } else {
1575 // lhs ceildiv c <=> (lhs + c - 1) floordiv c
1576 SmallVector<int64_t, 8> dividend(lhs);
1577 dividend.back() += divisor - 1;
1578 addLocalFloorDivId(dividend, divisor, divExpr);
1579 }
1580 }
1581 // Set the expression on stack to the local var introduced to capture the
1582 // result of the division (floor or ceil).
1583 llvm::fill(lhs, 0);
1584 if (loc == -1)
1585 lhs[getLocalVarStartIndex() + numLocals - 1] = 1;
1586 else
1587 lhs[getLocalVarStartIndex() + loc] = 1;
1588 return success();
1589}
1590
1591// Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
1592// The local identifier added is always a floordiv of a pure add/mul affine
1593// function of other identifiers, coefficients of which are specified in
1594// dividend and with respect to a positive constant divisor. localExpr is the
1595// simplified tree expression (AffineExpr) corresponding to the quantifier.
1597 int64_t divisor,
1598 AffineExpr localExpr) {
1599 assert(divisor > 0 && "positive constant divisor expected");
1601 subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
1602 localExprs.push_back(localExpr);
1603 numLocals++;
1604 // dividend and divisor are not used here; an override of this method uses it.
1605}
1606
1610 subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
1611 localExprs.push_back(localExpr);
1612 ++numLocals;
1613 // lhs and rhs are not used here; an override of this method uses them.
1614 return success();
1615}
1616
1617int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) {
1619 if ((it = llvm::find(localExprs, localExpr)) == localExprs.end())
1620 return -1;
1621 return it - localExprs.begin();
1622}
1623
1624/// Simplify the affine expression by flattening it and reconstructing it.
1626 unsigned numSymbols) {
1627 // Simplify semi-affine expressions separately.
1628 if (!expr.isPureAffine())
1629 expr = simplifySemiAffine(expr, numDims, numSymbols);
1630
1631 SimpleAffineExprFlattener flattener(numDims, numSymbols);
1632 // has poison expression
1633 if (failed(flattener.walkPostOrder(expr)))
1634 return expr;
1635 ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
1636 if (!expr.isPureAffine() &&
1637 expr == getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1638 flattener.localExprs,
1639 expr.getContext()))
1640 return expr;
1641 AffineExpr simplifiedExpr =
1642 expr.isPureAffine()
1643 ? getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1644 flattener.localExprs, expr.getContext())
1645 : getSemiAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1646 flattener.localExprs,
1647 expr.getContext());
1648
1649 flattener.operandExprStack.pop_back();
1650 assert(flattener.operandExprStack.empty());
1651 return simplifiedExpr;
1652}
1653
1654std::optional<int64_t> mlir::getBoundForAffineExpr(
1655 AffineExpr expr, unsigned numDims, unsigned numSymbols,
1656 ArrayRef<std::optional<int64_t>> constLowerBounds,
1657 ArrayRef<std::optional<int64_t>> constUpperBounds, bool isUpper) {
1658 // Handle divs and mods.
1659 if (auto binOpExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
1660 // If the LHS of a floor or ceil is bounded and the RHS is a constant, we
1661 // can compute an upper bound.
1662 if (binOpExpr.getKind() == AffineExprKind::FloorDiv) {
1663 auto rhsConst = dyn_cast<AffineConstantExpr>(binOpExpr.getRHS());
1664 if (!rhsConst || rhsConst.getValue() < 1)
1665 return std::nullopt;
1666 auto bound =
1667 getBoundForAffineExpr(binOpExpr.getLHS(), numDims, numSymbols,
1668 constLowerBounds, constUpperBounds, isUpper);
1669 if (!bound)
1670 return std::nullopt;
1671 return divideFloorSigned(*bound, rhsConst.getValue());
1672 }
1673 if (binOpExpr.getKind() == AffineExprKind::CeilDiv) {
1674 auto rhsConst = dyn_cast<AffineConstantExpr>(binOpExpr.getRHS());
1675 if (rhsConst && rhsConst.getValue() >= 1) {
1676 auto bound =
1677 getBoundForAffineExpr(binOpExpr.getLHS(), numDims, numSymbols,
1678 constLowerBounds, constUpperBounds, isUpper);
1679 if (!bound)
1680 return std::nullopt;
1681 return divideCeilSigned(*bound, rhsConst.getValue());
1682 }
1683 return std::nullopt;
1684 }
1685 if (binOpExpr.getKind() == AffineExprKind::Mod) {
1686 // lhs mod c is always <= c - 1 and non-negative. In addition, if `lhs` is
1687 // bounded such that lb <= lhs <= ub and lb floordiv c == ub floordiv c
1688 // (same "interval"), then lb mod c <= lhs mod c <= ub mod c.
1689 auto rhsConst = dyn_cast<AffineConstantExpr>(binOpExpr.getRHS());
1690 if (rhsConst && rhsConst.getValue() >= 1) {
1691 int64_t rhsConstVal = rhsConst.getValue();
1692 auto lb = getBoundForAffineExpr(binOpExpr.getLHS(), numDims, numSymbols,
1693 constLowerBounds, constUpperBounds,
1694 /*isUpper=*/false);
1695 auto ub =
1696 getBoundForAffineExpr(binOpExpr.getLHS(), numDims, numSymbols,
1697 constLowerBounds, constUpperBounds, isUpper);
1698 if (ub && lb &&
1699 divideFloorSigned(*lb, rhsConstVal) ==
1700 divideFloorSigned(*ub, rhsConstVal))
1701 return isUpper ? mod(*ub, rhsConstVal) : mod(*lb, rhsConstVal);
1702 return isUpper ? rhsConstVal - 1 : 0;
1703 }
1704 }
1705 }
1706 // Flatten the expression.
1707 SimpleAffineExprFlattener flattener(numDims, numSymbols);
1708 auto simpleResult = flattener.walkPostOrder(expr);
1709 // has poison expression
1710 if (failed(simpleResult))
1711 return std::nullopt;
1712 ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
1713 // TODO: Handle local variables. We can get hold of flattener.localExprs and
1714 // get bound on the local expr recursively.
1715 if (flattener.numLocals > 0)
1716 return std::nullopt;
1717 int64_t bound = 0;
1718 // Substitute the constant lower or upper bound for the dimensional or
1719 // symbolic input depending on `isUpper` to determine the bound.
1720 for (unsigned i = 0, e = numDims + numSymbols; i < e; ++i) {
1721 if (flattenedExpr[i] > 0) {
1722 auto &constBound = isUpper ? constUpperBounds[i] : constLowerBounds[i];
1723 if (!constBound)
1724 return std::nullopt;
1725 bound += *constBound * flattenedExpr[i];
1726 } else if (flattenedExpr[i] < 0) {
1727 auto &constBound = isUpper ? constLowerBounds[i] : constUpperBounds[i];
1728 if (!constBound)
1729 return std::nullopt;
1730 bound += *constBound * flattenedExpr[i];
1731 }
1732 }
1733 // Constant term.
1734 bound += flattenedExpr.back();
1735 return bound;
1736}
return success()
static int64_t product(ArrayRef< int64_t > vals)
lhs
static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs)
Simplify a multiply expression. Return nullptr if it can't be simplified.
static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs)
static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef< int64_t > flatExprs, unsigned numDims, unsigned numSymbols, ArrayRef< AffineExpr > localExprs, MLIRContext *context)
Constructs a semi-affine expression from a flat ArrayRef. If there are local identifiers (neither dim...
static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs)
static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
Affine binary operation expression.
Definition AffineExpr.h:214
AffineExpr getLHS() const
AffineBinaryOpExpr(AffineExpr::ImplType *ptr)
detail::AffineBinaryOpExprStorage ImplType
Definition AffineExpr.h:216
AffineExpr getRHS() const
An integer constant appearing in affine expression.
Definition AffineExpr.h:239
AffineConstantExpr(AffineExpr::ImplType *ptr=nullptr)
detail::AffineConstantExprStorage ImplType
Definition AffineExpr.h:241
int64_t getValue() const
A dimensional identifier appearing in an affine expression.
Definition AffineExpr.h:223
AffineDimExpr(AffineExpr::ImplType *ptr)
detail::AffineDimExprStorage ImplType
Definition AffineExpr.h:225
unsigned getPosition() const
See documentation for AffineExprVisitorBase.
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
Definition AffineExpr.h:68
AffineExpr replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements) const
This method substitutes any uses of dimensions and symbols (e.g.
AffineExpr shiftDims(unsigned numDims, unsigned shift, unsigned offset=0) const
Replace dims[offset ... numDims) by dims[offset + shift ... shift + numDims).
bool isSymbolicOrConstant() const
Returns true if this expression is made out of only symbols and constants, i.e., it does not involve ...
AffineExpr operator+(int64_t v) const
AffineExpr operator*(int64_t v) const
constexpr AffineExpr()
Definition AffineExpr.h:72
bool operator==(AffineExpr other) const
Definition AffineExpr.h:76
bool isPureAffine() const
Returns true if this is a pure affine expression, i.e., multiplication, floordiv, ceildiv,...
AffineExpr shiftSymbols(unsigned numSymbols, unsigned shift, unsigned offset=0) const
Replace symbols[offset ... numSymbols) by symbols[offset + shift ... shift + numSymbols).
AffineExpr operator-() const
AffineExpr floorDiv(uint64_t v) const
ImplType * expr
Definition AffineExpr.h:196
RetT walk(FnT &&callback) const
Walk all of the AffineExpr's in this expression in postorder.
Definition AffineExpr.h:117
AffineExprKind getKind() const
Return the classification for this type.
bool isMultipleOf(int64_t factor) const
Return true if the affine expression is a multiple of 'factor'.
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
AffineExpr compose(AffineMap map) const
Compose with an AffineMap.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
bool isFunctionOfSymbol(unsigned position) const
Return true if the affine expression involves AffineSymbolExpr position.
AffineExpr replaceDims(ArrayRef< AffineExpr > dimReplacements) const
Dim-only version of replaceDimsAndSymbols.
AffineExpr operator%(uint64_t v) const
MLIRContext * getContext() const
AffineExpr replace(AffineExpr expr, AffineExpr replacement) const
Sparse replace method.
AffineExpr replaceSymbols(ArrayRef< AffineExpr > symReplacements) const
Symbol-only version of replaceDimsAndSymbols.
detail::AffineExprStorage ImplType
Definition AffineExpr.h:70
AffineExpr ceilDiv(uint64_t v) const
void print(raw_ostream &os) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
ArrayRef< AffineExpr > getResults() const
A symbolic identifier appearing in an affine expression.
Definition AffineExpr.h:231
unsigned getPosition() const
detail::AffineDimExprStorage ImplType
Definition AffineExpr.h:233
AffineSymbolExpr(AffineExpr::ImplType *ptr)
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
StorageUniquer & getAffineUniquer()
Returns the storage uniquer used for creating affine constructs.
virtual void addLocalFloorDivId(ArrayRef< int64_t > dividend, int64_t divisor, AffineExpr localExpr)
LogicalResult visitSymbolExpr(AffineSymbolExpr expr)
std::vector< SmallVector< int64_t, 8 > > operandExprStack
LogicalResult visitDimExpr(AffineDimExpr expr)
LogicalResult visitFloorDivExpr(AffineBinaryOpExpr expr)
LogicalResult visitConstantExpr(AffineConstantExpr expr)
virtual LogicalResult addLocalIdSemiAffine(ArrayRef< int64_t > lhs, ArrayRef< int64_t > rhs, AffineExpr localExpr)
Add a local identifier (needed to flatten a mod, floordiv, ceildiv, mul expr) when the rhs is a symbo...
LogicalResult visitModExpr(AffineBinaryOpExpr expr)
LogicalResult visitAddExpr(AffineBinaryOpExpr expr)
LogicalResult visitCeilDivExpr(AffineBinaryOpExpr expr)
LogicalResult visitMulExpr(AffineBinaryOpExpr expr)
SmallVector< AffineExpr, 4 > localExprs
SimpleAffineExprFlattener(unsigned numDims, unsigned numSymbols)
A utility class to get or create instances of "storage classes".
Storage * get(function_ref< void(Storage *)> initFn, TypeID id, Args &&...args)
Gets a uniqued instance of 'Storage'.
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
AttrTypeReplacer.
Include the generated interface declarations.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
std::optional< int64_t > getBoundForAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols, ArrayRef< std::optional< int64_t > > constLowerBounds, ArrayRef< std::optional< int64_t > > constUpperBounds, bool isUpper)
Get a lower or upper (depending on isUpper) bound for expr while using the constant lower and upper b...
AffineExprKind
Definition AffineExpr.h:40
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
Definition AffineExpr.h:50
@ Mul
RHS of mul is always a constant or a symbolic expression.
Definition AffineExpr.h:43
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
Definition AffineExpr.h:46
@ DimId
Dimensional identifier.
Definition AffineExpr.h:59
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
Definition AffineExpr.h:48
@ Constant
Constant integer.
Definition AffineExpr.h:57
@ SymbolId
Symbolic identifier.
Definition AffineExpr.h:61
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
AffineExpr getAffineExprFromFlatForm(ArrayRef< int64_t > flatExprs, unsigned numDims, unsigned numSymbols, ArrayRef< AffineExpr > localExprs, MLIRContext *context)
Constructs an affine expression from a flat ArrayRef.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
SmallVector< AffineExpr > getAffineConstantExprs(ArrayRef< int64_t > constants, MLIRContext *context)
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
A binary operation appearing in an affine expression.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.