MLIR 23.0.0git
AffineExpr.cpp
Go to the documentation of this file.
1//===- AffineExpr.cpp - MLIR Affine Expr Classes --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <cmath>
10#include <cstdint>
11#include <utility>
12
13#include "AffineExprDetail.h"
14#include "mlir/IR/AffineExpr.h"
16#include "mlir/IR/AffineMap.h"
17#include "mlir/IR/IntegerSet.h"
18#include "llvm/ADT/STLExtras.h"
19#include "llvm/ADT/SmallVectorExtras.h"
20#include "llvm/Support/MathExtras.h"
21#include <numeric>
22#include <optional>
23
24using namespace mlir;
25using namespace mlir::detail;
26
27using llvm::divideCeilSigned;
28using llvm::divideFloorSigned;
29using llvm::divideSignedWouldOverflow;
30using llvm::mod;
31
32MLIRContext *AffineExpr::getContext() const { return expr->context; }
33
34AffineExprKind AffineExpr::getKind() const { return expr->kind; }
35
36/// Walk all of the AffineExprs in `e` in postorder. This is a private factory
37/// method to help handle lambda walk functions. Users should use the regular
38/// (non-static) `walk` method.
39template <typename WalkRetTy>
41 function_ref<WalkRetTy(AffineExpr)> callback) {
42 struct AffineExprWalker
43 : public AffineExprVisitor<AffineExprWalker, WalkRetTy> {
44 function_ref<WalkRetTy(AffineExpr)> callback;
45
46 AffineExprWalker(function_ref<WalkRetTy(AffineExpr)> callback)
47 : callback(callback) {}
48
49 WalkRetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
50 return callback(expr);
51 }
52 WalkRetTy visitConstantExpr(AffineConstantExpr expr) {
53 return callback(expr);
54 }
55 WalkRetTy visitDimExpr(AffineDimExpr expr) { return callback(expr); }
56 WalkRetTy visitSymbolExpr(AffineSymbolExpr expr) { return callback(expr); }
57 };
58
59 return AffineExprWalker(callback).walkPostOrder(e);
60}
61// Explicitly instantiate for the two supported return types.
63 function_ref<void(AffineExpr)> callback);
64template WalkResult
67
68// Dispatch affine expression construction based on kind.
71 if (kind == AffineExprKind::Add)
72 return lhs + rhs;
73 if (kind == AffineExprKind::Mul)
74 return lhs * rhs;
75 if (kind == AffineExprKind::FloorDiv)
76 return lhs.floorDiv(rhs);
77 if (kind == AffineExprKind::CeilDiv)
78 return lhs.ceilDiv(rhs);
79 if (kind == AffineExprKind::Mod)
80 return lhs % rhs;
81
82 llvm_unreachable("unknown binary operation on affine expressions");
83}
84
85/// This method substitutes any uses of dimensions and symbols (e.g.
86/// dim#0 with dimReplacements[0]) and returns the modified expression tree.
89 ArrayRef<AffineExpr> symReplacements) const {
90 switch (getKind()) {
92 return *this;
94 unsigned dimId = llvm::cast<AffineDimExpr>(*this).getPosition();
95 if (dimId >= dimReplacements.size())
96 return *this;
97 return dimReplacements[dimId];
98 }
100 unsigned symId = llvm::cast<AffineSymbolExpr>(*this).getPosition();
101 if (symId >= symReplacements.size())
102 return *this;
103 return symReplacements[symId];
104 }
110 auto binOp = llvm::cast<AffineBinaryOpExpr>(*this);
111 auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
112 auto newLHS = lhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
113 auto newRHS = rhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
114 if (newLHS == lhs && newRHS == rhs)
115 return *this;
116 return getAffineBinaryOpExpr(getKind(), newLHS, newRHS);
117 }
118 llvm_unreachable("Unknown AffineExpr");
119}
120
122 return replaceDimsAndSymbols(dimReplacements, {});
123}
124
127 return replaceDimsAndSymbols({}, symReplacements);
128}
129
130/// Replace dims[offset ... numDims)
131/// by dims[offset + shift ... shift + numDims).
132AffineExpr AffineExpr::shiftDims(unsigned numDims, unsigned shift,
133 unsigned offset) const {
135 for (unsigned idx = 0; idx < offset; ++idx)
136 dims.push_back(getAffineDimExpr(idx, getContext()));
137 for (unsigned idx = offset; idx < numDims; ++idx)
138 dims.push_back(getAffineDimExpr(idx + shift, getContext()));
139 return replaceDimsAndSymbols(dims, {});
140}
141
142/// Replace symbols[offset ... numSymbols)
143/// by symbols[offset + shift ... shift + numSymbols).
144AffineExpr AffineExpr::shiftSymbols(unsigned numSymbols, unsigned shift,
145 unsigned offset) const {
147 for (unsigned idx = 0; idx < offset; ++idx)
148 symbols.push_back(getAffineSymbolExpr(idx, getContext()));
149 for (unsigned idx = offset; idx < numSymbols; ++idx)
150 symbols.push_back(getAffineSymbolExpr(idx + shift, getContext()));
151 return replaceDimsAndSymbols({}, symbols);
152}
153
154/// Sparse replace method. Return the modified expression tree.
157 auto it = map.find(*this);
158 if (it != map.end())
159 return it->second;
160 switch (getKind()) {
161 default:
162 return *this;
168 auto binOp = llvm::cast<AffineBinaryOpExpr>(*this);
169 auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
170 auto newLHS = lhs.replace(map);
171 auto newRHS = rhs.replace(map);
172 if (newLHS == lhs && newRHS == rhs)
173 return *this;
174 return getAffineBinaryOpExpr(getKind(), newLHS, newRHS);
175 }
176 llvm_unreachable("Unknown AffineExpr");
177}
178
179/// Sparse replace method. Return the modified expression tree.
182 map.insert(std::make_pair(expr, replacement));
183 return replace(map);
184}
185/// Returns true if this expression is made out of only symbols and
186/// constants (no dimensional identifiers).
188 switch (getKind()) {
190 return true;
192 return false;
194 return true;
195
200 case AffineExprKind::Mod: {
201 auto expr = llvm::cast<AffineBinaryOpExpr>(*this);
202 return expr.getLHS().isSymbolicOrConstant() &&
203 expr.getRHS().isSymbolicOrConstant();
204 }
205 }
206 llvm_unreachable("Unknown AffineExpr");
207}
208
209/// Returns true if this is a pure affine expression, i.e., multiplication,
210/// floordiv, ceildiv, and mod is only allowed w.r.t constants.
212 switch (getKind()) {
216 return true;
217 case AffineExprKind::Add: {
218 auto op = llvm::cast<AffineBinaryOpExpr>(*this);
219 return op.getLHS().isPureAffine() && op.getRHS().isPureAffine();
220 }
221
222 case AffineExprKind::Mul: {
223 // TODO: Canonicalize the constants in binary operators to the RHS when
224 // possible, allowing this to merge into the next case.
225 auto op = llvm::cast<AffineBinaryOpExpr>(*this);
226 return op.getLHS().isPureAffine() && op.getRHS().isPureAffine() &&
227 (llvm::isa<AffineConstantExpr>(op.getLHS()) ||
228 llvm::isa<AffineConstantExpr>(op.getRHS()));
229 }
232 case AffineExprKind::Mod: {
233 auto op = llvm::cast<AffineBinaryOpExpr>(*this);
234 return op.getLHS().isPureAffine() &&
235 llvm::isa<AffineConstantExpr>(op.getRHS());
236 }
237 }
238 llvm_unreachable("Unknown AffineExpr");
239}
240
241// Returns the greatest known integral divisor of this affine expression.
243 AffineBinaryOpExpr binExpr(nullptr);
244 switch (getKind()) {
246 [[fallthrough]];
248 return 1;
250 [[fallthrough]];
252 // If the RHS is a constant and divides the known divisor on the LHS, the
253 // quotient is a known divisor of the expression.
254 binExpr = llvm::cast<AffineBinaryOpExpr>(*this);
255 auto rhs = llvm::dyn_cast<AffineConstantExpr>(binExpr.getRHS());
256 // Leave alone undefined expressions.
257 if (rhs && rhs.getValue() != 0) {
258 int64_t lhsDiv = binExpr.getLHS().getLargestKnownDivisor();
259 if (lhsDiv % rhs.getValue() == 0)
260 return std::abs(lhsDiv / rhs.getValue());
261 }
262 return 1;
263 }
265 return std::abs(llvm::cast<AffineConstantExpr>(*this).getValue());
266 case AffineExprKind::Mul: {
267 binExpr = llvm::cast<AffineBinaryOpExpr>(*this);
268 return binExpr.getLHS().getLargestKnownDivisor() *
269 binExpr.getRHS().getLargestKnownDivisor();
270 }
272 [[fallthrough]];
273 case AffineExprKind::Mod: {
274 binExpr = llvm::cast<AffineBinaryOpExpr>(*this);
275 return std::gcd((uint64_t)binExpr.getLHS().getLargestKnownDivisor(),
276 (uint64_t)binExpr.getRHS().getLargestKnownDivisor());
277 }
278 }
279 llvm_unreachable("Unknown AffineExpr");
280}
281
283 AffineBinaryOpExpr binExpr(nullptr);
284 uint64_t l, u;
285 switch (getKind()) {
287 [[fallthrough]];
289 return factor * factor == 1;
291 return llvm::cast<AffineConstantExpr>(*this).getValue() % factor == 0;
292 case AffineExprKind::Mul: {
293 binExpr = llvm::cast<AffineBinaryOpExpr>(*this);
294 // It's probably not worth optimizing this further (to not traverse the
295 // whole sub-tree under - it that would require a version of isMultipleOf
296 // that on a 'false' return also returns the largest known divisor).
297 return (l = binExpr.getLHS().getLargestKnownDivisor()) % factor == 0 ||
298 (u = binExpr.getRHS().getLargestKnownDivisor()) % factor == 0 ||
299 (l * u) % factor == 0;
300 }
304 case AffineExprKind::Mod: {
305 binExpr = llvm::cast<AffineBinaryOpExpr>(*this);
306 return std::gcd((uint64_t)binExpr.getLHS().getLargestKnownDivisor(),
307 (uint64_t)binExpr.getRHS().getLargestKnownDivisor()) %
308 factor ==
309 0;
310 }
311 }
312 llvm_unreachable("Unknown AffineExpr");
313}
314
315bool AffineExpr::isFunctionOfDim(unsigned position) const {
317 return *this == mlir::getAffineDimExpr(position, getContext());
318 }
319 if (auto expr = llvm::dyn_cast<AffineBinaryOpExpr>(*this)) {
320 return expr.getLHS().isFunctionOfDim(position) ||
321 expr.getRHS().isFunctionOfDim(position);
322 }
323 return false;
324}
325
326bool AffineExpr::isFunctionOfSymbol(unsigned position) const {
328 return *this == mlir::getAffineSymbolExpr(position, getContext());
329 }
330 if (auto expr = llvm::dyn_cast<AffineBinaryOpExpr>(*this)) {
331 return expr.getLHS().isFunctionOfSymbol(position) ||
332 expr.getRHS().isFunctionOfSymbol(position);
333 }
334 return false;
335}
336
340 return static_cast<ImplType *>(expr)->lhs;
341}
343 return static_cast<ImplType *>(expr)->rhs;
344}
345
348 return static_cast<ImplType *>(expr)->position;
349}
350
351/// Returns true if the expression is divisible by the given symbol with
352/// position `symbolPos`. The argument `opKind` specifies here what kind of
353/// division or mod operation called this division. It helps in implementing the
354/// commutative property of the floordiv and ceildiv operations. If the argument
355///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv
356/// operation, then the commutative property can be used otherwise, the floordiv
357/// operation is not divisible. The same argument holds for ceildiv operation.
358static bool canSimplifyDivisionBySymbol(AffineExpr expr, unsigned symbolPos,
359 AffineExprKind opKind,
360 bool fromMul = false) {
361 // The argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
362 assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
363 opKind == AffineExprKind::CeilDiv) &&
364 "unexpected opKind");
365 switch (expr.getKind()) {
367 return cast<AffineConstantExpr>(expr).getValue() == 0;
369 return false;
371 return (cast<AffineSymbolExpr>(expr).getPosition() == symbolPos);
372 // Checks divisibility by the given symbol for both operands.
373 case AffineExprKind::Add: {
374 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
375 return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos,
376 opKind) &&
377 canSimplifyDivisionBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
378 }
379 // Checks divisibility by the given symbol for both operands. Consider the
380 // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
381 // this is a division by s1 and both the operands of modulo are divisible by
382 // s1 but it is not divisible by s1 always. The third argument is
383 // `AffineExprKind::Mod` for this reason.
384 case AffineExprKind::Mod: {
385 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
386 return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos,
388 canSimplifyDivisionBySymbol(binaryExpr.getRHS(), symbolPos,
390 }
391 // Checks if any of the operand divisible by the given symbol.
392 case AffineExprKind::Mul: {
393 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
394 return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos, opKind,
395 true) ||
396 canSimplifyDivisionBySymbol(binaryExpr.getRHS(), symbolPos, opKind,
397 true);
398 }
399 // Floordiv and ceildiv are divisible by the given symbol when the first
400 // operand is divisible, and the affine expression kind of the argument expr
401 // is same as the argument `opKind`. This can be inferred from commutative
402 // property of floordiv and ceildiv operations and are as follow:
403 // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
404 // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
405 // It will fail 1.if operations are not same. For example:
406 // (exps1 ceildiv exp2) floordiv exp3 can not be simplified. 2.if there is a
407 // multiplication operation in the expression. For example:
408 // (exps1 ceildiv exp2) mul exp3 ceildiv exp4 can not be simplified.
411 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
412 if (opKind != expr.getKind())
413 return false;
414 if (fromMul)
415 return false;
416 return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos,
417 expr.getKind());
418 }
419 }
420 llvm_unreachable("Unknown AffineExpr");
421}
422
423/// Divides the given expression by the given symbol at position `symbolPos`. It
424/// considers the divisibility condition is checked before calling itself. A
425/// null expression is returned whenever the divisibility condition fails.
426static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos,
427 AffineExprKind opKind) {
428 // THe argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
429 assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
430 opKind == AffineExprKind::CeilDiv) &&
431 "unexpected opKind");
432 switch (expr.getKind()) {
434 if (cast<AffineConstantExpr>(expr).getValue() != 0)
435 return nullptr;
436 return getAffineConstantExpr(0, expr.getContext());
438 return nullptr;
440 return getAffineConstantExpr(1, expr.getContext());
441 // Dividing both operands by the given symbol.
442 case AffineExprKind::Add: {
443 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
445 expr.getKind(), symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind),
446 symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind));
447 }
448 // Dividing both operands by the given symbol.
449 case AffineExprKind::Mod: {
450 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
452 expr.getKind(),
453 symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
454 symbolicDivide(binaryExpr.getRHS(), symbolPos, expr.getKind()));
455 }
456 // Dividing any of the operand by the given symbol.
457 case AffineExprKind::Mul: {
458 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
459 if (!canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos, opKind))
460 return binaryExpr.getLHS() *
461 symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind);
462 return symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind) *
463 binaryExpr.getRHS();
464 }
465 // Dividing first operand only by the given symbol.
468 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
470 expr.getKind(),
471 symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
472 binaryExpr.getRHS());
473 }
474 }
475 llvm_unreachable("Unknown AffineExpr");
476}
477
478/// Populate `result` with all summand operands of given (potentially nested)
479/// addition. If the given expression is not an addition, just populate the
480/// expression itself.
481/// Example: Add(Add(7, 8), Mul(9, 10)) will return [7, 8, Mul(9, 10)].
482static void getSummandExprs(AffineExpr expr, SmallVector<AffineExpr> &result) {
483 auto addExpr = dyn_cast<AffineBinaryOpExpr>(expr);
484 if (!addExpr || addExpr.getKind() != AffineExprKind::Add) {
485 result.push_back(expr);
486 return;
487 }
488 getSummandExprs(addExpr.getLHS(), result);
489 getSummandExprs(addExpr.getRHS(), result);
490}
491
492/// Return "true" if `candidate` is a negated expression, i.e., Mul(-1, expr).
493/// If so, also return the non-negated expression via `expr`.
494static bool isNegatedAffineExpr(AffineExpr candidate, AffineExpr &expr) {
495 auto mulExpr = dyn_cast<AffineBinaryOpExpr>(candidate);
496 if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul)
497 return false;
498 if (auto lhs = dyn_cast<AffineConstantExpr>(mulExpr.getLHS())) {
499 if (lhs.getValue() == -1) {
500 expr = mulExpr.getRHS();
501 return true;
502 }
503 }
504 if (auto rhs = dyn_cast<AffineConstantExpr>(mulExpr.getRHS())) {
505 if (rhs.getValue() == -1) {
506 expr = mulExpr.getLHS();
507 return true;
508 }
509 }
510 return false;
511}
512
513/// Return "true" if `lhs` % `rhs` is guaranteed to evaluate to zero based on
514/// the fact that `lhs` contains another modulo expression that ensures that
515/// `lhs` is divisible by `rhs`. This is a common pattern in the resulting IR
516/// after loop peeling.
517///
518/// Example: lhs = ub - ub % step
519/// rhs = step
520/// => (ub - ub % step) % step is guaranteed to evaluate to 0.
521static bool isModOfModSubtraction(AffineExpr lhs, AffineExpr rhs,
522 unsigned numDims, unsigned numSymbols) {
523 // TODO: Try to unify this function with `getBoundForAffineExpr`.
524 // Collect all summands in lhs.
526 getSummandExprs(lhs, summands);
527 // Look for Mul(-1, Mod(x, rhs)) among the summands. If x matches the
528 // remaining summands, then lhs % rhs is guaranteed to evaluate to 0.
529 for (int64_t i = 0, e = summands.size(); i < e; ++i) {
530 AffineExpr current = summands[i];
531 AffineExpr beforeNegation;
532 if (!isNegatedAffineExpr(current, beforeNegation))
533 continue;
534 AffineBinaryOpExpr innerMod = dyn_cast<AffineBinaryOpExpr>(beforeNegation);
535 if (!innerMod || innerMod.getKind() != AffineExprKind::Mod)
536 continue;
537 if (innerMod.getRHS() != rhs)
538 continue;
539 // Sum all remaining summands and subtract x. If that expression can be
540 // simplified to zero, then the remaining summands and x are equal.
541 AffineExpr diff = getAffineConstantExpr(0, lhs.getContext());
542 for (int64_t j = 0; j < e; ++j)
543 if (i != j)
544 diff = diff + summands[j];
545 diff = diff - innerMod.getLHS();
546 diff = simplifyAffineExpr(diff, numDims, numSymbols);
547 auto constExpr = dyn_cast<AffineConstantExpr>(diff);
548 if (constExpr && constExpr.getValue() == 0)
549 return true;
550 }
551 return false;
552}
553
554/// Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv
555/// operations when the second operand simplifies to a symbol and the first
556/// operand is divisible by that symbol. It can be applied to any semi-affine
557/// expression. Returned expression can either be a semi-affine or pure affine
558/// expression.
559static AffineExpr simplifySemiAffine(AffineExpr expr, unsigned numDims,
560 unsigned numSymbols) {
561 switch (expr.getKind()) {
565 return expr;
567 case AffineExprKind::Mul: {
568 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
570 expr.getKind(),
571 simplifySemiAffine(binaryExpr.getLHS(), numDims, numSymbols),
572 simplifySemiAffine(binaryExpr.getRHS(), numDims, numSymbols));
573 }
574 // Check if the simplification of the second operand is a symbol, and the
575 // first operand is divisible by it. If the operation is a modulo, a constant
576 // zero expression is returned. In the case of floordiv and ceildiv, the
577 // symbol from the simplification of the second operand divides the first
578 // operand. Otherwise, simplification is not possible.
581 case AffineExprKind::Mod: {
582 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
583 AffineExpr sLHS =
584 simplifySemiAffine(binaryExpr.getLHS(), numDims, numSymbols);
585 AffineExpr sRHS =
586 simplifySemiAffine(binaryExpr.getRHS(), numDims, numSymbols);
587 if (isModOfModSubtraction(sLHS, sRHS, numDims, numSymbols))
588 return getAffineConstantExpr(0, expr.getContext());
589 AffineSymbolExpr symbolExpr = dyn_cast<AffineSymbolExpr>(
590 simplifySemiAffine(binaryExpr.getRHS(), numDims, numSymbols));
591 if (!symbolExpr)
592 return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
593 unsigned symbolPos = symbolExpr.getPosition();
594 if (!canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos,
595 expr.getKind()))
596 return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
597 if (expr.getKind() == AffineExprKind::Mod)
598 return getAffineConstantExpr(0, expr.getContext());
599 AffineExpr simplifiedQuotient =
600 symbolicDivide(sLHS, symbolPos, expr.getKind());
601 return simplifiedQuotient
602 ? simplifiedQuotient
603 : getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
604 }
605 }
606 llvm_unreachable("Unknown AffineExpr");
607}
608
609static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position,
610 MLIRContext *context) {
611 auto assignCtx = [context](AffineDimExprStorage *storage) {
612 storage->context = context;
613 };
614
615 StorageUniquer &uniquer = context->getAffineUniquer();
616 return uniquer.get<AffineDimExprStorage>(
617 assignCtx, static_cast<unsigned>(kind), position);
618}
619
620AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) {
621 return getAffineDimOrSymbol(AffineExprKind::DimId, position, context);
622}
623
625 : AffineExpr(ptr) {}
626unsigned AffineSymbolExpr::getPosition() const {
627 return static_cast<ImplType *>(expr)->position;
628}
629
630AffineExpr mlir::getAffineSymbolExpr(unsigned position, MLIRContext *context) {
631 return getAffineDimOrSymbol(AffineExprKind::SymbolId, position, context);
632}
633
635 : AffineExpr(ptr) {}
636int64_t AffineConstantExpr::getValue() const {
637 return static_cast<ImplType *>(expr)->constant;
638}
639
640bool AffineExpr::operator==(int64_t v) const {
641 return *this == getAffineConstantExpr(v, getContext());
642}
643
644AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) {
645 auto assignCtx = [context](AffineConstantExprStorage *storage) {
646 storage->context = context;
647 };
648
649 StorageUniquer &uniquer = context->getAffineUniquer();
650 return uniquer.get<AffineConstantExprStorage>(assignCtx, constant);
651}
652
655 MLIRContext *context) {
656 return llvm::map_to_vector(constants, [&](int64_t constant) {
657 return getAffineConstantExpr(constant, context);
658 });
659}
660
661/// Simplify add expression. Return nullptr if it can't be simplified.
662static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
663 auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
664 auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
665 // Fold if both LHS, RHS are a constant and the sum does not overflow.
666 if (lhsConst && rhsConst) {
667 int64_t sum;
668 if (llvm::AddOverflow(lhsConst.getValue(), rhsConst.getValue(), sum)) {
669 return nullptr;
670 }
671 return getAffineConstantExpr(sum, lhs.getContext());
672 }
673
674 // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
675 // If only one of them is a symbolic expressions, make it the RHS.
676 if (isa<AffineConstantExpr>(lhs) ||
677 (lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())) {
678 return rhs + lhs;
679 }
680
681 // At this point, if there was a constant, it would be on the right.
682
683 // Addition with a zero is a noop, return the other input.
684 if (rhsConst) {
685 if (rhsConst.getValue() == 0)
686 return lhs;
687 }
688 // Fold successive additions like (d0 + 2) + 3 into d0 + 5.
689 auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
690 if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Add) {
691 if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS()))
692 return lBin.getLHS() + (lrhs.getValue() + rhsConst.getValue());
693 }
694
695 // Detect "c1 * expr + c_2 * expr" as "(c1 + c2) * expr".
696 // c1 is rRhsConst, c2 is rLhsConst; firstExpr, secondExpr are their
697 // respective multiplicands.
698 std::optional<int64_t> rLhsConst, rRhsConst;
699 AffineExpr firstExpr, secondExpr;
700 AffineConstantExpr rLhsConstExpr;
701 auto lBinOpExpr = dyn_cast<AffineBinaryOpExpr>(lhs);
702 if (lBinOpExpr && lBinOpExpr.getKind() == AffineExprKind::Mul &&
703 (rLhsConstExpr = dyn_cast<AffineConstantExpr>(lBinOpExpr.getRHS()))) {
704 rLhsConst = rLhsConstExpr.getValue();
705 firstExpr = lBinOpExpr.getLHS();
706 } else {
707 rLhsConst = 1;
708 firstExpr = lhs;
709 }
710
711 auto rBinOpExpr = dyn_cast<AffineBinaryOpExpr>(rhs);
712 AffineConstantExpr rRhsConstExpr;
713 if (rBinOpExpr && rBinOpExpr.getKind() == AffineExprKind::Mul &&
714 (rRhsConstExpr = dyn_cast<AffineConstantExpr>(rBinOpExpr.getRHS()))) {
715 rRhsConst = rRhsConstExpr.getValue();
716 secondExpr = rBinOpExpr.getLHS();
717 } else {
718 rRhsConst = 1;
719 secondExpr = rhs;
720 }
721
722 if (rLhsConst && rRhsConst && firstExpr == secondExpr)
724 AffineExprKind::Mul, firstExpr,
725 getAffineConstantExpr(*rLhsConst + *rRhsConst, lhs.getContext()));
726
727 // When doing successive additions, bring constant to the right: turn (d0 + 2)
728 // + d1 into (d0 + d1) + 2.
729 if (lBin && lBin.getKind() == AffineExprKind::Add) {
730 if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) {
731 return lBin.getLHS() + rhs + lrhs;
732 }
733 }
734
735 // Detect and transform "expr - q * (expr floordiv q)" to "expr mod q", where
736 // q may be a constant or symbolic expression. This leads to a much more
737 // efficient form when 'c' is a power of two, and in general a more compact
738 // and readable form.
739
740 // Process '(expr floordiv c) * (-c)'.
741 if (!rBinOpExpr)
742 return nullptr;
743
744 auto lrhs = rBinOpExpr.getLHS();
745 auto rrhs = rBinOpExpr.getRHS();
746
747 AffineExpr llrhs, rlrhs;
748
749 // Check if lrhsBinOpExpr is of the form (expr floordiv q) * q, where q is a
750 // symbolic expression.
751 auto lrhsBinOpExpr = dyn_cast<AffineBinaryOpExpr>(lrhs);
752 // Check rrhsConstOpExpr = -1.
753 auto rrhsConstOpExpr = dyn_cast<AffineConstantExpr>(rrhs);
754 if (rrhsConstOpExpr && rrhsConstOpExpr.getValue() == -1 && lrhsBinOpExpr &&
755 lrhsBinOpExpr.getKind() == AffineExprKind::Mul) {
756 // Check llrhs = expr floordiv q.
757 llrhs = lrhsBinOpExpr.getLHS();
758 // Check rlrhs = q.
759 rlrhs = lrhsBinOpExpr.getRHS();
760 auto llrhsBinOpExpr = dyn_cast<AffineBinaryOpExpr>(llrhs);
761 if (!llrhsBinOpExpr || llrhsBinOpExpr.getKind() != AffineExprKind::FloorDiv)
762 return nullptr;
763 if (llrhsBinOpExpr.getRHS() == rlrhs && lhs == llrhsBinOpExpr.getLHS())
764 return lhs % rlrhs;
765 }
766
767 // Process lrhs, which is 'expr floordiv c'.
768 // expr + (expr // c * -c) = expr % c
769 AffineBinaryOpExpr lrBinOpExpr = dyn_cast<AffineBinaryOpExpr>(lrhs);
770 if (!lrBinOpExpr || rhs.getKind() != AffineExprKind::Mul ||
771 lrBinOpExpr.getKind() != AffineExprKind::FloorDiv)
772 return nullptr;
773
774 llrhs = lrBinOpExpr.getLHS();
775 rlrhs = lrBinOpExpr.getRHS();
776 auto rlrhsConstOpExpr = dyn_cast<AffineConstantExpr>(rlrhs);
777 // We don't support modulo with a negative RHS.
778 bool isPositiveRhs = rlrhsConstOpExpr && rlrhsConstOpExpr.getValue() > 0;
779
780 if (isPositiveRhs && lhs == llrhs && rlrhs == -rrhs) {
781 return lhs % rlrhs;
782 }
783
784 // Try simplify lhs's last operand with rhs. e.g:
785 // (s0 * 64 + s1) + (s1 // c * -c) --->
786 // s0 * 64 + (s1 + s1 // c * -c) -->
787 // s0 * 64 + s1 % c
788 if (lBinOpExpr && lBinOpExpr.getKind() == AffineExprKind::Add) {
789 if (auto simplified = simplifyAdd(lBinOpExpr.getRHS(), rhs))
790 return lBinOpExpr.getLHS() + simplified;
791 }
792 return nullptr;
793}
794
795/// Get the canonical order of two commutative exprs arguments.
796static std::pair<AffineExpr, AffineExpr>
797orderCommutativeArgs(AffineExpr expr1, AffineExpr expr2) {
798 auto sym1 = dyn_cast<AffineSymbolExpr>(expr1);
799 auto sym2 = dyn_cast<AffineSymbolExpr>(expr2);
800 // Try to order by symbol/dim position first.
801 if (sym1 && sym2)
802 return sym1.getPosition() < sym2.getPosition() ? std::pair{expr1, expr2}
803 : std::pair{expr2, expr1};
804
805 auto dim1 = dyn_cast<AffineDimExpr>(expr1);
806 auto dim2 = dyn_cast<AffineDimExpr>(expr2);
807 if (dim1 && dim2)
808 return dim1.getPosition() < dim2.getPosition() ? std::pair{expr1, expr2}
809 : std::pair{expr2, expr1};
810
811 // Put dims before symbols.
812 if (dim1 && sym2)
813 return {dim1, sym2};
814
815 if (sym1 && dim2)
816 return {dim2, sym1};
817
818 // Otherwise, keep original order.
819 return {expr1, expr2};
820}
821
822AffineExpr AffineExpr::operator+(int64_t v) const {
823 return *this + getAffineConstantExpr(v, getContext());
824}
826 if (auto simplified = simplifyAdd(*this, other))
827 return simplified;
828
829 auto [lhs, rhs] = orderCommutativeArgs(*this, other);
830
831 StorageUniquer &uniquer = getContext()->getAffineUniquer();
832 return uniquer.get<AffineBinaryOpExprStorage>(
833 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), lhs, rhs);
834}
835
836/// Simplify a multiply expression. Return nullptr if it can't be simplified.
838 auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
839 auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
840
841 if (lhsConst && rhsConst) {
843 if (llvm::MulOverflow(lhsConst.getValue(), rhsConst.getValue(), product)) {
844 return nullptr;
845 }
846 return getAffineConstantExpr(product, lhs.getContext());
847 }
848
849 if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())
850 return nullptr;
851
852 // Canonicalize the mul expression so that the constant/symbolic term is the
853 // RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a
854 // constant. (Note that a constant is trivially symbolic).
855 if (!rhs.isSymbolicOrConstant() || isa<AffineConstantExpr>(lhs)) {
856 // At least one of them has to be symbolic.
857 return rhs * lhs;
858 }
859
860 // At this point, if there was a constant, it would be on the right.
861
862 // Multiplication with a one is a noop, return the other input.
863 if (rhsConst) {
864 if (rhsConst.getValue() == 1)
865 return lhs;
866 // Multiplication with zero.
867 if (rhsConst.getValue() == 0)
868 return rhsConst;
869 }
870
871 // Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6.
872 auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
873 if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Mul) {
874 if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS()))
875 return lBin.getLHS() * (lrhs.getValue() * rhsConst.getValue());
876 }
877
878 // When doing successive multiplication, bring constant to the right: turn (d0
879 // * 2) * d1 into (d0 * d1) * 2.
880 if (lBin && lBin.getKind() == AffineExprKind::Mul) {
881 if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) {
882 return (lBin.getLHS() * rhs) * lrhs;
883 }
884 }
885
886 return nullptr;
887}
888
893 if (auto simplified = simplifyMul(*this, other))
894 return simplified;
895
896 auto [lhs, rhs] = orderCommutativeArgs(*this, other);
897
899 return uniquer.get<AffineBinaryOpExprStorage>(
900 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), lhs, rhs);
901}
902
903// Unary minus, delegate to operator*.
905 return *this * getAffineConstantExpr(-1, getContext());
906}
907
908// Delegate to operator+.
909AffineExpr AffineExpr::operator-(int64_t v) const { return *this + (-v); }
911 return *this + (-other);
912}
913
915 auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
916 auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
917
918 if (!rhsConst || rhsConst.getValue() == 0)
919 return nullptr;
920
921 if (lhsConst) {
922 if (divideSignedWouldOverflow(lhsConst.getValue(), rhsConst.getValue()))
923 return nullptr;
925 divideFloorSigned(lhsConst.getValue(), rhsConst.getValue()),
926 lhs.getContext());
927 }
928
929 // Fold floordiv of a multiply with a constant that is a multiple of the
930 // divisor. Eg: (i * 128) floordiv 64 = i * 2.
931 if (rhsConst == 1)
932 return lhs;
933
934 // Simplify `(expr * lrhs) floordiv rhsConst` when `lrhs` is known to be a
935 // multiple of `rhsConst`.
936 auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
937 if (lBin && lBin.getKind() == AffineExprKind::Mul) {
938 if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) {
939 // `rhsConst` is known to be a nonzero constant.
940 if (lrhs.getValue() % rhsConst.getValue() == 0)
941 return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
942 }
943 }
944
945 // Simplify (expr1 + expr2) floordiv divConst when either expr1 or expr2 is
946 // known to be a multiple of divConst.
947 if (lBin && lBin.getKind() == AffineExprKind::Add) {
948 int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
949 int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
950 // rhsConst is known to be a nonzero constant.
951 if (llhsDiv % rhsConst.getValue() == 0 ||
952 lrhsDiv % rhsConst.getValue() == 0)
953 return lBin.getLHS().floorDiv(rhsConst.getValue()) +
954 lBin.getRHS().floorDiv(rhsConst.getValue());
955 }
956
957 return nullptr;
958}
959
962}
964 if (auto simplified = simplifyFloorDiv(*this, other))
965 return simplified;
966
968 return uniquer.get<AffineBinaryOpExprStorage>(
969 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::FloorDiv), *this,
970 other);
971}
972
974 auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
975 auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
976
977 if (!rhsConst || rhsConst.getValue() == 0)
978 return nullptr;
979
980 if (lhsConst) {
981 if (divideSignedWouldOverflow(lhsConst.getValue(), rhsConst.getValue()))
982 return nullptr;
984 divideCeilSigned(lhsConst.getValue(), rhsConst.getValue()),
985 lhs.getContext());
986 }
987
988 // Fold ceildiv of a multiply with a constant that is a multiple of the
989 // divisor. Eg: (i * 128) ceildiv 64 = i * 2.
990 if (rhsConst.getValue() == 1)
991 return lhs;
992
993 // Simplify `(expr * lrhs) ceildiv rhsConst` when `lrhs` is known to be a
994 // multiple of `rhsConst`.
995 auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
996 if (lBin && lBin.getKind() == AffineExprKind::Mul) {
997 if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) {
998 // `rhsConst` is known to be a nonzero constant.
999 if (lrhs.getValue() % rhsConst.getValue() == 0)
1000 return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
1001 }
1002 }
1003
1004 return nullptr;
1005}
1006
1009}
1011 if (auto simplified = simplifyCeilDiv(*this, other))
1012 return simplified;
1013
1015 return uniquer.get<AffineBinaryOpExprStorage>(
1016 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::CeilDiv), *this,
1017 other);
1018}
1019
1021 auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
1022 auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
1023
1024 // mod w.r.t zero or negative numbers is undefined and preserved as is.
1025 if (!rhsConst || rhsConst.getValue() < 1)
1026 return nullptr;
1027
1028 if (lhsConst) {
1029 // mod never overflows.
1030 return getAffineConstantExpr(mod(lhsConst.getValue(), rhsConst.getValue()),
1031 lhs.getContext());
1032 }
1033
1034 // Fold modulo of an expression that is known to be a multiple of a constant
1035 // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128)
1036 // mod 64 is folded to 0, and less trivially, (i*(j*4*(k*32))) mod 128 = 0.
1037 if (lhs.getLargestKnownDivisor() % rhsConst.getValue() == 0)
1038 return getAffineConstantExpr(0, lhs.getContext());
1039
1040 // Simplify (expr1 + expr2) mod divConst when either expr1 or expr2 is
1041 // known to be a multiple of divConst.
1042 auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
1043 if (lBin && lBin.getKind() == AffineExprKind::Add) {
1044 int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
1045 int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
1046 // rhsConst is known to be a positive constant.
1047 if (llhsDiv % rhsConst.getValue() == 0)
1048 return lBin.getRHS() % rhsConst.getValue();
1049 if (lrhsDiv % rhsConst.getValue() == 0)
1050 return lBin.getLHS() % rhsConst.getValue();
1051 }
1052
1053 // Simplify (e % a) % b to e % b when b evenly divides a
1054 if (lBin && lBin.getKind() == AffineExprKind::Mod) {
1055 auto intermediate = dyn_cast<AffineConstantExpr>(lBin.getRHS());
1056 if (intermediate && intermediate.getValue() >= 1 &&
1057 mod(intermediate.getValue(), rhsConst.getValue()) == 0) {
1058 return lBin.getLHS() % rhsConst.getValue();
1059 }
1060 }
1061
1062 return nullptr;
1063}
1064
1066 return *this % getAffineConstantExpr(v, getContext());
1067}
1069 if (auto simplified = simplifyMod(*this, other))
1070 return simplified;
1071
1073 return uniquer.get<AffineBinaryOpExprStorage>(
1074 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mod), *this, other);
1075}
1076
1078 SmallVector<AffineExpr, 8> dimReplacements(map.getResults());
1079 return replaceDimsAndSymbols(dimReplacements, {});
1080}
1082 expr.print(os);
1083 return os;
1084}
1085
1086/// Constructs an affine expression from a flat ArrayRef. If there are local
1087/// identifiers (neither dimensional nor symbolic) that appear in the sum of
1088/// products expression, `localExprs` is expected to have the AffineExpr
1089/// for it, and is substituted into. The ArrayRef `flatExprs` is expected to be
1090/// in the format [dims, symbols, locals, constant term].
1092 unsigned numDims,
1093 unsigned numSymbols,
1094 ArrayRef<AffineExpr> localExprs,
1095 MLIRContext *context) {
1096 // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
1097 assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
1098 "unexpected number of local expressions");
1099
1100 auto expr = getAffineConstantExpr(0, context);
1101 // Dimensions and symbols.
1102 for (unsigned j = 0; j < numDims + numSymbols; j++) {
1103 if (flatExprs[j] == 0)
1104 continue;
1105 auto id = j < numDims ? getAffineDimExpr(j, context)
1106 : getAffineSymbolExpr(j - numDims, context);
1107 expr = expr + id * flatExprs[j];
1108 }
1109
1110 // Local identifiers.
1111 for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
1112 j++) {
1113 if (flatExprs[j] == 0)
1114 continue;
1115 auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
1116 expr = expr + term;
1117 }
1118
1119 // Constant term.
1120 int64_t constTerm = flatExprs[flatExprs.size() - 1];
1121 if (constTerm != 0)
1122 expr = expr + constTerm;
1123 return expr;
1124}
1125
1126/// Constructs a semi-affine expression from a flat ArrayRef. If there are
1127/// local identifiers (neither dimensional nor symbolic) that appear in the sum
1128/// of products expression, `localExprs` is expected to have the AffineExprs for
1129/// it, and is substituted into. The ArrayRef `flatExprs` is expected to be in
1130/// the format [dims, symbols, locals, constant term]. The semi-affine
1131/// expression is constructed in the sorted order of dimension and symbol
1132/// position numbers. Note: local expressions/ids are used for mod, div as well
1133/// as symbolic RHS terms for terms that are not pure affine.
1135 unsigned numDims,
1136 unsigned numSymbols,
1137 ArrayRef<AffineExpr> localExprs,
1138 MLIRContext *context) {
1139 assert(!flatExprs.empty() && "flatExprs cannot be empty");
1140
1141 // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
1142 assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
1143 "unexpected number of local expressions");
1144
1145 AffineExpr expr = getAffineConstantExpr(0, context);
1146
1147 // We design indices as a pair which help us present the semi-affine map as
1148 // sum of product where terms are sorted based on dimension or symbol
1149 // position: <keyA, keyB> for expressions of the form dimension * symbol,
1150 // where keyA is the position number of the dimension and keyB is the
1151 // position number of the symbol. For dimensional expressions we set the index
1152 // as (position number of the dimension, -1), as we want dimensional
1153 // expressions to appear before symbolic and product of dimensional and
1154 // symbolic expressions having the dimension with the same position number.
1155 // For symbolic expression set the index as (position number of the symbol,
1156 // maximum of last dimension and symbol position) number. For example, we want
1157 // the expression we are constructing to look something like: d0 + d0 * s0 +
1158 // s0 + d1*s1 + s1.
1159
1160 // Stores the affine expression corresponding to a given index.
1162 // Stores the constant coefficient value corresponding to a given
1163 // dimension, symbol or a non-pure affine expression stored in `localExprs`.
1165 // Stores the indices as defined above, and later sorted to produce
1166 // the semi-affine expression in the desired form.
1168
1169 // Example: expression = d0 + d0 * s0 + 2 * s0.
1170 // indices = [{0,-1}, {0, 0}, {0, 1}]
1171 // coefficients = [{{0, -1}, 1}, {{0, 0}, 1}, {{0, 1}, 2}]
1172 // indexToExprMap = [{{0, -1}, d0}, {{0, 0}, d0 * s0}, {{0, 1}, s0}]
1173
1174 // Adds entries to `indexToExprMap`, `coefficients` and `indices`.
1175 auto addEntry = [&](std::pair<unsigned, signed> index, int64_t coefficient,
1176 AffineExpr expr) {
1177 assert(!llvm::is_contained(indices, index) &&
1178 "Key is already present in indices vector and overwriting will "
1179 "happen in `indexToExprMap` and `coefficients`!");
1180
1181 indices.push_back(index);
1182 coefficients.insert({index, coefficient});
1183 indexToExprMap.insert({index, expr});
1184 };
1185
1186 // Design indices for dimensional or symbolic terms, and store the indices,
1187 // constant coefficient corresponding to the indices in `coefficients` map,
1188 // and affine expression corresponding to indices in `indexToExprMap` map.
1189
1190 // Ensure we do not have duplicate keys in `indexToExpr` map.
1191 unsigned offsetSym = 0;
1192 signed offsetDim = -1;
1193 for (unsigned j = numDims; j < numDims + numSymbols; ++j) {
1194 if (flatExprs[j] == 0)
1195 continue;
1196 // For symbolic expression set the index as <position number
1197 // of the symbol, max(dimCount, symCount)> number,
1198 // as we want symbolic expressions with the same positional number to
1199 // appear after dimensional expressions having the same positional number.
1200 std::pair<unsigned, signed> indexEntry(
1201 j - numDims, std::max(numDims, numSymbols) + offsetSym++);
1202 addEntry(indexEntry, flatExprs[j],
1203 getAffineSymbolExpr(j - numDims, context));
1204 }
1205
1206 // Denotes semi-affine product, modulo or division terms, which has been added
1207 // to the `indexToExpr` map.
1208 SmallVector<bool, 4> addedToMap(flatExprs.size() - numDims - numSymbols - 1,
1209 false);
1210 unsigned lhsPos, rhsPos;
1211 // Construct indices for product terms involving dimension, symbol or constant
1212 // as lhs/rhs, and store the indices, constant coefficient corresponding to
1213 // the indices in `coefficients` map, and affine expression corresponding to
1214 // in indices in `indexToExprMap` map.
1215 for (const auto &it : llvm::enumerate(localExprs)) {
1216 if (flatExprs[numDims + numSymbols + it.index()] == 0)
1217 continue;
1218 AffineExpr expr = it.value();
1219 auto binaryExpr = dyn_cast<AffineBinaryOpExpr>(expr);
1220 if (!binaryExpr)
1221 continue;
1222
1223 AffineExpr lhs = binaryExpr.getLHS();
1224 AffineExpr rhs = binaryExpr.getRHS();
1225 if (!((isa<AffineDimExpr>(lhs) || isa<AffineSymbolExpr>(lhs)) &&
1226 (isa<AffineDimExpr>(rhs) || isa<AffineSymbolExpr>(rhs) ||
1227 isa<AffineConstantExpr>(rhs)))) {
1228 continue;
1229 }
1230 if (isa<AffineConstantExpr>(rhs)) {
1231 // For product/modulo/division expressions, when rhs of modulo/division
1232 // expression is constant, we put 0 in place of keyB, because we want
1233 // them to appear earlier in the semi-affine expression we are
1234 // constructing. When rhs is constant, we place 0 in place of keyB.
1235 if (isa<AffineDimExpr>(lhs)) {
1236 lhsPos = cast<AffineDimExpr>(lhs).getPosition();
1237 std::pair<unsigned, signed> indexEntry(lhsPos, offsetDim--);
1238 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1239 expr);
1240 } else {
1241 lhsPos = cast<AffineSymbolExpr>(lhs).getPosition();
1242 std::pair<unsigned, signed> indexEntry(
1243 lhsPos, std::max(numDims, numSymbols) + offsetSym++);
1244 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1245 expr);
1246 }
1247 } else if (isa<AffineDimExpr>(lhs)) {
1248 // For product/modulo/division expressions having lhs as dimension and rhs
1249 // as symbol, we order the terms in the semi-affine expression based on
1250 // the pair: <keyA, keyB> for expressions of the form dimension * symbol,
1251 // where keyA is the position number of the dimension and keyB is the
1252 // position number of the symbol.
1253 lhsPos = cast<AffineDimExpr>(lhs).getPosition();
1254 rhsPos = cast<AffineSymbolExpr>(rhs).getPosition();
1255 std::pair<unsigned, signed> indexEntry(lhsPos, rhsPos);
1256 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1257 } else {
1258 // For product/modulo/division expressions having both lhs and rhs as
1259 // symbol, we design indices as a pair: <keyA, keyB> for expressions
1260 // of the form dimension * symbol, where keyA is the position number of
1261 // the dimension and keyB is the position number of the symbol.
1262 lhsPos = cast<AffineSymbolExpr>(lhs).getPosition();
1263 rhsPos = cast<AffineSymbolExpr>(rhs).getPosition();
1264 std::pair<unsigned, signed> indexEntry(
1265 lhsPos, std::max(numDims, numSymbols) + offsetSym++);
1266 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1267 }
1268 addedToMap[it.index()] = true;
1269 }
1270
1271 for (unsigned j = 0; j < numDims; ++j) {
1272 if (flatExprs[j] == 0)
1273 continue;
1274 // For dimensional expressions we set the index as <position number of the
1275 // dimension, 0>, as we want dimensional expressions to appear before
1276 // symbolic ones and products of dimensional and symbolic expressions
1277 // having the dimension with the same position number.
1278 std::pair<unsigned, signed> indexEntry(j, offsetDim--);
1279 addEntry(indexEntry, flatExprs[j], getAffineDimExpr(j, context));
1280 }
1281
1282 // Constructing the simplified semi-affine sum of product/division/mod
1283 // expression from the flattened form in the desired sorted order of indices
1284 // of the various individual product/division/mod expressions.
1285 llvm::sort(indices);
1286 for (const std::pair<unsigned, unsigned> index : indices) {
1287 assert(indexToExprMap.lookup(index) &&
1288 "cannot find key in `indexToExprMap` map");
1289 expr = expr + indexToExprMap.lookup(index) * coefficients.lookup(index);
1290 }
1291
1292 // Local identifiers.
1293 for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
1294 j++) {
1295 // If the coefficient of the local expression is 0, continue as we need not
1296 // add it in out final expression.
1297 if (flatExprs[j] == 0 || addedToMap[j - numDims - numSymbols])
1298 continue;
1299 auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
1300 expr = expr + term;
1301 }
1302
1303 // Constant term.
1304 int64_t constTerm = flatExprs.back();
1305 if (constTerm != 0)
1306 expr = expr + constTerm;
1307 return expr;
1308}
1309
1315
1316// In pure affine t = expr * c, we multiply each coefficient of lhs with c.
1317//
1318// In case of semi affine multiplication expressions, t = expr * symbolic_expr,
1319// introduce a local variable p (= expr * symbolic_expr), and the affine
1320// expression expr * symbolic_expr is added to `localExprs`.
1322 assert(operandExprStack.size() >= 2);
1324 operandExprStack.pop_back();
1326
1327 // Flatten semi-affine multiplication expressions by introducing a local
1328 // variable in place of the product; the affine expression
1329 // corresponding to the quantifier is added to `localExprs`.
1330 if (!isa<AffineConstantExpr>(expr.getRHS())) {
1332 MLIRContext *context = expr.getContext();
1334 localExprs, context);
1336 localExprs, context);
1337 return addLocalVariableSemiAffine(mulLhs, rhs, a * b, lhs, lhs.size());
1338 }
1339
1340 // Get the RHS constant.
1341 int64_t rhsConst = rhs[getConstantIndex()];
1342 for (int64_t &lhsElt : lhs)
1343 lhsElt *= rhsConst;
1344
1345 return success();
1346}
1347
1349 assert(operandExprStack.size() >= 2);
1350 const auto &rhs = operandExprStack.back();
1351 auto &lhs = operandExprStack[operandExprStack.size() - 2];
1352 assert(lhs.size() == rhs.size());
1353 // Update the LHS in place.
1354 for (unsigned i = 0, e = rhs.size(); i < e; i++) {
1355 lhs[i] += rhs[i];
1356 }
1357 // Pop off the RHS.
1358 operandExprStack.pop_back();
1359 return success();
1360}
1361
1362//
1363// t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1
1364//
1365// A mod expression "expr mod c" is thus flattened by introducing a new local
1366// variable q (= expr floordiv c), such that expr mod c is replaced with
1367// 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
1368//
1369// In case of semi-affine modulo expressions, t = expr mod symbolic_expr,
1370// introduce a local variable m (= expr mod symbolic_expr), and the affine
1371// expression expr mod symbolic_expr is added to `localExprs`.
1373 assert(operandExprStack.size() >= 2);
1374
1376 operandExprStack.pop_back();
1378 MLIRContext *context = expr.getContext();
1379
1380 // Flatten semi affine modulo expressions by introducing a local
1381 // variable in place of the modulo value, and the affine expression
1382 // corresponding to the quantifier is added to `localExprs`.
1383 if (!isa<AffineConstantExpr>(expr.getRHS())) {
1386 lhs, numDims, numSymbols, localExprs, context);
1388 localExprs, context);
1389 AffineExpr modExpr = dividendExpr % divisorExpr;
1390 return addLocalVariableSemiAffine(modLhs, rhs, modExpr, lhs, lhs.size());
1391 }
1392
1393 int64_t rhsConst = rhs[getConstantIndex()];
1394 if (rhsConst <= 0)
1395 return failure();
1396
1397 // Check if the LHS expression is a multiple of modulo factor.
1398 unsigned i, e;
1399 for (i = 0, e = lhs.size(); i < e; i++)
1400 if (lhs[i] % rhsConst != 0)
1401 break;
1402 // If yes, modulo expression here simplifies to zero.
1403 if (i == lhs.size()) {
1404 llvm::fill(lhs, 0);
1405 return success();
1406 }
1407
1408 // Add a local variable for the quotient, i.e., expr % c is replaced by
1409 // (expr - q * c) where q = expr floordiv c. Do this while canceling out
1410 // the GCD of expr and c.
1411 SmallVector<int64_t, 8> floorDividend(lhs);
1412 uint64_t gcd = rhsConst;
1413 for (int64_t lhsElt : lhs)
1414 gcd = std::gcd(gcd, (uint64_t)std::abs(lhsElt));
1415 // Simplify the numerator and the denominator.
1416 if (gcd != 1) {
1417 for (int64_t &floorDividendElt : floorDividend)
1418 floorDividendElt = floorDividendElt / static_cast<int64_t>(gcd);
1419 }
1420 int64_t floorDivisor = rhsConst / static_cast<int64_t>(gcd);
1421
1422 // Construct the AffineExpr form of the floordiv to store in localExprs.
1423
1425 floorDividend, numDims, numSymbols, localExprs, context);
1426 AffineExpr divisorExpr = getAffineConstantExpr(floorDivisor, context);
1427 AffineExpr floorDivExpr = dividendExpr.floorDiv(divisorExpr);
1428 int loc;
1429 if ((loc = findLocalId(floorDivExpr)) == -1) {
1430 addLocalFloorDivId(floorDividend, floorDivisor, floorDivExpr);
1431 // Set result at top of stack to "lhs - rhsConst * q".
1432 lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst;
1433 } else {
1434 // Reuse the existing local id.
1435 lhs[getLocalVarStartIndex() + loc] -= rhsConst;
1436 }
1437 return success();
1438}
1439
1440LogicalResult
1442 return visitDivExpr(expr, /*isCeil=*/true);
1443}
1444LogicalResult
1446 return visitDivExpr(expr, /*isCeil=*/false);
1447}
1448
1450 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1451 auto &eq = operandExprStack.back();
1452 assert(expr.getPosition() < numDims && "Inconsistent number of dims");
1453 eq[getDimStartIndex() + expr.getPosition()] = 1;
1454 return success();
1455}
1456
1457LogicalResult
1459 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1460 auto &eq = operandExprStack.back();
1461 assert(expr.getPosition() < numSymbols && "inconsistent number of symbols");
1462 eq[getSymbolStartIndex() + expr.getPosition()] = 1;
1463 return success();
1464}
1465
1466LogicalResult
1468 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1469 auto &eq = operandExprStack.back();
1470 eq[getConstantIndex()] = expr.getValue();
1471 return success();
1472}
1473
1474LogicalResult SimpleAffineExprFlattener::addLocalVariableSemiAffine(
1476 SmallVectorImpl<int64_t> &result, unsigned long resultSize) {
1477 assert(result.size() == resultSize &&
1478 "`result` vector passed is not of correct size");
1479 int loc;
1480 if ((loc = findLocalId(localExpr)) == -1) {
1481 if (failed(addLocalIdSemiAffine(lhs, rhs, localExpr)))
1482 return failure();
1483 }
1484 llvm::fill(result, 0);
1485 if (loc == -1)
1486 result[getLocalVarStartIndex() + numLocals - 1] = 1;
1487 else
1488 result[getLocalVarStartIndex() + loc] = 1;
1489 return success();
1490}
1491
1492// t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1
1493// A floordiv is thus flattened by introducing a new local variable q, and
1494// replacing that expression with 'q' while adding the constraints
1495// c * q <= expr <= c * q + c - 1 to localVarCst (done by
1496// IntegerRelation::addLocalFloorDiv).
1497//
1498// A ceildiv is similarly flattened:
1499// t = expr ceildiv c <=> t = (expr + c - 1) floordiv c
1500//
1501// In case of semi affine division expressions, t = expr floordiv symbolic_expr
1502// or t = expr ceildiv symbolic_expr, introduce a local variable q (= expr
1503// floordiv/ceildiv symbolic_expr), and the affine floordiv/ceildiv is added to
1504// `localExprs`.
1505LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
1506 bool isCeil) {
1507 assert(operandExprStack.size() >= 2);
1508
1509 MLIRContext *context = expr.getContext();
1510 SmallVector<int64_t, 8> rhs = operandExprStack.back();
1511 operandExprStack.pop_back();
1512 SmallVector<int64_t, 8> &lhs = operandExprStack.back();
1513
1514 // Flatten semi affine division expressions by introducing a local
1515 // variable in place of the quotient, and the affine expression corresponding
1516 // to the quantifier is added to `localExprs`.
1517 if (!isa<AffineConstantExpr>(expr.getRHS())) {
1518 SmallVector<int64_t, 8> divLhs(lhs);
1520 localExprs, context);
1522 localExprs, context);
1523 AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
1524 return addLocalVariableSemiAffine(divLhs, rhs, divExpr, lhs, lhs.size());
1525 }
1526
1527 // This is a pure affine expr; the RHS is a positive constant.
1528 int64_t rhsConst = rhs[getConstantIndex()];
1529 if (rhsConst <= 0)
1530 return failure();
1531
1532 // Simplify the floordiv, ceildiv if possible by canceling out the greatest
1533 // common divisors of the numerator and denominator.
1534 uint64_t gcd = std::abs(rhsConst);
1535 for (int64_t lhsElt : lhs)
1536 gcd = std::gcd(gcd, (uint64_t)std::abs(lhsElt));
1537 // Simplify the numerator and the denominator.
1538 if (gcd != 1) {
1539 for (int64_t &lhsElt : lhs)
1540 lhsElt = lhsElt / static_cast<int64_t>(gcd);
1541 }
1542 int64_t divisor = rhsConst / static_cast<int64_t>(gcd);
1543 // If the divisor becomes 1, the updated LHS is the result. (The
1544 // divisor can't be negative since rhsConst is positive).
1545 if (divisor == 1)
1546 return success();
1547
1548 // If the divisor cannot be simplified to one, we will have to retain
1549 // the ceil/floor expr (simplified up until here). Add an existential
1550 // quantifier to express its result, i.e., expr1 div expr2 is replaced
1551 // by a new identifier, q.
1552 AffineExpr a =
1554 AffineExpr b = getAffineConstantExpr(divisor, context);
1555
1556 int loc;
1557 AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
1558 if ((loc = findLocalId(divExpr)) == -1) {
1559 if (!isCeil) {
1560 SmallVector<int64_t, 8> dividend(lhs);
1561 addLocalFloorDivId(dividend, divisor, divExpr);
1562 } else {
1563 // lhs ceildiv c <=> (lhs + c - 1) floordiv c
1564 SmallVector<int64_t, 8> dividend(lhs);
1565 dividend.back() += divisor - 1;
1566 addLocalFloorDivId(dividend, divisor, divExpr);
1567 }
1568 }
1569 // Set the expression on stack to the local var introduced to capture the
1570 // result of the division (floor or ceil).
1571 llvm::fill(lhs, 0);
1572 if (loc == -1)
1573 lhs[getLocalVarStartIndex() + numLocals - 1] = 1;
1574 else
1575 lhs[getLocalVarStartIndex() + loc] = 1;
1576 return success();
1577}
1578
1579// Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
1580// The local identifier added is always a floordiv of a pure add/mul affine
1581// function of other identifiers, coefficients of which are specified in
1582// dividend and with respect to a positive constant divisor. localExpr is the
1583// simplified tree expression (AffineExpr) corresponding to the quantifier.
1585 int64_t divisor,
1586 AffineExpr localExpr) {
1587 assert(divisor > 0 && "positive constant divisor expected");
1589 subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
1590 localExprs.push_back(localExpr);
1591 numLocals++;
1592 // dividend and divisor are not used here; an override of this method uses it.
1593}
1594
1598 subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
1599 localExprs.push_back(localExpr);
1600 ++numLocals;
1601 // lhs and rhs are not used here; an override of this method uses them.
1602 return success();
1603}
1604
1605int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) {
1607 if ((it = llvm::find(localExprs, localExpr)) == localExprs.end())
1608 return -1;
1609 return it - localExprs.begin();
1610}
1611
1612/// Simplify the affine expression by flattening it and reconstructing it.
1614 unsigned numSymbols) {
1615 // Simplify semi-affine expressions separately.
1616 if (!expr.isPureAffine())
1617 expr = simplifySemiAffine(expr, numDims, numSymbols);
1618
1619 SimpleAffineExprFlattener flattener(numDims, numSymbols);
1620 // has poison expression
1621 if (failed(flattener.walkPostOrder(expr)))
1622 return expr;
1623 ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
1624 if (!expr.isPureAffine() &&
1625 expr == getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1626 flattener.localExprs,
1627 expr.getContext()))
1628 return expr;
1629 AffineExpr simplifiedExpr =
1630 expr.isPureAffine()
1631 ? getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1632 flattener.localExprs, expr.getContext())
1633 : getSemiAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1634 flattener.localExprs,
1635 expr.getContext());
1636
1637 flattener.operandExprStack.pop_back();
1638 assert(flattener.operandExprStack.empty());
1639 return simplifiedExpr;
1640}
1641
1642std::optional<int64_t> mlir::getBoundForAffineExpr(
1643 AffineExpr expr, unsigned numDims, unsigned numSymbols,
1644 ArrayRef<std::optional<int64_t>> constLowerBounds,
1645 ArrayRef<std::optional<int64_t>> constUpperBounds, bool isUpper) {
1646 // Handle divs and mods.
1647 if (auto binOpExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
1648 // If the LHS of a floor or ceil is bounded and the RHS is a constant, we
1649 // can compute an upper bound.
1650 if (binOpExpr.getKind() == AffineExprKind::FloorDiv) {
1651 auto rhsConst = dyn_cast<AffineConstantExpr>(binOpExpr.getRHS());
1652 if (!rhsConst || rhsConst.getValue() < 1)
1653 return std::nullopt;
1654 auto bound =
1655 getBoundForAffineExpr(binOpExpr.getLHS(), numDims, numSymbols,
1656 constLowerBounds, constUpperBounds, isUpper);
1657 if (!bound)
1658 return std::nullopt;
1659 return divideFloorSigned(*bound, rhsConst.getValue());
1660 }
1661 if (binOpExpr.getKind() == AffineExprKind::CeilDiv) {
1662 auto rhsConst = dyn_cast<AffineConstantExpr>(binOpExpr.getRHS());
1663 if (rhsConst && rhsConst.getValue() >= 1) {
1664 auto bound =
1665 getBoundForAffineExpr(binOpExpr.getLHS(), numDims, numSymbols,
1666 constLowerBounds, constUpperBounds, isUpper);
1667 if (!bound)
1668 return std::nullopt;
1669 return divideCeilSigned(*bound, rhsConst.getValue());
1670 }
1671 return std::nullopt;
1672 }
1673 if (binOpExpr.getKind() == AffineExprKind::Mod) {
1674 // lhs mod c is always <= c - 1 and non-negative. In addition, if `lhs` is
1675 // bounded such that lb <= lhs <= ub and lb floordiv c == ub floordiv c
1676 // (same "interval"), then lb mod c <= lhs mod c <= ub mod c.
1677 auto rhsConst = dyn_cast<AffineConstantExpr>(binOpExpr.getRHS());
1678 if (rhsConst && rhsConst.getValue() >= 1) {
1679 int64_t rhsConstVal = rhsConst.getValue();
1680 auto lb = getBoundForAffineExpr(binOpExpr.getLHS(), numDims, numSymbols,
1681 constLowerBounds, constUpperBounds,
1682 /*isUpper=*/false);
1683 auto ub =
1684 getBoundForAffineExpr(binOpExpr.getLHS(), numDims, numSymbols,
1685 constLowerBounds, constUpperBounds, isUpper);
1686 if (ub && lb &&
1687 divideFloorSigned(*lb, rhsConstVal) ==
1688 divideFloorSigned(*ub, rhsConstVal))
1689 return isUpper ? mod(*ub, rhsConstVal) : mod(*lb, rhsConstVal);
1690 return isUpper ? rhsConstVal - 1 : 0;
1691 }
1692 }
1693 }
1694 // Flatten the expression.
1695 SimpleAffineExprFlattener flattener(numDims, numSymbols);
1696 auto simpleResult = flattener.walkPostOrder(expr);
1697 // has poison expression
1698 if (failed(simpleResult))
1699 return std::nullopt;
1700 ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
1701 // TODO: Handle local variables. We can get hold of flattener.localExprs and
1702 // get bound on the local expr recursively.
1703 if (flattener.numLocals > 0)
1704 return std::nullopt;
1705 int64_t bound = 0;
1706 // Substitute the constant lower or upper bound for the dimensional or
1707 // symbolic input depending on `isUpper` to determine the bound.
1708 for (unsigned i = 0, e = numDims + numSymbols; i < e; ++i) {
1709 if (flattenedExpr[i] > 0) {
1710 auto &constBound = isUpper ? constUpperBounds[i] : constLowerBounds[i];
1711 if (!constBound)
1712 return std::nullopt;
1713 bound += *constBound * flattenedExpr[i];
1714 } else if (flattenedExpr[i] < 0) {
1715 auto &constBound = isUpper ? constLowerBounds[i] : constUpperBounds[i];
1716 if (!constBound)
1717 return std::nullopt;
1718 bound += *constBound * flattenedExpr[i];
1719 }
1720 }
1721 // Constant term.
1722 bound += flattenedExpr.back();
1723 return bound;
1724}
return success()
static int64_t product(ArrayRef< int64_t > vals)
lhs
static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs)
Simplify a multiply expression. Return nullptr if it can't be simplified.
static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs)
static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef< int64_t > flatExprs, unsigned numDims, unsigned numSymbols, ArrayRef< AffineExpr > localExprs, MLIRContext *context)
Constructs a semi-affine expression from a flat ArrayRef. If there are local identifiers (neither dim...
static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs)
static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
Affine binary operation expression.
Definition AffineExpr.h:214
AffineExpr getLHS() const
AffineBinaryOpExpr(AffineExpr::ImplType *ptr)
detail::AffineBinaryOpExprStorage ImplType
Definition AffineExpr.h:216
AffineExpr getRHS() const
An integer constant appearing in affine expression.
Definition AffineExpr.h:239
AffineConstantExpr(AffineExpr::ImplType *ptr=nullptr)
detail::AffineConstantExprStorage ImplType
Definition AffineExpr.h:241
int64_t getValue() const
A dimensional identifier appearing in an affine expression.
Definition AffineExpr.h:223
AffineDimExpr(AffineExpr::ImplType *ptr)
detail::AffineDimExprStorage ImplType
Definition AffineExpr.h:225
unsigned getPosition() const
See documentation for AffineExprVisitorBase.
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
Definition AffineExpr.h:68
AffineExpr replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements) const
This method substitutes any uses of dimensions and symbols (e.g.
AffineExpr shiftDims(unsigned numDims, unsigned shift, unsigned offset=0) const
Replace dims[offset ... numDims) by dims[offset + shift ... shift + numDims).
bool isSymbolicOrConstant() const
Returns true if this expression is made out of only symbols and constants, i.e., it does not involve ...
AffineExpr operator+(int64_t v) const
AffineExpr operator*(int64_t v) const
constexpr AffineExpr()
Definition AffineExpr.h:72
bool operator==(AffineExpr other) const
Definition AffineExpr.h:76
bool isPureAffine() const
Returns true if this is a pure affine expression, i.e., multiplication, floordiv, ceildiv,...
AffineExpr shiftSymbols(unsigned numSymbols, unsigned shift, unsigned offset=0) const
Replace symbols[offset ... numSymbols) by symbols[offset + shift ... shift + numSymbols).
AffineExpr operator-() const
AffineExpr floorDiv(uint64_t v) const
ImplType * expr
Definition AffineExpr.h:196
RetT walk(FnT &&callback) const
Walk all of the AffineExpr's in this expression in postorder.
Definition AffineExpr.h:117
AffineExprKind getKind() const
Return the classification for this type.
bool isMultipleOf(int64_t factor) const
Return true if the affine expression is a multiple of 'factor'.
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
AffineExpr compose(AffineMap map) const
Compose with an AffineMap.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
bool isFunctionOfSymbol(unsigned position) const
Return true if the affine expression involves AffineSymbolExpr position.
AffineExpr replaceDims(ArrayRef< AffineExpr > dimReplacements) const
Dim-only version of replaceDimsAndSymbols.
AffineExpr operator%(uint64_t v) const
MLIRContext * getContext() const
AffineExpr replace(AffineExpr expr, AffineExpr replacement) const
Sparse replace method.
AffineExpr replaceSymbols(ArrayRef< AffineExpr > symReplacements) const
Symbol-only version of replaceDimsAndSymbols.
detail::AffineExprStorage ImplType
Definition AffineExpr.h:70
AffineExpr ceilDiv(uint64_t v) const
void print(raw_ostream &os) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
ArrayRef< AffineExpr > getResults() const
A symbolic identifier appearing in an affine expression.
Definition AffineExpr.h:231
unsigned getPosition() const
detail::AffineDimExprStorage ImplType
Definition AffineExpr.h:233
AffineSymbolExpr(AffineExpr::ImplType *ptr)
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
StorageUniquer & getAffineUniquer()
Returns the storage uniquer used for creating affine constructs.
virtual void addLocalFloorDivId(ArrayRef< int64_t > dividend, int64_t divisor, AffineExpr localExpr)
LogicalResult visitSymbolExpr(AffineSymbolExpr expr)
std::vector< SmallVector< int64_t, 8 > > operandExprStack
LogicalResult visitDimExpr(AffineDimExpr expr)
LogicalResult visitFloorDivExpr(AffineBinaryOpExpr expr)
LogicalResult visitConstantExpr(AffineConstantExpr expr)
virtual LogicalResult addLocalIdSemiAffine(ArrayRef< int64_t > lhs, ArrayRef< int64_t > rhs, AffineExpr localExpr)
Add a local identifier (needed to flatten a mod, floordiv, ceildiv, mul expr) when the rhs is a symbo...
LogicalResult visitModExpr(AffineBinaryOpExpr expr)
LogicalResult visitAddExpr(AffineBinaryOpExpr expr)
LogicalResult visitCeilDivExpr(AffineBinaryOpExpr expr)
LogicalResult visitMulExpr(AffineBinaryOpExpr expr)
SmallVector< AffineExpr, 4 > localExprs
SimpleAffineExprFlattener(unsigned numDims, unsigned numSymbols)
A utility class to get or create instances of "storage classes".
Storage * get(function_ref< void(Storage *)> initFn, TypeID id, Args &&...args)
Gets a uniqued instance of 'Storage'.
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
AttrTypeReplacer.
Include the generated interface declarations.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
std::optional< int64_t > getBoundForAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols, ArrayRef< std::optional< int64_t > > constLowerBounds, ArrayRef< std::optional< int64_t > > constUpperBounds, bool isUpper)
Get a lower or upper (depending on isUpper) bound for expr while using the constant lower and upper b...
AffineExprKind
Definition AffineExpr.h:40
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
Definition AffineExpr.h:50
@ Mul
RHS of mul is always a constant or a symbolic expression.
Definition AffineExpr.h:43
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
Definition AffineExpr.h:46
@ DimId
Dimensional identifier.
Definition AffineExpr.h:59
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
Definition AffineExpr.h:48
@ Constant
Constant integer.
Definition AffineExpr.h:57
@ SymbolId
Symbolic identifier.
Definition AffineExpr.h:61
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
AffineExpr getAffineExprFromFlatForm(ArrayRef< int64_t > flatExprs, unsigned numDims, unsigned numSymbols, ArrayRef< AffineExpr > localExprs, MLIRContext *context)
Constructs an affine expression from a flat ArrayRef.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
SmallVector< AffineExpr > getAffineConstantExprs(ArrayRef< int64_t > constants, MLIRContext *context)
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
A binary operation appearing in an affine expression.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.