18 #include "llvm/ADT/STLExtras.h"
34 AffineExprWalker(std::function<
void(
AffineExpr)> callback)
35 : callback(std::move(callback)) {}
43 AffineExprWalker(std::move(callback)).walkPostOrder(*
this);
60 llvm_unreachable(
"unknown binary operation on affine expressions");
72 unsigned dimId = cast<AffineDimExpr>().getPosition();
73 if (dimId >= dimReplacements.size())
75 return dimReplacements[dimId];
78 unsigned symId = cast<AffineSymbolExpr>().getPosition();
79 if (symId >= symReplacements.size())
81 return symReplacements[symId];
88 auto binOp = cast<AffineBinaryOpExpr>();
89 auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
90 auto newLHS = lhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
91 auto newRHS = rhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
92 if (newLHS == lhs && newRHS == rhs)
96 llvm_unreachable(
"Unknown AffineExpr");
100 return replaceDimsAndSymbols(dimReplacements, {});
105 return replaceDimsAndSymbols({}, symReplacements);
111 unsigned offset)
const {
113 for (
unsigned idx = 0; idx < offset; ++idx)
115 for (
unsigned idx = offset; idx < numDims; ++idx)
117 return replaceDimsAndSymbols(dims, {});
123 unsigned offset)
const {
125 for (
unsigned idx = 0; idx < offset; ++idx)
127 for (
unsigned idx = offset; idx < numSymbols; ++idx)
129 return replaceDimsAndSymbols({}, symbols);
135 auto it = map.find(*
this);
146 auto binOp = cast<AffineBinaryOpExpr>();
147 auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
148 auto newLHS = lhs.replace(map);
149 auto newRHS = rhs.replace(map);
150 if (newLHS == lhs && newRHS == rhs)
154 llvm_unreachable(
"Unknown AffineExpr");
160 map.insert(std::make_pair(expr, replacement));
179 auto expr = this->cast<AffineBinaryOpExpr>();
180 return expr.getLHS().isSymbolicOrConstant() &&
181 expr.getRHS().isSymbolicOrConstant();
184 llvm_unreachable(
"Unknown AffineExpr");
196 auto op = cast<AffineBinaryOpExpr>();
197 return op.getLHS().isPureAffine() && op.getRHS().isPureAffine();
203 auto op = cast<AffineBinaryOpExpr>();
204 return op.getLHS().isPureAffine() && op.getRHS().isPureAffine() &&
205 (op.getLHS().template isa<AffineConstantExpr>() ||
206 op.getRHS().template isa<AffineConstantExpr>());
211 auto op = cast<AffineBinaryOpExpr>();
212 return op.getLHS().isPureAffine() &&
213 op.getRHS().template isa<AffineConstantExpr>();
216 llvm_unreachable(
"Unknown AffineExpr");
232 binExpr = this->cast<AffineBinaryOpExpr>();
235 if (rhs && rhs.getValue() != 0) {
237 if (lhsDiv % rhs.getValue() == 0)
238 return lhsDiv / rhs.getValue();
243 return std::abs(this->cast<AffineConstantExpr>().getValue());
245 binExpr = this->cast<AffineBinaryOpExpr>();
252 binExpr = cast<AffineBinaryOpExpr>();
257 llvm_unreachable(
"Unknown AffineExpr");
267 return factor * factor == 1;
269 return cast<AffineConstantExpr>().getValue() % factor == 0;
271 binExpr = cast<AffineBinaryOpExpr>();
277 (l * u) % factor == 0;
283 binExpr = cast<AffineBinaryOpExpr>();
290 llvm_unreachable(
"Unknown AffineExpr");
297 if (
auto expr = this->dyn_cast<AffineBinaryOpExpr>()) {
298 return expr.getLHS().isFunctionOfDim(position) ||
299 expr.getRHS().isFunctionOfDim(position);
308 if (
auto expr = this->dyn_cast<AffineBinaryOpExpr>()) {
309 return expr.getLHS().isFunctionOfSymbol(position) ||
310 expr.getRHS().isFunctionOfSymbol(position);
341 "unexpected opKind");
389 llvm_unreachable(
"Unknown AffineExpr");
400 "unexpected opKind");
429 return binaryExpr.
getLHS() *
444 llvm_unreachable(
"Unknown AffineExpr");
488 llvm_unreachable(
"Unknown AffineExpr");
494 storage->context = context;
499 assignCtx,
static_cast<unsigned>(kind), position);
529 storage->context = context;
539 return llvm::to_vector(llvm::map_range(constants, [&](int64_t constant) {
549 if (lhsConst && rhsConst)
577 std::optional<int64_t> rLhsConst, rRhsConst;
583 rLhsConst = rLhsConstExpr.
getValue();
584 firstExpr = lBinOpExpr.getLHS();
594 rRhsConst = rRhsConstExpr.
getValue();
595 secondExpr = rBinOpExpr.getLHS();
601 if (rLhsConst && rRhsConst && firstExpr == secondExpr)
610 return lBin.getLHS() + rhs + lrhs;
623 auto lrhs = rBinOpExpr.getLHS();
624 auto rrhs = rBinOpExpr.getRHS();
633 if (rrhsConstOpExpr && rrhsConstOpExpr.getValue() == -1 && lrhsBinOpExpr &&
636 llrhs = lrhsBinOpExpr.getLHS();
638 rlrhs = lrhsBinOpExpr.getRHS();
642 if (llrhsBinOpExpr.getRHS() == rlrhs && lhs == llrhsBinOpExpr.getLHS())
651 llrhs = lrBinOpExpr.
getLHS();
652 rlrhs = lrBinOpExpr.
getRHS();
654 if (lhs == llrhs && rlrhs == -rrhs) {
677 if (lhsConst && rhsConst)
713 return (lBin.getLHS() * rhs) * lrhs;
740 return *
this + (-other);
748 if (!rhsConst || rhsConst.
getValue() < 1)
766 if (lrhs.getValue() % rhsConst.
getValue() == 0)
775 int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
777 if (llhsDiv % rhsConst.
getValue() == 0 ||
779 return lBin.getLHS().floorDiv(rhsConst.
getValue()) +
780 lBin.getRHS().floorDiv(rhsConst.
getValue());
803 if (!rhsConst || rhsConst.
getValue() < 1)
821 if (lrhs.getValue() % rhsConst.
getValue() == 0)
847 if (!rhsConst || rhsConst.
getValue() < 1)
865 int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
867 if (llhsDiv % rhsConst.
getValue() == 0)
868 return lBin.getRHS() % rhsConst.
getValue();
869 if (lrhsDiv % rhsConst.
getValue() == 0)
870 return lBin.getLHS() % rhsConst.
getValue();
876 if (intermediate && intermediate.getValue() >= 1 &&
877 mod(intermediate.getValue(), rhsConst.
getValue()) == 0) {
878 return lBin.getLHS() % rhsConst.
getValue();
918 assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
919 "unexpected number of local expressions");
923 for (
unsigned j = 0;
j < numDims + numSymbols;
j++) {
924 if (flatExprs[
j] == 0)
928 expr = expr +
id * flatExprs[
j];
932 for (
unsigned j = numDims + numSymbols, e = flatExprs.size() - 1;
j < e;
934 if (flatExprs[
j] == 0)
936 auto term = localExprs[
j - numDims - numSymbols] * flatExprs[
j];
941 int64_t constTerm = flatExprs[flatExprs.size() - 1];
943 expr = expr + constTerm;
960 assert(!flatExprs.empty() &&
"flatExprs cannot be empty");
963 assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
964 "unexpected number of local expressions");
996 auto addEntry = [&](std::pair<unsigned, signed> index, int64_t coefficient,
998 assert(!llvm::is_contained(indices, index) &&
999 "Key is already present in indices vector and overwriting will "
1000 "happen in `indexToExprMap` and `coefficients`!");
1002 indices.push_back(index);
1003 coefficients.insert({index, coefficient});
1004 indexToExprMap.insert({index, expr});
1012 unsigned offsetSym = 0;
1013 signed offsetDim = -1;
1014 for (
unsigned j = numDims;
j < numDims + numSymbols; ++
j) {
1015 if (flatExprs[
j] == 0)
1021 std::pair<unsigned, signed> indexEntry(
1022 j - numDims,
std::max(numDims, numSymbols) + offsetSym++);
1023 addEntry(indexEntry, flatExprs[
j],
1031 unsigned lhsPos, rhsPos;
1038 if (flatExprs[numDims + numSymbols + it.index()] == 0)
1054 std::pair<unsigned, signed> indexEntry(lhsPos, offsetDim--);
1055 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1059 std::pair<unsigned, signed> indexEntry(
1060 lhsPos,
std::max(numDims, numSymbols) + offsetSym++);
1061 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1072 std::pair<unsigned, signed> indexEntry(lhsPos, rhsPos);
1073 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1081 std::pair<unsigned, signed> indexEntry(
1082 lhsPos,
std::max(numDims, numSymbols) + offsetSym++);
1083 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1085 addedToMap[it.index()] =
true;
1088 for (
unsigned j = 0;
j < numDims; ++
j) {
1089 if (flatExprs[
j] == 0)
1095 std::pair<unsigned, signed> indexEntry(
j, offsetDim--);
1102 llvm::sort(indices);
1103 for (
const std::pair<unsigned, unsigned> index : indices) {
1104 assert(indexToExprMap.lookup(index) &&
1105 "cannot find key in `indexToExprMap` map");
1106 expr = expr + indexToExprMap.lookup(index) * coefficients.lookup(index);
1110 for (
unsigned j = numDims + numSymbols, e = flatExprs.size() - 1;
j < e;
1114 if (flatExprs[
j] == 0 || addedToMap[
j - numDims - numSymbols])
1116 auto term = localExprs[
j - numDims - numSymbols] * flatExprs[
j];
1121 int64_t constTerm = flatExprs.back();
1123 expr = expr + constTerm;
1128 unsigned numSymbols)
1129 : numDims(numDims), numSymbols(numSymbols), numLocals(0) {
1153 addLocalVariableSemiAffine(a * b, lhs, lhs.size());
1158 auto rhsConst = rhs[getConstantIndex()];
1159 for (
unsigned i = 0, e = lhs.size(); i < e; i++) {
1168 assert(lhs.size() == rhs.size());
1170 for (
unsigned i = 0, e = rhs.size(); i < e; i++) {
1203 AffineExpr modExpr = dividendExpr % divisorExpr;
1204 addLocalVariableSemiAffine(modExpr, lhs, lhs.size());
1208 int64_t rhsConst = rhs[getConstantIndex()];
1211 assert(rhsConst > 0 &&
"RHS constant has to be positive");
1215 for (i = 0, e = lhs.size(); i < e; i++)
1216 if (lhs[i] % rhsConst != 0)
1219 if (i == lhs.size()) {
1220 std::fill(lhs.begin(), lhs.end(), 0);
1228 uint64_t
gcd = rhsConst;
1229 for (
unsigned i = 0, e = lhs.size(); i < e; i++)
1233 for (
unsigned i = 0, e = floorDividend.size(); i < e; i++)
1234 floorDividend[i] = floorDividend[i] /
static_cast<int64_t
>(
gcd);
1236 int64_t floorDivisor = rhsConst /
static_cast<int64_t
>(
gcd);
1245 if ((loc = findLocalId(floorDivExpr)) == -1) {
1248 lhs[getLocalVarStartIndex() +
numLocals - 1] = -rhsConst;
1251 lhs[getLocalVarStartIndex() + loc] = -rhsConst;
1256 visitDivExpr(expr,
true);
1259 visitDivExpr(expr,
false);
1273 eq[getSymbolStartIndex() + expr.
getPosition()] = 1;
1279 eq[getConstantIndex()] = expr.
getValue();
1282 void SimpleAffineExprFlattener::addLocalVariableSemiAffine(
1284 unsigned long resultSize) {
1285 assert(result.size() == resultSize &&
1286 "`result` vector passed is not of correct size");
1288 if ((loc = findLocalId(expr)) == -1)
1290 std::fill(result.begin(), result.end(), 0);
1292 result[getLocalVarStartIndex() +
numLocals - 1] = 1;
1294 result[getLocalVarStartIndex() + loc] = 1;
1328 addLocalVariableSemiAffine(divExpr, lhs, lhs.size());
1333 int64_t rhsConst = rhs[getConstantIndex()];
1336 assert(rhsConst > 0 &&
"RHS constant has to be positive");
1341 for (
unsigned i = 0, e = lhs.size(); i < e; i++)
1345 for (
unsigned i = 0, e = lhs.size(); i < e; i++)
1346 lhs[i] = lhs[i] /
static_cast<int64_t
>(
gcd);
1348 int64_t divisor = rhsConst /
static_cast<int64_t
>(
gcd);
1364 if ((loc = findLocalId(divExpr)) == -1) {
1371 dividend.back() += divisor - 1;
1377 std::fill(lhs.begin(), lhs.end(), 0);
1379 lhs[getLocalVarStartIndex() +
numLocals - 1] = 1;
1381 lhs[getLocalVarStartIndex() + loc] = 1;
1392 assert(divisor > 0 &&
"positive constant divisor expected");
1394 subExpr.insert(subExpr.begin() + getLocalVarStartIndex() +
numLocals, 0);
1402 subExpr.insert(subExpr.begin() + getLocalVarStartIndex() +
numLocals, 0);
1407 int SimpleAffineExprFlattener::findLocalId(
AffineExpr localExpr) {
1416 unsigned numSymbols) {
1439 return simplifiedExpr;
static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos, AffineExprKind opKind)
Divides the given expression by the given symbol at position symbolPos.
static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs)
Simplify a multiply expression. Return nullptr if it can't be simplified.
static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs)
static AffineExpr simplifySemiAffine(AffineExpr expr)
Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv operations when the second...
static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs)
Simplify add expression. Return nullptr if it can't be simplified.
static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef< int64_t > flatExprs, unsigned numDims, unsigned numSymbols, ArrayRef< AffineExpr > localExprs, MLIRContext *context)
Constructs a semi-affine expression from a flat ArrayRef.
static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs)
static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs)
static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position, MLIRContext *context)
static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos, AffineExprKind opKind)
Returns true if the expression is divisible by the given symbol with position symbolPos.
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Affine binary operation expression.
AffineExpr getLHS() const
AffineBinaryOpExpr(AffineExpr::ImplType *ptr)
AffineExpr getRHS() const
An integer constant appearing in affine expression.
AffineConstantExpr(AffineExpr::ImplType *ptr=nullptr)
A dimensional identifier appearing in an affine expression.
AffineDimExpr(AffineExpr::ImplType *ptr)
unsigned getPosition() const
Base class for AffineExpr visitors/walkers.
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
AffineExpr replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements) const
This method substitutes any uses of dimensions and symbols (e.g.
AffineExpr shiftDims(unsigned numDims, unsigned shift, unsigned offset=0) const
Replace dims[offset ...
AffineExpr operator+(int64_t v) const
bool isSymbolicOrConstant() const
Returns true if this expression is made out of only symbols and constants, i.e., it does not involve ...
AffineExpr operator*(int64_t v) const
bool operator==(AffineExpr other) const
bool isPureAffine() const
Returns true if this is a pure affine expression, i.e., multiplication, floordiv, ceildiv,...
AffineExpr shiftSymbols(unsigned numSymbols, unsigned shift, unsigned offset=0) const
Replace symbols[offset ...
AffineExpr operator-() const
void walk(std::function< void(AffineExpr)> callback) const
Walk all of the AffineExpr's in this expression in postorder.
AffineExpr floorDiv(uint64_t v) const
AffineExprKind getKind() const
Return the classification for this type.
bool isMultipleOf(int64_t factor) const
Return true if the affine expression is a multiple of 'factor'.
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
AffineExpr compose(AffineMap map) const
Compose with an AffineMap.
constexpr bool isa() const
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
bool isFunctionOfSymbol(unsigned position) const
Return true if the affine expression involves AffineSymbolExpr position.
AffineExpr replaceDims(ArrayRef< AffineExpr > dimReplacements) const
Dim-only version of replaceDimsAndSymbols.
AffineExpr operator%(uint64_t v) const
MLIRContext * getContext() const
AffineExpr replace(AffineExpr expr, AffineExpr replacement) const
Sparse replace method.
AffineExpr replaceSymbols(ArrayRef< AffineExpr > symReplacements) const
Symbol-only version of replaceDimsAndSymbols.
AffineExpr ceilDiv(uint64_t v) const
void print(raw_ostream &os) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
ArrayRef< AffineExpr > getResults() const
A symbolic identifier appearing in an affine expression.
AffineSymbolExpr(AffineExpr::ImplType *ptr)
unsigned getPosition() const
MLIRContext is the top-level object for a collection of MLIR operations.
StorageUniquer & getAffineUniquer()
Returns the storage uniquer used for creating affine constructs.
virtual void addLocalFloorDivId(ArrayRef< int64_t > dividend, int64_t divisor, AffineExpr localExpr)
void visitFloorDivExpr(AffineBinaryOpExpr expr)
void visitAddExpr(AffineBinaryOpExpr expr)
std::vector< SmallVector< int64_t, 8 > > operandExprStack
void visitDimExpr(AffineDimExpr expr)
void visitConstantExpr(AffineConstantExpr expr)
void visitSymbolExpr(AffineSymbolExpr expr)
virtual void addLocalIdSemiAffine(AffineExpr localExpr)
Add a local identifier (needed to flatten a mod, floordiv, ceildiv, mul expr) when the rhs is a symbo...
SmallVector< AffineExpr, 4 > localExprs
void visitCeilDivExpr(AffineBinaryOpExpr expr)
void visitModExpr(AffineBinaryOpExpr expr)
SimpleAffineExprFlattener(unsigned numDims, unsigned numSymbols)
void visitMulExpr(AffineBinaryOpExpr expr)
A utility class to get or create instances of "storage classes".
Storage * get(function_ref< void(Storage *)> initFn, TypeID id, Args &&...args)
Gets a uniqued instance of 'Storage'.
Detect if any of the given parameter types has a sub-element handler.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt gcd(const MPInt &a, const MPInt &b)
LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt abs(const MPInt &x)
This header declares functions that assist transformations in the MemRef dialect.
int64_t floorDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's floordiv operation on constants.
int64_t ceilDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's ceildiv operation on constants.
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ Constant
Constant integer.
@ SymbolId
Symbolic identifier.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
AffineExpr getAffineExprFromFlatForm(ArrayRef< int64_t > flatExprs, unsigned numDims, unsigned numSymbols, ArrayRef< AffineExpr > localExprs, MLIRContext *context)
Constructs an affine expression from a flat ArrayRef.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
SmallVector< AffineExpr > getAffineConstantExprs(ArrayRef< int64_t > constants, MLIRContext *context)
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
int64_t mod(int64_t lhs, int64_t rhs)
Returns MLIR's mod operation on constants.
A binary operation appearing in an affine expression.
An integer constant appearing in affine expression.
A dimensional or symbolic identifier appearing in an affine expression.
Base storage class appearing in an affine expression.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.