MLIR 22.0.0git
AffineExpr.cpp
Go to the documentation of this file.
1//===- AffineExpr.cpp - MLIR Affine Expr Classes --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <cmath>
10#include <cstdint>
11#include <utility>
12
13#include "AffineExprDetail.h"
14#include "mlir/IR/AffineExpr.h"
16#include "mlir/IR/AffineMap.h"
17#include "mlir/IR/IntegerSet.h"
18#include "llvm/ADT/STLExtras.h"
19#include "llvm/Support/MathExtras.h"
20#include <numeric>
21#include <optional>
22
23using namespace mlir;
24using namespace mlir::detail;
25
26using llvm::divideCeilSigned;
27using llvm::divideFloorSigned;
28using llvm::divideSignedWouldOverflow;
29using llvm::mod;
30
31MLIRContext *AffineExpr::getContext() const { return expr->context; }
32
33AffineExprKind AffineExpr::getKind() const { return expr->kind; }
34
35/// Walk all of the AffineExprs in `e` in postorder. This is a private factory
36/// method to help handle lambda walk functions. Users should use the regular
37/// (non-static) `walk` method.
38template <typename WalkRetTy>
40 function_ref<WalkRetTy(AffineExpr)> callback) {
41 struct AffineExprWalker
42 : public AffineExprVisitor<AffineExprWalker, WalkRetTy> {
43 function_ref<WalkRetTy(AffineExpr)> callback;
44
45 AffineExprWalker(function_ref<WalkRetTy(AffineExpr)> callback)
46 : callback(callback) {}
47
48 WalkRetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
49 return callback(expr);
50 }
51 WalkRetTy visitConstantExpr(AffineConstantExpr expr) {
52 return callback(expr);
53 }
54 WalkRetTy visitDimExpr(AffineDimExpr expr) { return callback(expr); }
55 WalkRetTy visitSymbolExpr(AffineSymbolExpr expr) { return callback(expr); }
56 };
57
58 return AffineExprWalker(callback).walkPostOrder(e);
59}
60// Explicitly instantiate for the two supported return types.
62 function_ref<void(AffineExpr)> callback);
63template WalkResult
66
67// Dispatch affine expression construction based on kind.
70 if (kind == AffineExprKind::Add)
71 return lhs + rhs;
72 if (kind == AffineExprKind::Mul)
73 return lhs * rhs;
74 if (kind == AffineExprKind::FloorDiv)
75 return lhs.floorDiv(rhs);
76 if (kind == AffineExprKind::CeilDiv)
77 return lhs.ceilDiv(rhs);
78 if (kind == AffineExprKind::Mod)
79 return lhs % rhs;
80
81 llvm_unreachable("unknown binary operation on affine expressions");
82}
83
84/// This method substitutes any uses of dimensions and symbols (e.g.
85/// dim#0 with dimReplacements[0]) and returns the modified expression tree.
88 ArrayRef<AffineExpr> symReplacements) const {
89 switch (getKind()) {
91 return *this;
93 unsigned dimId = llvm::cast<AffineDimExpr>(*this).getPosition();
94 if (dimId >= dimReplacements.size())
95 return *this;
96 return dimReplacements[dimId];
97 }
99 unsigned symId = llvm::cast<AffineSymbolExpr>(*this).getPosition();
100 if (symId >= symReplacements.size())
101 return *this;
102 return symReplacements[symId];
103 }
109 auto binOp = llvm::cast<AffineBinaryOpExpr>(*this);
110 auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
111 auto newLHS = lhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
112 auto newRHS = rhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
113 if (newLHS == lhs && newRHS == rhs)
114 return *this;
115 return getAffineBinaryOpExpr(getKind(), newLHS, newRHS);
116 }
117 llvm_unreachable("Unknown AffineExpr");
118}
119
121 return replaceDimsAndSymbols(dimReplacements, {});
122}
123
126 return replaceDimsAndSymbols({}, symReplacements);
127}
128
129/// Replace dims[offset ... numDims)
130/// by dims[offset + shift ... shift + numDims).
131AffineExpr AffineExpr::shiftDims(unsigned numDims, unsigned shift,
132 unsigned offset) const {
134 for (unsigned idx = 0; idx < offset; ++idx)
135 dims.push_back(getAffineDimExpr(idx, getContext()));
136 for (unsigned idx = offset; idx < numDims; ++idx)
137 dims.push_back(getAffineDimExpr(idx + shift, getContext()));
138 return replaceDimsAndSymbols(dims, {});
139}
140
141/// Replace symbols[offset ... numSymbols)
142/// by symbols[offset + shift ... shift + numSymbols).
143AffineExpr AffineExpr::shiftSymbols(unsigned numSymbols, unsigned shift,
144 unsigned offset) const {
146 for (unsigned idx = 0; idx < offset; ++idx)
147 symbols.push_back(getAffineSymbolExpr(idx, getContext()));
148 for (unsigned idx = offset; idx < numSymbols; ++idx)
149 symbols.push_back(getAffineSymbolExpr(idx + shift, getContext()));
150 return replaceDimsAndSymbols({}, symbols);
151}
152
153/// Sparse replace method. Return the modified expression tree.
156 auto it = map.find(*this);
157 if (it != map.end())
158 return it->second;
159 switch (getKind()) {
160 default:
161 return *this;
167 auto binOp = llvm::cast<AffineBinaryOpExpr>(*this);
168 auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
169 auto newLHS = lhs.replace(map);
170 auto newRHS = rhs.replace(map);
171 if (newLHS == lhs && newRHS == rhs)
172 return *this;
173 return getAffineBinaryOpExpr(getKind(), newLHS, newRHS);
174 }
175 llvm_unreachable("Unknown AffineExpr");
176}
177
178/// Sparse replace method. Return the modified expression tree.
181 map.insert(std::make_pair(expr, replacement));
182 return replace(map);
183}
184/// Returns true if this expression is made out of only symbols and
185/// constants (no dimensional identifiers).
187 switch (getKind()) {
189 return true;
191 return false;
193 return true;
194
199 case AffineExprKind::Mod: {
200 auto expr = llvm::cast<AffineBinaryOpExpr>(*this);
201 return expr.getLHS().isSymbolicOrConstant() &&
202 expr.getRHS().isSymbolicOrConstant();
203 }
204 }
205 llvm_unreachable("Unknown AffineExpr");
206}
207
208/// Returns true if this is a pure affine expression, i.e., multiplication,
209/// floordiv, ceildiv, and mod is only allowed w.r.t constants.
211 switch (getKind()) {
215 return true;
216 case AffineExprKind::Add: {
217 auto op = llvm::cast<AffineBinaryOpExpr>(*this);
218 return op.getLHS().isPureAffine() && op.getRHS().isPureAffine();
219 }
220
221 case AffineExprKind::Mul: {
222 // TODO: Canonicalize the constants in binary operators to the RHS when
223 // possible, allowing this to merge into the next case.
224 auto op = llvm::cast<AffineBinaryOpExpr>(*this);
225 return op.getLHS().isPureAffine() && op.getRHS().isPureAffine() &&
226 (llvm::isa<AffineConstantExpr>(op.getLHS()) ||
227 llvm::isa<AffineConstantExpr>(op.getRHS()));
228 }
231 case AffineExprKind::Mod: {
232 auto op = llvm::cast<AffineBinaryOpExpr>(*this);
233 return op.getLHS().isPureAffine() &&
234 llvm::isa<AffineConstantExpr>(op.getRHS());
235 }
236 }
237 llvm_unreachable("Unknown AffineExpr");
238}
239
240// Returns the greatest known integral divisor of this affine expression.
242 AffineBinaryOpExpr binExpr(nullptr);
243 switch (getKind()) {
245 [[fallthrough]];
247 return 1;
249 [[fallthrough]];
251 // If the RHS is a constant and divides the known divisor on the LHS, the
252 // quotient is a known divisor of the expression.
253 binExpr = llvm::cast<AffineBinaryOpExpr>(*this);
254 auto rhs = llvm::dyn_cast<AffineConstantExpr>(binExpr.getRHS());
255 // Leave alone undefined expressions.
256 if (rhs && rhs.getValue() != 0) {
257 int64_t lhsDiv = binExpr.getLHS().getLargestKnownDivisor();
258 if (lhsDiv % rhs.getValue() == 0)
259 return std::abs(lhsDiv / rhs.getValue());
260 }
261 return 1;
262 }
264 return std::abs(llvm::cast<AffineConstantExpr>(*this).getValue());
265 case AffineExprKind::Mul: {
266 binExpr = llvm::cast<AffineBinaryOpExpr>(*this);
267 return binExpr.getLHS().getLargestKnownDivisor() *
268 binExpr.getRHS().getLargestKnownDivisor();
269 }
271 [[fallthrough]];
272 case AffineExprKind::Mod: {
273 binExpr = llvm::cast<AffineBinaryOpExpr>(*this);
274 return std::gcd((uint64_t)binExpr.getLHS().getLargestKnownDivisor(),
275 (uint64_t)binExpr.getRHS().getLargestKnownDivisor());
276 }
277 }
278 llvm_unreachable("Unknown AffineExpr");
279}
280
282 AffineBinaryOpExpr binExpr(nullptr);
283 uint64_t l, u;
284 switch (getKind()) {
286 [[fallthrough]];
288 return factor * factor == 1;
290 return llvm::cast<AffineConstantExpr>(*this).getValue() % factor == 0;
291 case AffineExprKind::Mul: {
292 binExpr = llvm::cast<AffineBinaryOpExpr>(*this);
293 // It's probably not worth optimizing this further (to not traverse the
294 // whole sub-tree under - it that would require a version of isMultipleOf
295 // that on a 'false' return also returns the largest known divisor).
296 return (l = binExpr.getLHS().getLargestKnownDivisor()) % factor == 0 ||
297 (u = binExpr.getRHS().getLargestKnownDivisor()) % factor == 0 ||
298 (l * u) % factor == 0;
299 }
303 case AffineExprKind::Mod: {
304 binExpr = llvm::cast<AffineBinaryOpExpr>(*this);
305 return std::gcd((uint64_t)binExpr.getLHS().getLargestKnownDivisor(),
306 (uint64_t)binExpr.getRHS().getLargestKnownDivisor()) %
307 factor ==
308 0;
309 }
310 }
311 llvm_unreachable("Unknown AffineExpr");
312}
313
314bool AffineExpr::isFunctionOfDim(unsigned position) const {
316 return *this == mlir::getAffineDimExpr(position, getContext());
317 }
318 if (auto expr = llvm::dyn_cast<AffineBinaryOpExpr>(*this)) {
319 return expr.getLHS().isFunctionOfDim(position) ||
320 expr.getRHS().isFunctionOfDim(position);
321 }
322 return false;
323}
324
325bool AffineExpr::isFunctionOfSymbol(unsigned position) const {
327 return *this == mlir::getAffineSymbolExpr(position, getContext());
328 }
329 if (auto expr = llvm::dyn_cast<AffineBinaryOpExpr>(*this)) {
330 return expr.getLHS().isFunctionOfSymbol(position) ||
331 expr.getRHS().isFunctionOfSymbol(position);
332 }
333 return false;
334}
335
339 return static_cast<ImplType *>(expr)->lhs;
340}
342 return static_cast<ImplType *>(expr)->rhs;
343}
344
347 return static_cast<ImplType *>(expr)->position;
348}
349
350/// Returns true if the expression is divisible by the given symbol with
351/// position `symbolPos`. The argument `opKind` specifies here what kind of
352/// division or mod operation called this division. It helps in implementing the
353/// commutative property of the floordiv and ceildiv operations. If the argument
354///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv
355/// operation, then the commutative property can be used otherwise, the floordiv
356/// operation is not divisible. The same argument holds for ceildiv operation.
357static bool canSimplifyDivisionBySymbol(AffineExpr expr, unsigned symbolPos,
358 AffineExprKind opKind,
359 bool fromMul = false) {
360 // The argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
361 assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
362 opKind == AffineExprKind::CeilDiv) &&
363 "unexpected opKind");
364 switch (expr.getKind()) {
366 return cast<AffineConstantExpr>(expr).getValue() == 0;
368 return false;
370 return (cast<AffineSymbolExpr>(expr).getPosition() == symbolPos);
371 // Checks divisibility by the given symbol for both operands.
372 case AffineExprKind::Add: {
373 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
374 return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos,
375 opKind) &&
376 canSimplifyDivisionBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
377 }
378 // Checks divisibility by the given symbol for both operands. Consider the
379 // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
380 // this is a division by s1 and both the operands of modulo are divisible by
381 // s1 but it is not divisible by s1 always. The third argument is
382 // `AffineExprKind::Mod` for this reason.
383 case AffineExprKind::Mod: {
384 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
385 return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos,
387 canSimplifyDivisionBySymbol(binaryExpr.getRHS(), symbolPos,
389 }
390 // Checks if any of the operand divisible by the given symbol.
391 case AffineExprKind::Mul: {
392 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
393 return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos, opKind,
394 true) ||
395 canSimplifyDivisionBySymbol(binaryExpr.getRHS(), symbolPos, opKind,
396 true);
397 }
398 // Floordiv and ceildiv are divisible by the given symbol when the first
399 // operand is divisible, and the affine expression kind of the argument expr
400 // is same as the argument `opKind`. This can be inferred from commutative
401 // property of floordiv and ceildiv operations and are as follow:
402 // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
403 // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
404 // It will fail 1.if operations are not same. For example:
405 // (exps1 ceildiv exp2) floordiv exp3 can not be simplified. 2.if there is a
406 // multiplication operation in the expression. For example:
407 // (exps1 ceildiv exp2) mul exp3 ceildiv exp4 can not be simplified.
410 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
411 if (opKind != expr.getKind())
412 return false;
413 if (fromMul)
414 return false;
415 return canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos,
416 expr.getKind());
417 }
418 }
419 llvm_unreachable("Unknown AffineExpr");
420}
421
422/// Divides the given expression by the given symbol at position `symbolPos`. It
423/// considers the divisibility condition is checked before calling itself. A
424/// null expression is returned whenever the divisibility condition fails.
425static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos,
426 AffineExprKind opKind) {
427 // THe argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
428 assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
429 opKind == AffineExprKind::CeilDiv) &&
430 "unexpected opKind");
431 switch (expr.getKind()) {
433 if (cast<AffineConstantExpr>(expr).getValue() != 0)
434 return nullptr;
435 return getAffineConstantExpr(0, expr.getContext());
437 return nullptr;
439 return getAffineConstantExpr(1, expr.getContext());
440 // Dividing both operands by the given symbol.
441 case AffineExprKind::Add: {
442 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
444 expr.getKind(), symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind),
445 symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind));
446 }
447 // Dividing both operands by the given symbol.
448 case AffineExprKind::Mod: {
449 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
451 expr.getKind(),
452 symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
453 symbolicDivide(binaryExpr.getRHS(), symbolPos, expr.getKind()));
454 }
455 // Dividing any of the operand by the given symbol.
456 case AffineExprKind::Mul: {
457 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
458 if (!canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos, opKind))
459 return binaryExpr.getLHS() *
460 symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind);
461 return symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind) *
462 binaryExpr.getRHS();
463 }
464 // Dividing first operand only by the given symbol.
467 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
469 expr.getKind(),
470 symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
471 binaryExpr.getRHS());
472 }
473 }
474 llvm_unreachable("Unknown AffineExpr");
475}
476
477/// Populate `result` with all summand operands of given (potentially nested)
478/// addition. If the given expression is not an addition, just populate the
479/// expression itself.
480/// Example: Add(Add(7, 8), Mul(9, 10)) will return [7, 8, Mul(9, 10)].
481static void getSummandExprs(AffineExpr expr, SmallVector<AffineExpr> &result) {
482 auto addExpr = dyn_cast<AffineBinaryOpExpr>(expr);
483 if (!addExpr || addExpr.getKind() != AffineExprKind::Add) {
484 result.push_back(expr);
485 return;
486 }
487 getSummandExprs(addExpr.getLHS(), result);
488 getSummandExprs(addExpr.getRHS(), result);
489}
490
491/// Return "true" if `candidate` is a negated expression, i.e., Mul(-1, expr).
492/// If so, also return the non-negated expression via `expr`.
493static bool isNegatedAffineExpr(AffineExpr candidate, AffineExpr &expr) {
494 auto mulExpr = dyn_cast<AffineBinaryOpExpr>(candidate);
495 if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul)
496 return false;
497 if (auto lhs = dyn_cast<AffineConstantExpr>(mulExpr.getLHS())) {
498 if (lhs.getValue() == -1) {
499 expr = mulExpr.getRHS();
500 return true;
501 }
502 }
503 if (auto rhs = dyn_cast<AffineConstantExpr>(mulExpr.getRHS())) {
504 if (rhs.getValue() == -1) {
505 expr = mulExpr.getLHS();
506 return true;
507 }
508 }
509 return false;
510}
511
512/// Return "true" if `lhs` % `rhs` is guaranteed to evaluate to zero based on
513/// the fact that `lhs` contains another modulo expression that ensures that
514/// `lhs` is divisible by `rhs`. This is a common pattern in the resulting IR
515/// after loop peeling.
516///
517/// Example: lhs = ub - ub % step
518/// rhs = step
519/// => (ub - ub % step) % step is guaranteed to evaluate to 0.
520static bool isModOfModSubtraction(AffineExpr lhs, AffineExpr rhs,
521 unsigned numDims, unsigned numSymbols) {
522 // TODO: Try to unify this function with `getBoundForAffineExpr`.
523 // Collect all summands in lhs.
525 getSummandExprs(lhs, summands);
526 // Look for Mul(-1, Mod(x, rhs)) among the summands. If x matches the
527 // remaining summands, then lhs % rhs is guaranteed to evaluate to 0.
528 for (int64_t i = 0, e = summands.size(); i < e; ++i) {
529 AffineExpr current = summands[i];
530 AffineExpr beforeNegation;
531 if (!isNegatedAffineExpr(current, beforeNegation))
532 continue;
533 AffineBinaryOpExpr innerMod = dyn_cast<AffineBinaryOpExpr>(beforeNegation);
534 if (!innerMod || innerMod.getKind() != AffineExprKind::Mod)
535 continue;
536 if (innerMod.getRHS() != rhs)
537 continue;
538 // Sum all remaining summands and subtract x. If that expression can be
539 // simplified to zero, then the remaining summands and x are equal.
540 AffineExpr diff = getAffineConstantExpr(0, lhs.getContext());
541 for (int64_t j = 0; j < e; ++j)
542 if (i != j)
543 diff = diff + summands[j];
544 diff = diff - innerMod.getLHS();
545 diff = simplifyAffineExpr(diff, numDims, numSymbols);
546 auto constExpr = dyn_cast<AffineConstantExpr>(diff);
547 if (constExpr && constExpr.getValue() == 0)
548 return true;
549 }
550 return false;
551}
552
553/// Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv
554/// operations when the second operand simplifies to a symbol and the first
555/// operand is divisible by that symbol. It can be applied to any semi-affine
556/// expression. Returned expression can either be a semi-affine or pure affine
557/// expression.
558static AffineExpr simplifySemiAffine(AffineExpr expr, unsigned numDims,
559 unsigned numSymbols) {
560 switch (expr.getKind()) {
564 return expr;
566 case AffineExprKind::Mul: {
567 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
569 expr.getKind(),
570 simplifySemiAffine(binaryExpr.getLHS(), numDims, numSymbols),
571 simplifySemiAffine(binaryExpr.getRHS(), numDims, numSymbols));
572 }
573 // Check if the simplification of the second operand is a symbol, and the
574 // first operand is divisible by it. If the operation is a modulo, a constant
575 // zero expression is returned. In the case of floordiv and ceildiv, the
576 // symbol from the simplification of the second operand divides the first
577 // operand. Otherwise, simplification is not possible.
580 case AffineExprKind::Mod: {
581 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
582 AffineExpr sLHS =
583 simplifySemiAffine(binaryExpr.getLHS(), numDims, numSymbols);
584 AffineExpr sRHS =
585 simplifySemiAffine(binaryExpr.getRHS(), numDims, numSymbols);
586 if (isModOfModSubtraction(sLHS, sRHS, numDims, numSymbols))
587 return getAffineConstantExpr(0, expr.getContext());
588 AffineSymbolExpr symbolExpr = dyn_cast<AffineSymbolExpr>(
589 simplifySemiAffine(binaryExpr.getRHS(), numDims, numSymbols));
590 if (!symbolExpr)
591 return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
592 unsigned symbolPos = symbolExpr.getPosition();
593 if (!canSimplifyDivisionBySymbol(binaryExpr.getLHS(), symbolPos,
594 expr.getKind()))
595 return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
596 if (expr.getKind() == AffineExprKind::Mod)
597 return getAffineConstantExpr(0, expr.getContext());
598 AffineExpr simplifiedQuotient =
599 symbolicDivide(sLHS, symbolPos, expr.getKind());
600 return simplifiedQuotient
601 ? simplifiedQuotient
602 : getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
603 }
604 }
605 llvm_unreachable("Unknown AffineExpr");
606}
607
608static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position,
609 MLIRContext *context) {
610 auto assignCtx = [context](AffineDimExprStorage *storage) {
611 storage->context = context;
612 };
613
614 StorageUniquer &uniquer = context->getAffineUniquer();
615 return uniquer.get<AffineDimExprStorage>(
616 assignCtx, static_cast<unsigned>(kind), position);
617}
618
619AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) {
620 return getAffineDimOrSymbol(AffineExprKind::DimId, position, context);
621}
622
624 : AffineExpr(ptr) {}
625unsigned AffineSymbolExpr::getPosition() const {
626 return static_cast<ImplType *>(expr)->position;
627}
628
629AffineExpr mlir::getAffineSymbolExpr(unsigned position, MLIRContext *context) {
630 return getAffineDimOrSymbol(AffineExprKind::SymbolId, position, context);
631}
632
634 : AffineExpr(ptr) {}
635int64_t AffineConstantExpr::getValue() const {
636 return static_cast<ImplType *>(expr)->constant;
637}
638
639bool AffineExpr::operator==(int64_t v) const {
640 return *this == getAffineConstantExpr(v, getContext());
641}
642
643AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) {
644 auto assignCtx = [context](AffineConstantExprStorage *storage) {
645 storage->context = context;
646 };
647
648 StorageUniquer &uniquer = context->getAffineUniquer();
649 return uniquer.get<AffineConstantExprStorage>(assignCtx, constant);
650}
651
654 MLIRContext *context) {
655 return llvm::to_vector(llvm::map_range(constants, [&](int64_t constant) {
656 return getAffineConstantExpr(constant, context);
657 }));
658}
659
660/// Simplify add expression. Return nullptr if it can't be simplified.
661static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
662 auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
663 auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
664 // Fold if both LHS, RHS are a constant and the sum does not overflow.
665 if (lhsConst && rhsConst) {
666 int64_t sum;
667 if (llvm::AddOverflow(lhsConst.getValue(), rhsConst.getValue(), sum)) {
668 return nullptr;
669 }
670 return getAffineConstantExpr(sum, lhs.getContext());
671 }
672
673 // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
674 // If only one of them is a symbolic expressions, make it the RHS.
675 if (isa<AffineConstantExpr>(lhs) ||
676 (lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())) {
677 return rhs + lhs;
678 }
679
680 // At this point, if there was a constant, it would be on the right.
681
682 // Addition with a zero is a noop, return the other input.
683 if (rhsConst) {
684 if (rhsConst.getValue() == 0)
685 return lhs;
686 }
687 // Fold successive additions like (d0 + 2) + 3 into d0 + 5.
688 auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
689 if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Add) {
690 if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS()))
691 return lBin.getLHS() + (lrhs.getValue() + rhsConst.getValue());
692 }
693
694 // Detect "c1 * expr + c_2 * expr" as "(c1 + c2) * expr".
695 // c1 is rRhsConst, c2 is rLhsConst; firstExpr, secondExpr are their
696 // respective multiplicands.
697 std::optional<int64_t> rLhsConst, rRhsConst;
698 AffineExpr firstExpr, secondExpr;
699 AffineConstantExpr rLhsConstExpr;
700 auto lBinOpExpr = dyn_cast<AffineBinaryOpExpr>(lhs);
701 if (lBinOpExpr && lBinOpExpr.getKind() == AffineExprKind::Mul &&
702 (rLhsConstExpr = dyn_cast<AffineConstantExpr>(lBinOpExpr.getRHS()))) {
703 rLhsConst = rLhsConstExpr.getValue();
704 firstExpr = lBinOpExpr.getLHS();
705 } else {
706 rLhsConst = 1;
707 firstExpr = lhs;
708 }
709
710 auto rBinOpExpr = dyn_cast<AffineBinaryOpExpr>(rhs);
711 AffineConstantExpr rRhsConstExpr;
712 if (rBinOpExpr && rBinOpExpr.getKind() == AffineExprKind::Mul &&
713 (rRhsConstExpr = dyn_cast<AffineConstantExpr>(rBinOpExpr.getRHS()))) {
714 rRhsConst = rRhsConstExpr.getValue();
715 secondExpr = rBinOpExpr.getLHS();
716 } else {
717 rRhsConst = 1;
718 secondExpr = rhs;
719 }
720
721 if (rLhsConst && rRhsConst && firstExpr == secondExpr)
723 AffineExprKind::Mul, firstExpr,
724 getAffineConstantExpr(*rLhsConst + *rRhsConst, lhs.getContext()));
725
726 // When doing successive additions, bring constant to the right: turn (d0 + 2)
727 // + d1 into (d0 + d1) + 2.
728 if (lBin && lBin.getKind() == AffineExprKind::Add) {
729 if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) {
730 return lBin.getLHS() + rhs + lrhs;
731 }
732 }
733
734 // Detect and transform "expr - q * (expr floordiv q)" to "expr mod q", where
735 // q may be a constant or symbolic expression. This leads to a much more
736 // efficient form when 'c' is a power of two, and in general a more compact
737 // and readable form.
738
739 // Process '(expr floordiv c) * (-c)'.
740 if (!rBinOpExpr)
741 return nullptr;
742
743 auto lrhs = rBinOpExpr.getLHS();
744 auto rrhs = rBinOpExpr.getRHS();
745
746 AffineExpr llrhs, rlrhs;
747
748 // Check if lrhsBinOpExpr is of the form (expr floordiv q) * q, where q is a
749 // symbolic expression.
750 auto lrhsBinOpExpr = dyn_cast<AffineBinaryOpExpr>(lrhs);
751 // Check rrhsConstOpExpr = -1.
752 auto rrhsConstOpExpr = dyn_cast<AffineConstantExpr>(rrhs);
753 if (rrhsConstOpExpr && rrhsConstOpExpr.getValue() == -1 && lrhsBinOpExpr &&
754 lrhsBinOpExpr.getKind() == AffineExprKind::Mul) {
755 // Check llrhs = expr floordiv q.
756 llrhs = lrhsBinOpExpr.getLHS();
757 // Check rlrhs = q.
758 rlrhs = lrhsBinOpExpr.getRHS();
759 auto llrhsBinOpExpr = dyn_cast<AffineBinaryOpExpr>(llrhs);
760 if (!llrhsBinOpExpr || llrhsBinOpExpr.getKind() != AffineExprKind::FloorDiv)
761 return nullptr;
762 if (llrhsBinOpExpr.getRHS() == rlrhs && lhs == llrhsBinOpExpr.getLHS())
763 return lhs % rlrhs;
764 }
765
766 // Process lrhs, which is 'expr floordiv c'.
767 // expr + (expr // c * -c) = expr % c
768 AffineBinaryOpExpr lrBinOpExpr = dyn_cast<AffineBinaryOpExpr>(lrhs);
769 if (!lrBinOpExpr || rhs.getKind() != AffineExprKind::Mul ||
770 lrBinOpExpr.getKind() != AffineExprKind::FloorDiv)
771 return nullptr;
772
773 llrhs = lrBinOpExpr.getLHS();
774 rlrhs = lrBinOpExpr.getRHS();
775 auto rlrhsConstOpExpr = dyn_cast<AffineConstantExpr>(rlrhs);
776 // We don't support modulo with a negative RHS.
777 bool isPositiveRhs = rlrhsConstOpExpr && rlrhsConstOpExpr.getValue() > 0;
778
779 if (isPositiveRhs && lhs == llrhs && rlrhs == -rrhs) {
780 return lhs % rlrhs;
781 }
782
783 // Try simplify lhs's last operand with rhs. e.g:
784 // (s0 * 64 + s1) + (s1 // c * -c) --->
785 // s0 * 64 + (s1 + s1 // c * -c) -->
786 // s0 * 64 + s1 % c
787 if (lBinOpExpr && lBinOpExpr.getKind() == AffineExprKind::Add) {
788 if (auto simplified = simplifyAdd(lBinOpExpr.getRHS(), rhs))
789 return lBinOpExpr.getLHS() + simplified;
790 }
791 return nullptr;
792}
793
794/// Get the canonical order of two commutative exprs arguments.
795static std::pair<AffineExpr, AffineExpr>
796orderCommutativeArgs(AffineExpr expr1, AffineExpr expr2) {
797 auto sym1 = dyn_cast<AffineSymbolExpr>(expr1);
798 auto sym2 = dyn_cast<AffineSymbolExpr>(expr2);
799 // Try to order by symbol/dim position first.
800 if (sym1 && sym2)
801 return sym1.getPosition() < sym2.getPosition() ? std::pair{expr1, expr2}
802 : std::pair{expr2, expr1};
803
804 auto dim1 = dyn_cast<AffineDimExpr>(expr1);
805 auto dim2 = dyn_cast<AffineDimExpr>(expr2);
806 if (dim1 && dim2)
807 return dim1.getPosition() < dim2.getPosition() ? std::pair{expr1, expr2}
808 : std::pair{expr2, expr1};
809
810 // Put dims before symbols.
811 if (dim1 && sym2)
812 return {dim1, sym2};
813
814 if (sym1 && dim2)
815 return {dim2, sym1};
816
817 // Otherwise, keep original order.
818 return {expr1, expr2};
819}
820
821AffineExpr AffineExpr::operator+(int64_t v) const {
822 return *this + getAffineConstantExpr(v, getContext());
823}
825 if (auto simplified = simplifyAdd(*this, other))
826 return simplified;
827
828 auto [lhs, rhs] = orderCommutativeArgs(*this, other);
829
830 StorageUniquer &uniquer = getContext()->getAffineUniquer();
831 return uniquer.get<AffineBinaryOpExprStorage>(
832 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), lhs, rhs);
833}
834
835/// Simplify a multiply expression. Return nullptr if it can't be simplified.
837 auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
838 auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
839
840 if (lhsConst && rhsConst) {
842 if (llvm::MulOverflow(lhsConst.getValue(), rhsConst.getValue(), product)) {
843 return nullptr;
844 }
845 return getAffineConstantExpr(product, lhs.getContext());
846 }
847
848 if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())
849 return nullptr;
850
851 // Canonicalize the mul expression so that the constant/symbolic term is the
852 // RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a
853 // constant. (Note that a constant is trivially symbolic).
854 if (!rhs.isSymbolicOrConstant() || isa<AffineConstantExpr>(lhs)) {
855 // At least one of them has to be symbolic.
856 return rhs * lhs;
857 }
858
859 // At this point, if there was a constant, it would be on the right.
860
861 // Multiplication with a one is a noop, return the other input.
862 if (rhsConst) {
863 if (rhsConst.getValue() == 1)
864 return lhs;
865 // Multiplication with zero.
866 if (rhsConst.getValue() == 0)
867 return rhsConst;
868 }
869
870 // Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6.
871 auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
872 if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Mul) {
873 if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS()))
874 return lBin.getLHS() * (lrhs.getValue() * rhsConst.getValue());
875 }
876
877 // When doing successive multiplication, bring constant to the right: turn (d0
878 // * 2) * d1 into (d0 * d1) * 2.
879 if (lBin && lBin.getKind() == AffineExprKind::Mul) {
880 if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) {
881 return (lBin.getLHS() * rhs) * lrhs;
882 }
883 }
884
885 return nullptr;
886}
887
892 if (auto simplified = simplifyMul(*this, other))
893 return simplified;
894
895 auto [lhs, rhs] = orderCommutativeArgs(*this, other);
896
898 return uniquer.get<AffineBinaryOpExprStorage>(
899 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), lhs, rhs);
900}
901
902// Unary minus, delegate to operator*.
904 return *this * getAffineConstantExpr(-1, getContext());
905}
906
907// Delegate to operator+.
908AffineExpr AffineExpr::operator-(int64_t v) const { return *this + (-v); }
910 return *this + (-other);
911}
912
914 auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
915 auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
916
917 if (!rhsConst || rhsConst.getValue() == 0)
918 return nullptr;
919
920 if (lhsConst) {
921 if (divideSignedWouldOverflow(lhsConst.getValue(), rhsConst.getValue()))
922 return nullptr;
924 divideFloorSigned(lhsConst.getValue(), rhsConst.getValue()),
925 lhs.getContext());
926 }
927
928 // Fold floordiv of a multiply with a constant that is a multiple of the
929 // divisor. Eg: (i * 128) floordiv 64 = i * 2.
930 if (rhsConst == 1)
931 return lhs;
932
933 // Simplify `(expr * lrhs) floordiv rhsConst` when `lrhs` is known to be a
934 // multiple of `rhsConst`.
935 auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
936 if (lBin && lBin.getKind() == AffineExprKind::Mul) {
937 if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) {
938 // `rhsConst` is known to be a nonzero constant.
939 if (lrhs.getValue() % rhsConst.getValue() == 0)
940 return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
941 }
942 }
943
944 // Simplify (expr1 + expr2) floordiv divConst when either expr1 or expr2 is
945 // known to be a multiple of divConst.
946 if (lBin && lBin.getKind() == AffineExprKind::Add) {
947 int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
948 int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
949 // rhsConst is known to be a nonzero constant.
950 if (llhsDiv % rhsConst.getValue() == 0 ||
951 lrhsDiv % rhsConst.getValue() == 0)
952 return lBin.getLHS().floorDiv(rhsConst.getValue()) +
953 lBin.getRHS().floorDiv(rhsConst.getValue());
954 }
955
956 return nullptr;
957}
958
961}
963 if (auto simplified = simplifyFloorDiv(*this, other))
964 return simplified;
965
967 return uniquer.get<AffineBinaryOpExprStorage>(
968 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::FloorDiv), *this,
969 other);
970}
971
973 auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
974 auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
975
976 if (!rhsConst || rhsConst.getValue() == 0)
977 return nullptr;
978
979 if (lhsConst) {
980 if (divideSignedWouldOverflow(lhsConst.getValue(), rhsConst.getValue()))
981 return nullptr;
983 divideCeilSigned(lhsConst.getValue(), rhsConst.getValue()),
984 lhs.getContext());
985 }
986
987 // Fold ceildiv of a multiply with a constant that is a multiple of the
988 // divisor. Eg: (i * 128) ceildiv 64 = i * 2.
989 if (rhsConst.getValue() == 1)
990 return lhs;
991
992 // Simplify `(expr * lrhs) ceildiv rhsConst` when `lrhs` is known to be a
993 // multiple of `rhsConst`.
994 auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
995 if (lBin && lBin.getKind() == AffineExprKind::Mul) {
996 if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) {
997 // `rhsConst` is known to be a nonzero constant.
998 if (lrhs.getValue() % rhsConst.getValue() == 0)
999 return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
1000 }
1001 }
1002
1003 return nullptr;
1004}
1005
1008}
1010 if (auto simplified = simplifyCeilDiv(*this, other))
1011 return simplified;
1012
1014 return uniquer.get<AffineBinaryOpExprStorage>(
1015 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::CeilDiv), *this,
1016 other);
1017}
1018
1020 auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
1021 auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
1022
1023 // mod w.r.t zero or negative numbers is undefined and preserved as is.
1024 if (!rhsConst || rhsConst.getValue() < 1)
1025 return nullptr;
1026
1027 if (lhsConst) {
1028 // mod never overflows.
1029 return getAffineConstantExpr(mod(lhsConst.getValue(), rhsConst.getValue()),
1030 lhs.getContext());
1031 }
1032
1033 // Fold modulo of an expression that is known to be a multiple of a constant
1034 // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128)
1035 // mod 64 is folded to 0, and less trivially, (i*(j*4*(k*32))) mod 128 = 0.
1036 if (lhs.getLargestKnownDivisor() % rhsConst.getValue() == 0)
1037 return getAffineConstantExpr(0, lhs.getContext());
1038
1039 // Simplify (expr1 + expr2) mod divConst when either expr1 or expr2 is
1040 // known to be a multiple of divConst.
1041 auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
1042 if (lBin && lBin.getKind() == AffineExprKind::Add) {
1043 int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
1044 int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
1045 // rhsConst is known to be a positive constant.
1046 if (llhsDiv % rhsConst.getValue() == 0)
1047 return lBin.getRHS() % rhsConst.getValue();
1048 if (lrhsDiv % rhsConst.getValue() == 0)
1049 return lBin.getLHS() % rhsConst.getValue();
1050 }
1051
1052 // Simplify (e % a) % b to e % b when b evenly divides a
1053 if (lBin && lBin.getKind() == AffineExprKind::Mod) {
1054 auto intermediate = dyn_cast<AffineConstantExpr>(lBin.getRHS());
1055 if (intermediate && intermediate.getValue() >= 1 &&
1056 mod(intermediate.getValue(), rhsConst.getValue()) == 0) {
1057 return lBin.getLHS() % rhsConst.getValue();
1058 }
1059 }
1060
1061 return nullptr;
1062}
1063
1065 return *this % getAffineConstantExpr(v, getContext());
1066}
1068 if (auto simplified = simplifyMod(*this, other))
1069 return simplified;
1070
1072 return uniquer.get<AffineBinaryOpExprStorage>(
1073 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mod), *this, other);
1074}
1075
1077 SmallVector<AffineExpr, 8> dimReplacements(map.getResults());
1078 return replaceDimsAndSymbols(dimReplacements, {});
1079}
1081 expr.print(os);
1082 return os;
1083}
1084
1085/// Constructs an affine expression from a flat ArrayRef. If there are local
1086/// identifiers (neither dimensional nor symbolic) that appear in the sum of
1087/// products expression, `localExprs` is expected to have the AffineExpr
1088/// for it, and is substituted into. The ArrayRef `flatExprs` is expected to be
1089/// in the format [dims, symbols, locals, constant term].
1091 unsigned numDims,
1092 unsigned numSymbols,
1093 ArrayRef<AffineExpr> localExprs,
1094 MLIRContext *context) {
1095 // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
1096 assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
1097 "unexpected number of local expressions");
1098
1099 auto expr = getAffineConstantExpr(0, context);
1100 // Dimensions and symbols.
1101 for (unsigned j = 0; j < numDims + numSymbols; j++) {
1102 if (flatExprs[j] == 0)
1103 continue;
1104 auto id = j < numDims ? getAffineDimExpr(j, context)
1105 : getAffineSymbolExpr(j - numDims, context);
1106 expr = expr + id * flatExprs[j];
1107 }
1108
1109 // Local identifiers.
1110 for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
1111 j++) {
1112 if (flatExprs[j] == 0)
1113 continue;
1114 auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
1115 expr = expr + term;
1116 }
1117
1118 // Constant term.
1119 int64_t constTerm = flatExprs[flatExprs.size() - 1];
1120 if (constTerm != 0)
1121 expr = expr + constTerm;
1122 return expr;
1123}
1124
1125/// Constructs a semi-affine expression from a flat ArrayRef. If there are
1126/// local identifiers (neither dimensional nor symbolic) that appear in the sum
1127/// of products expression, `localExprs` is expected to have the AffineExprs for
1128/// it, and is substituted into. The ArrayRef `flatExprs` is expected to be in
1129/// the format [dims, symbols, locals, constant term]. The semi-affine
1130/// expression is constructed in the sorted order of dimension and symbol
1131/// position numbers. Note: local expressions/ids are used for mod, div as well
1132/// as symbolic RHS terms for terms that are not pure affine.
1134 unsigned numDims,
1135 unsigned numSymbols,
1136 ArrayRef<AffineExpr> localExprs,
1137 MLIRContext *context) {
1138 assert(!flatExprs.empty() && "flatExprs cannot be empty");
1139
1140 // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
1141 assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
1142 "unexpected number of local expressions");
1143
1144 AffineExpr expr = getAffineConstantExpr(0, context);
1145
1146 // We design indices as a pair which help us present the semi-affine map as
1147 // sum of product where terms are sorted based on dimension or symbol
1148 // position: <keyA, keyB> for expressions of the form dimension * symbol,
1149 // where keyA is the position number of the dimension and keyB is the
1150 // position number of the symbol. For dimensional expressions we set the index
1151 // as (position number of the dimension, -1), as we want dimensional
1152 // expressions to appear before symbolic and product of dimensional and
1153 // symbolic expressions having the dimension with the same position number.
1154 // For symbolic expression set the index as (position number of the symbol,
1155 // maximum of last dimension and symbol position) number. For example, we want
1156 // the expression we are constructing to look something like: d0 + d0 * s0 +
1157 // s0 + d1*s1 + s1.
1158
1159 // Stores the affine expression corresponding to a given index.
1161 // Stores the constant coefficient value corresponding to a given
1162 // dimension, symbol or a non-pure affine expression stored in `localExprs`.
1164 // Stores the indices as defined above, and later sorted to produce
1165 // the semi-affine expression in the desired form.
1167
1168 // Example: expression = d0 + d0 * s0 + 2 * s0.
1169 // indices = [{0,-1}, {0, 0}, {0, 1}]
1170 // coefficients = [{{0, -1}, 1}, {{0, 0}, 1}, {{0, 1}, 2}]
1171 // indexToExprMap = [{{0, -1}, d0}, {{0, 0}, d0 * s0}, {{0, 1}, s0}]
1172
1173 // Adds entries to `indexToExprMap`, `coefficients` and `indices`.
1174 auto addEntry = [&](std::pair<unsigned, signed> index, int64_t coefficient,
1175 AffineExpr expr) {
1176 assert(!llvm::is_contained(indices, index) &&
1177 "Key is already present in indices vector and overwriting will "
1178 "happen in `indexToExprMap` and `coefficients`!");
1179
1180 indices.push_back(index);
1181 coefficients.insert({index, coefficient});
1182 indexToExprMap.insert({index, expr});
1183 };
1184
1185 // Design indices for dimensional or symbolic terms, and store the indices,
1186 // constant coefficient corresponding to the indices in `coefficients` map,
1187 // and affine expression corresponding to indices in `indexToExprMap` map.
1188
1189 // Ensure we do not have duplicate keys in `indexToExpr` map.
1190 unsigned offsetSym = 0;
1191 signed offsetDim = -1;
1192 for (unsigned j = numDims; j < numDims + numSymbols; ++j) {
1193 if (flatExprs[j] == 0)
1194 continue;
1195 // For symbolic expression set the index as <position number
1196 // of the symbol, max(dimCount, symCount)> number,
1197 // as we want symbolic expressions with the same positional number to
1198 // appear after dimensional expressions having the same positional number.
1199 std::pair<unsigned, signed> indexEntry(
1200 j - numDims, std::max(numDims, numSymbols) + offsetSym++);
1201 addEntry(indexEntry, flatExprs[j],
1202 getAffineSymbolExpr(j - numDims, context));
1203 }
1204
1205 // Denotes semi-affine product, modulo or division terms, which has been added
1206 // to the `indexToExpr` map.
1207 SmallVector<bool, 4> addedToMap(flatExprs.size() - numDims - numSymbols - 1,
1208 false);
1209 unsigned lhsPos, rhsPos;
1210 // Construct indices for product terms involving dimension, symbol or constant
1211 // as lhs/rhs, and store the indices, constant coefficient corresponding to
1212 // the indices in `coefficients` map, and affine expression corresponding to
1213 // in indices in `indexToExprMap` map.
1214 for (const auto &it : llvm::enumerate(localExprs)) {
1215 if (flatExprs[numDims + numSymbols + it.index()] == 0)
1216 continue;
1217 AffineExpr expr = it.value();
1218 auto binaryExpr = dyn_cast<AffineBinaryOpExpr>(expr);
1219 if (!binaryExpr)
1220 continue;
1221
1222 AffineExpr lhs = binaryExpr.getLHS();
1223 AffineExpr rhs = binaryExpr.getRHS();
1224 if (!((isa<AffineDimExpr>(lhs) || isa<AffineSymbolExpr>(lhs)) &&
1225 (isa<AffineDimExpr>(rhs) || isa<AffineSymbolExpr>(rhs) ||
1226 isa<AffineConstantExpr>(rhs)))) {
1227 continue;
1228 }
1229 if (isa<AffineConstantExpr>(rhs)) {
1230 // For product/modulo/division expressions, when rhs of modulo/division
1231 // expression is constant, we put 0 in place of keyB, because we want
1232 // them to appear earlier in the semi-affine expression we are
1233 // constructing. When rhs is constant, we place 0 in place of keyB.
1234 if (isa<AffineDimExpr>(lhs)) {
1235 lhsPos = cast<AffineDimExpr>(lhs).getPosition();
1236 std::pair<unsigned, signed> indexEntry(lhsPos, offsetDim--);
1237 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1238 expr);
1239 } else {
1240 lhsPos = cast<AffineSymbolExpr>(lhs).getPosition();
1241 std::pair<unsigned, signed> indexEntry(
1242 lhsPos, std::max(numDims, numSymbols) + offsetSym++);
1243 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1244 expr);
1245 }
1246 } else if (isa<AffineDimExpr>(lhs)) {
1247 // For product/modulo/division expressions having lhs as dimension and rhs
1248 // as symbol, we order the terms in the semi-affine expression based on
1249 // the pair: <keyA, keyB> for expressions of the form dimension * symbol,
1250 // where keyA is the position number of the dimension and keyB is the
1251 // position number of the symbol.
1252 lhsPos = cast<AffineDimExpr>(lhs).getPosition();
1253 rhsPos = cast<AffineSymbolExpr>(rhs).getPosition();
1254 std::pair<unsigned, signed> indexEntry(lhsPos, rhsPos);
1255 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1256 } else {
1257 // For product/modulo/division expressions having both lhs and rhs as
1258 // symbol, we design indices as a pair: <keyA, keyB> for expressions
1259 // of the form dimension * symbol, where keyA is the position number of
1260 // the dimension and keyB is the position number of the symbol.
1261 lhsPos = cast<AffineSymbolExpr>(lhs).getPosition();
1262 rhsPos = cast<AffineSymbolExpr>(rhs).getPosition();
1263 std::pair<unsigned, signed> indexEntry(
1264 lhsPos, std::max(numDims, numSymbols) + offsetSym++);
1265 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1266 }
1267 addedToMap[it.index()] = true;
1268 }
1269
1270 for (unsigned j = 0; j < numDims; ++j) {
1271 if (flatExprs[j] == 0)
1272 continue;
1273 // For dimensional expressions we set the index as <position number of the
1274 // dimension, 0>, as we want dimensional expressions to appear before
1275 // symbolic ones and products of dimensional and symbolic expressions
1276 // having the dimension with the same position number.
1277 std::pair<unsigned, signed> indexEntry(j, offsetDim--);
1278 addEntry(indexEntry, flatExprs[j], getAffineDimExpr(j, context));
1279 }
1280
1281 // Constructing the simplified semi-affine sum of product/division/mod
1282 // expression from the flattened form in the desired sorted order of indices
1283 // of the various individual product/division/mod expressions.
1284 llvm::sort(indices);
1285 for (const std::pair<unsigned, unsigned> index : indices) {
1286 assert(indexToExprMap.lookup(index) &&
1287 "cannot find key in `indexToExprMap` map");
1288 expr = expr + indexToExprMap.lookup(index) * coefficients.lookup(index);
1289 }
1290
1291 // Local identifiers.
1292 for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
1293 j++) {
1294 // If the coefficient of the local expression is 0, continue as we need not
1295 // add it in out final expression.
1296 if (flatExprs[j] == 0 || addedToMap[j - numDims - numSymbols])
1297 continue;
1298 auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
1299 expr = expr + term;
1300 }
1301
1302 // Constant term.
1303 int64_t constTerm = flatExprs.back();
1304 if (constTerm != 0)
1305 expr = expr + constTerm;
1306 return expr;
1307}
1308
1314
1315// In pure affine t = expr * c, we multiply each coefficient of lhs with c.
1316//
1317// In case of semi affine multiplication expressions, t = expr * symbolic_expr,
1318// introduce a local variable p (= expr * symbolic_expr), and the affine
1319// expression expr * symbolic_expr is added to `localExprs`.
1321 assert(operandExprStack.size() >= 2);
1323 operandExprStack.pop_back();
1325
1326 // Flatten semi-affine multiplication expressions by introducing a local
1327 // variable in place of the product; the affine expression
1328 // corresponding to the quantifier is added to `localExprs`.
1329 if (!isa<AffineConstantExpr>(expr.getRHS())) {
1331 MLIRContext *context = expr.getContext();
1333 localExprs, context);
1335 localExprs, context);
1336 return addLocalVariableSemiAffine(mulLhs, rhs, a * b, lhs, lhs.size());
1337 }
1338
1339 // Get the RHS constant.
1340 int64_t rhsConst = rhs[getConstantIndex()];
1341 for (int64_t &lhsElt : lhs)
1342 lhsElt *= rhsConst;
1343
1344 return success();
1345}
1346
1348 assert(operandExprStack.size() >= 2);
1349 const auto &rhs = operandExprStack.back();
1350 auto &lhs = operandExprStack[operandExprStack.size() - 2];
1351 assert(lhs.size() == rhs.size());
1352 // Update the LHS in place.
1353 for (unsigned i = 0, e = rhs.size(); i < e; i++) {
1354 lhs[i] += rhs[i];
1355 }
1356 // Pop off the RHS.
1357 operandExprStack.pop_back();
1358 return success();
1359}
1360
1361//
1362// t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1
1363//
1364// A mod expression "expr mod c" is thus flattened by introducing a new local
1365// variable q (= expr floordiv c), such that expr mod c is replaced with
1366// 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
1367//
1368// In case of semi-affine modulo expressions, t = expr mod symbolic_expr,
1369// introduce a local variable m (= expr mod symbolic_expr), and the affine
1370// expression expr mod symbolic_expr is added to `localExprs`.
1372 assert(operandExprStack.size() >= 2);
1373
1375 operandExprStack.pop_back();
1377 MLIRContext *context = expr.getContext();
1378
1379 // Flatten semi affine modulo expressions by introducing a local
1380 // variable in place of the modulo value, and the affine expression
1381 // corresponding to the quantifier is added to `localExprs`.
1382 if (!isa<AffineConstantExpr>(expr.getRHS())) {
1385 lhs, numDims, numSymbols, localExprs, context);
1387 localExprs, context);
1388 AffineExpr modExpr = dividendExpr % divisorExpr;
1389 return addLocalVariableSemiAffine(modLhs, rhs, modExpr, lhs, lhs.size());
1390 }
1391
1392 int64_t rhsConst = rhs[getConstantIndex()];
1393 if (rhsConst <= 0)
1394 return failure();
1395
1396 // Check if the LHS expression is a multiple of modulo factor.
1397 unsigned i, e;
1398 for (i = 0, e = lhs.size(); i < e; i++)
1399 if (lhs[i] % rhsConst != 0)
1400 break;
1401 // If yes, modulo expression here simplifies to zero.
1402 if (i == lhs.size()) {
1403 llvm::fill(lhs, 0);
1404 return success();
1405 }
1406
1407 // Add a local variable for the quotient, i.e., expr % c is replaced by
1408 // (expr - q * c) where q = expr floordiv c. Do this while canceling out
1409 // the GCD of expr and c.
1410 SmallVector<int64_t, 8> floorDividend(lhs);
1411 uint64_t gcd = rhsConst;
1412 for (int64_t lhsElt : lhs)
1413 gcd = std::gcd(gcd, (uint64_t)std::abs(lhsElt));
1414 // Simplify the numerator and the denominator.
1415 if (gcd != 1) {
1416 for (int64_t &floorDividendElt : floorDividend)
1417 floorDividendElt = floorDividendElt / static_cast<int64_t>(gcd);
1418 }
1419 int64_t floorDivisor = rhsConst / static_cast<int64_t>(gcd);
1420
1421 // Construct the AffineExpr form of the floordiv to store in localExprs.
1422
1424 floorDividend, numDims, numSymbols, localExprs, context);
1425 AffineExpr divisorExpr = getAffineConstantExpr(floorDivisor, context);
1426 AffineExpr floorDivExpr = dividendExpr.floorDiv(divisorExpr);
1427 int loc;
1428 if ((loc = findLocalId(floorDivExpr)) == -1) {
1429 addLocalFloorDivId(floorDividend, floorDivisor, floorDivExpr);
1430 // Set result at top of stack to "lhs - rhsConst * q".
1431 lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst;
1432 } else {
1433 // Reuse the existing local id.
1434 lhs[getLocalVarStartIndex() + loc] -= rhsConst;
1435 }
1436 return success();
1437}
1438
1439LogicalResult
1441 return visitDivExpr(expr, /*isCeil=*/true);
1442}
1443LogicalResult
1445 return visitDivExpr(expr, /*isCeil=*/false);
1446}
1447
1449 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1450 auto &eq = operandExprStack.back();
1451 assert(expr.getPosition() < numDims && "Inconsistent number of dims");
1452 eq[getDimStartIndex() + expr.getPosition()] = 1;
1453 return success();
1454}
1455
1456LogicalResult
1458 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1459 auto &eq = operandExprStack.back();
1460 assert(expr.getPosition() < numSymbols && "inconsistent number of symbols");
1461 eq[getSymbolStartIndex() + expr.getPosition()] = 1;
1462 return success();
1463}
1464
1465LogicalResult
1467 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1468 auto &eq = operandExprStack.back();
1469 eq[getConstantIndex()] = expr.getValue();
1470 return success();
1471}
1472
1473LogicalResult SimpleAffineExprFlattener::addLocalVariableSemiAffine(
1475 SmallVectorImpl<int64_t> &result, unsigned long resultSize) {
1476 assert(result.size() == resultSize &&
1477 "`result` vector passed is not of correct size");
1478 int loc;
1479 if ((loc = findLocalId(localExpr)) == -1) {
1480 if (failed(addLocalIdSemiAffine(lhs, rhs, localExpr)))
1481 return failure();
1482 }
1483 llvm::fill(result, 0);
1484 if (loc == -1)
1485 result[getLocalVarStartIndex() + numLocals - 1] = 1;
1486 else
1487 result[getLocalVarStartIndex() + loc] = 1;
1488 return success();
1489}
1490
1491// t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1
1492// A floordiv is thus flattened by introducing a new local variable q, and
1493// replacing that expression with 'q' while adding the constraints
1494// c * q <= expr <= c * q + c - 1 to localVarCst (done by
1495// IntegerRelation::addLocalFloorDiv).
1496//
1497// A ceildiv is similarly flattened:
1498// t = expr ceildiv c <=> t = (expr + c - 1) floordiv c
1499//
1500// In case of semi affine division expressions, t = expr floordiv symbolic_expr
1501// or t = expr ceildiv symbolic_expr, introduce a local variable q (= expr
1502// floordiv/ceildiv symbolic_expr), and the affine floordiv/ceildiv is added to
1503// `localExprs`.
1504LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
1505 bool isCeil) {
1506 assert(operandExprStack.size() >= 2);
1507
1508 MLIRContext *context = expr.getContext();
1509 SmallVector<int64_t, 8> rhs = operandExprStack.back();
1510 operandExprStack.pop_back();
1511 SmallVector<int64_t, 8> &lhs = operandExprStack.back();
1512
1513 // Flatten semi affine division expressions by introducing a local
1514 // variable in place of the quotient, and the affine expression corresponding
1515 // to the quantifier is added to `localExprs`.
1516 if (!isa<AffineConstantExpr>(expr.getRHS())) {
1517 SmallVector<int64_t, 8> divLhs(lhs);
1519 localExprs, context);
1521 localExprs, context);
1522 AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
1523 return addLocalVariableSemiAffine(divLhs, rhs, divExpr, lhs, lhs.size());
1524 }
1525
1526 // This is a pure affine expr; the RHS is a positive constant.
1527 int64_t rhsConst = rhs[getConstantIndex()];
1528 if (rhsConst <= 0)
1529 return failure();
1530
1531 // Simplify the floordiv, ceildiv if possible by canceling out the greatest
1532 // common divisors of the numerator and denominator.
1533 uint64_t gcd = std::abs(rhsConst);
1534 for (int64_t lhsElt : lhs)
1535 gcd = std::gcd(gcd, (uint64_t)std::abs(lhsElt));
1536 // Simplify the numerator and the denominator.
1537 if (gcd != 1) {
1538 for (int64_t &lhsElt : lhs)
1539 lhsElt = lhsElt / static_cast<int64_t>(gcd);
1540 }
1541 int64_t divisor = rhsConst / static_cast<int64_t>(gcd);
1542 // If the divisor becomes 1, the updated LHS is the result. (The
1543 // divisor can't be negative since rhsConst is positive).
1544 if (divisor == 1)
1545 return success();
1546
1547 // If the divisor cannot be simplified to one, we will have to retain
1548 // the ceil/floor expr (simplified up until here). Add an existential
1549 // quantifier to express its result, i.e., expr1 div expr2 is replaced
1550 // by a new identifier, q.
1551 AffineExpr a =
1553 AffineExpr b = getAffineConstantExpr(divisor, context);
1554
1555 int loc;
1556 AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
1557 if ((loc = findLocalId(divExpr)) == -1) {
1558 if (!isCeil) {
1559 SmallVector<int64_t, 8> dividend(lhs);
1560 addLocalFloorDivId(dividend, divisor, divExpr);
1561 } else {
1562 // lhs ceildiv c <=> (lhs + c - 1) floordiv c
1563 SmallVector<int64_t, 8> dividend(lhs);
1564 dividend.back() += divisor - 1;
1565 addLocalFloorDivId(dividend, divisor, divExpr);
1566 }
1567 }
1568 // Set the expression on stack to the local var introduced to capture the
1569 // result of the division (floor or ceil).
1570 llvm::fill(lhs, 0);
1571 if (loc == -1)
1572 lhs[getLocalVarStartIndex() + numLocals - 1] = 1;
1573 else
1574 lhs[getLocalVarStartIndex() + loc] = 1;
1575 return success();
1576}
1577
1578// Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
1579// The local identifier added is always a floordiv of a pure add/mul affine
1580// function of other identifiers, coefficients of which are specified in
1581// dividend and with respect to a positive constant divisor. localExpr is the
1582// simplified tree expression (AffineExpr) corresponding to the quantifier.
1584 int64_t divisor,
1585 AffineExpr localExpr) {
1586 assert(divisor > 0 && "positive constant divisor expected");
1588 subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
1589 localExprs.push_back(localExpr);
1590 numLocals++;
1591 // dividend and divisor are not used here; an override of this method uses it.
1592}
1593
1597 subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
1598 localExprs.push_back(localExpr);
1599 ++numLocals;
1600 // lhs and rhs are not used here; an override of this method uses them.
1601 return success();
1602}
1603
1604int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) {
1606 if ((it = llvm::find(localExprs, localExpr)) == localExprs.end())
1607 return -1;
1608 return it - localExprs.begin();
1609}
1610
1611/// Simplify the affine expression by flattening it and reconstructing it.
1613 unsigned numSymbols) {
1614 // Simplify semi-affine expressions separately.
1615 if (!expr.isPureAffine())
1616 expr = simplifySemiAffine(expr, numDims, numSymbols);
1617
1618 SimpleAffineExprFlattener flattener(numDims, numSymbols);
1619 // has poison expression
1620 if (failed(flattener.walkPostOrder(expr)))
1621 return expr;
1622 ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
1623 if (!expr.isPureAffine() &&
1624 expr == getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1625 flattener.localExprs,
1626 expr.getContext()))
1627 return expr;
1628 AffineExpr simplifiedExpr =
1629 expr.isPureAffine()
1630 ? getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1631 flattener.localExprs, expr.getContext())
1632 : getSemiAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1633 flattener.localExprs,
1634 expr.getContext());
1635
1636 flattener.operandExprStack.pop_back();
1637 assert(flattener.operandExprStack.empty());
1638 return simplifiedExpr;
1639}
1640
1641std::optional<int64_t> mlir::getBoundForAffineExpr(
1642 AffineExpr expr, unsigned numDims, unsigned numSymbols,
1643 ArrayRef<std::optional<int64_t>> constLowerBounds,
1644 ArrayRef<std::optional<int64_t>> constUpperBounds, bool isUpper) {
1645 // Handle divs and mods.
1646 if (auto binOpExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
1647 // If the LHS of a floor or ceil is bounded and the RHS is a constant, we
1648 // can compute an upper bound.
1649 if (binOpExpr.getKind() == AffineExprKind::FloorDiv) {
1650 auto rhsConst = dyn_cast<AffineConstantExpr>(binOpExpr.getRHS());
1651 if (!rhsConst || rhsConst.getValue() < 1)
1652 return std::nullopt;
1653 auto bound =
1654 getBoundForAffineExpr(binOpExpr.getLHS(), numDims, numSymbols,
1655 constLowerBounds, constUpperBounds, isUpper);
1656 if (!bound)
1657 return std::nullopt;
1658 return divideFloorSigned(*bound, rhsConst.getValue());
1659 }
1660 if (binOpExpr.getKind() == AffineExprKind::CeilDiv) {
1661 auto rhsConst = dyn_cast<AffineConstantExpr>(binOpExpr.getRHS());
1662 if (rhsConst && rhsConst.getValue() >= 1) {
1663 auto bound =
1664 getBoundForAffineExpr(binOpExpr.getLHS(), numDims, numSymbols,
1665 constLowerBounds, constUpperBounds, isUpper);
1666 if (!bound)
1667 return std::nullopt;
1668 return divideCeilSigned(*bound, rhsConst.getValue());
1669 }
1670 return std::nullopt;
1671 }
1672 if (binOpExpr.getKind() == AffineExprKind::Mod) {
1673 // lhs mod c is always <= c - 1 and non-negative. In addition, if `lhs` is
1674 // bounded such that lb <= lhs <= ub and lb floordiv c == ub floordiv c
1675 // (same "interval"), then lb mod c <= lhs mod c <= ub mod c.
1676 auto rhsConst = dyn_cast<AffineConstantExpr>(binOpExpr.getRHS());
1677 if (rhsConst && rhsConst.getValue() >= 1) {
1678 int64_t rhsConstVal = rhsConst.getValue();
1679 auto lb = getBoundForAffineExpr(binOpExpr.getLHS(), numDims, numSymbols,
1680 constLowerBounds, constUpperBounds,
1681 /*isUpper=*/false);
1682 auto ub =
1683 getBoundForAffineExpr(binOpExpr.getLHS(), numDims, numSymbols,
1684 constLowerBounds, constUpperBounds, isUpper);
1685 if (ub && lb &&
1686 divideFloorSigned(*lb, rhsConstVal) ==
1687 divideFloorSigned(*ub, rhsConstVal))
1688 return isUpper ? mod(*ub, rhsConstVal) : mod(*lb, rhsConstVal);
1689 return isUpper ? rhsConstVal - 1 : 0;
1690 }
1691 }
1692 }
1693 // Flatten the expression.
1694 SimpleAffineExprFlattener flattener(numDims, numSymbols);
1695 auto simpleResult = flattener.walkPostOrder(expr);
1696 // has poison expression
1697 if (failed(simpleResult))
1698 return std::nullopt;
1699 ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
1700 // TODO: Handle local variables. We can get hold of flattener.localExprs and
1701 // get bound on the local expr recursively.
1702 if (flattener.numLocals > 0)
1703 return std::nullopt;
1704 int64_t bound = 0;
1705 // Substitute the constant lower or upper bound for the dimensional or
1706 // symbolic input depending on `isUpper` to determine the bound.
1707 for (unsigned i = 0, e = numDims + numSymbols; i < e; ++i) {
1708 if (flattenedExpr[i] > 0) {
1709 auto &constBound = isUpper ? constUpperBounds[i] : constLowerBounds[i];
1710 if (!constBound)
1711 return std::nullopt;
1712 bound += *constBound * flattenedExpr[i];
1713 } else if (flattenedExpr[i] < 0) {
1714 auto &constBound = isUpper ? constLowerBounds[i] : constUpperBounds[i];
1715 if (!constBound)
1716 return std::nullopt;
1717 bound += *constBound * flattenedExpr[i];
1718 }
1719 }
1720 // Constant term.
1721 bound += flattenedExpr.back();
1722 return bound;
1723}
return success()
static int64_t product(ArrayRef< int64_t > vals)
lhs
static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs)
Simplify a multiply expression. Return nullptr if it can't be simplified.
static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs)
static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef< int64_t > flatExprs, unsigned numDims, unsigned numSymbols, ArrayRef< AffineExpr > localExprs, MLIRContext *context)
Constructs a semi-affine expression from a flat ArrayRef. If there are local identifiers (neither dim...
static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs)
static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
Affine binary operation expression.
Definition AffineExpr.h:214
AffineExpr getLHS() const
AffineBinaryOpExpr(AffineExpr::ImplType *ptr)
detail::AffineBinaryOpExprStorage ImplType
Definition AffineExpr.h:216
AffineExpr getRHS() const
An integer constant appearing in affine expression.
Definition AffineExpr.h:239
AffineConstantExpr(AffineExpr::ImplType *ptr=nullptr)
detail::AffineConstantExprStorage ImplType
Definition AffineExpr.h:241
int64_t getValue() const
A dimensional identifier appearing in an affine expression.
Definition AffineExpr.h:223
AffineDimExpr(AffineExpr::ImplType *ptr)
detail::AffineDimExprStorage ImplType
Definition AffineExpr.h:225
unsigned getPosition() const
See documentation for AffineExprVisitorBase.
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
Definition AffineExpr.h:68
AffineExpr replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements) const
This method substitutes any uses of dimensions and symbols (e.g.
AffineExpr shiftDims(unsigned numDims, unsigned shift, unsigned offset=0) const
Replace dims[offset ... numDims) by dims[offset + shift ... shift + numDims).
bool isSymbolicOrConstant() const
Returns true if this expression is made out of only symbols and constants, i.e., it does not involve ...
AffineExpr operator+(int64_t v) const
AffineExpr operator*(int64_t v) const
constexpr AffineExpr()
Definition AffineExpr.h:72
bool operator==(AffineExpr other) const
Definition AffineExpr.h:76
bool isPureAffine() const
Returns true if this is a pure affine expression, i.e., multiplication, floordiv, ceildiv,...
AffineExpr shiftSymbols(unsigned numSymbols, unsigned shift, unsigned offset=0) const
Replace symbols[offset ... numSymbols) by symbols[offset + shift ... shift + numSymbols).
AffineExpr operator-() const
AffineExpr floorDiv(uint64_t v) const
ImplType * expr
Definition AffineExpr.h:196
RetT walk(FnT &&callback) const
Walk all of the AffineExpr's in this expression in postorder.
Definition AffineExpr.h:117
AffineExprKind getKind() const
Return the classification for this type.
bool isMultipleOf(int64_t factor) const
Return true if the affine expression is a multiple of 'factor'.
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
AffineExpr compose(AffineMap map) const
Compose with an AffineMap.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
bool isFunctionOfSymbol(unsigned position) const
Return true if the affine expression involves AffineSymbolExpr position.
AffineExpr replaceDims(ArrayRef< AffineExpr > dimReplacements) const
Dim-only version of replaceDimsAndSymbols.
AffineExpr operator%(uint64_t v) const
MLIRContext * getContext() const
AffineExpr replace(AffineExpr expr, AffineExpr replacement) const
Sparse replace method.
AffineExpr replaceSymbols(ArrayRef< AffineExpr > symReplacements) const
Symbol-only version of replaceDimsAndSymbols.
detail::AffineExprStorage ImplType
Definition AffineExpr.h:70
AffineExpr ceilDiv(uint64_t v) const
void print(raw_ostream &os) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
ArrayRef< AffineExpr > getResults() const
A symbolic identifier appearing in an affine expression.
Definition AffineExpr.h:231
unsigned getPosition() const
detail::AffineDimExprStorage ImplType
Definition AffineExpr.h:233
AffineSymbolExpr(AffineExpr::ImplType *ptr)
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
StorageUniquer & getAffineUniquer()
Returns the storage uniquer used for creating affine constructs.
virtual void addLocalFloorDivId(ArrayRef< int64_t > dividend, int64_t divisor, AffineExpr localExpr)
LogicalResult visitSymbolExpr(AffineSymbolExpr expr)
std::vector< SmallVector< int64_t, 8 > > operandExprStack
LogicalResult visitDimExpr(AffineDimExpr expr)
LogicalResult visitFloorDivExpr(AffineBinaryOpExpr expr)
LogicalResult visitConstantExpr(AffineConstantExpr expr)
virtual LogicalResult addLocalIdSemiAffine(ArrayRef< int64_t > lhs, ArrayRef< int64_t > rhs, AffineExpr localExpr)
Add a local identifier (needed to flatten a mod, floordiv, ceildiv, mul expr) when the rhs is a symbo...
LogicalResult visitModExpr(AffineBinaryOpExpr expr)
LogicalResult visitAddExpr(AffineBinaryOpExpr expr)
LogicalResult visitCeilDivExpr(AffineBinaryOpExpr expr)
LogicalResult visitMulExpr(AffineBinaryOpExpr expr)
SmallVector< AffineExpr, 4 > localExprs
SimpleAffineExprFlattener(unsigned numDims, unsigned numSymbols)
A utility class to get or create instances of "storage classes".
Storage * get(function_ref< void(Storage *)> initFn, TypeID id, Args &&...args)
Gets a uniqued instance of 'Storage'.
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
AttrTypeReplacer.
Include the generated interface declarations.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
std::optional< int64_t > getBoundForAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols, ArrayRef< std::optional< int64_t > > constLowerBounds, ArrayRef< std::optional< int64_t > > constUpperBounds, bool isUpper)
Get a lower or upper (depending on isUpper) bound for expr while using the constant lower and upper b...
AffineExprKind
Definition AffineExpr.h:40
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
Definition AffineExpr.h:50
@ Mul
RHS of mul is always a constant or a symbolic expression.
Definition AffineExpr.h:43
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
Definition AffineExpr.h:46
@ DimId
Dimensional identifier.
Definition AffineExpr.h:59
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
Definition AffineExpr.h:48
@ Constant
Constant integer.
Definition AffineExpr.h:57
@ SymbolId
Symbolic identifier.
Definition AffineExpr.h:61
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
AffineExpr getAffineExprFromFlatForm(ArrayRef< int64_t > flatExprs, unsigned numDims, unsigned numSymbols, ArrayRef< AffineExpr > localExprs, MLIRContext *context)
Constructs an affine expression from a flat ArrayRef.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
SmallVector< AffineExpr > getAffineConstantExprs(ArrayRef< int64_t > constants, MLIRContext *context)
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
A binary operation appearing in an affine expression.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.