18 #include "llvm/ADT/STLExtras.h" 30 std::function<void(AffineExpr)> callback;
32 AffineExprWalker(std::function<
void(
AffineExpr)> callback)
33 : callback(std::move(callback)) {}
41 AffineExprWalker(std::move(callback)).walkPostOrder(*
this);
58 llvm_unreachable(
"unknown binary operation on affine expressions");
70 unsigned dimId = cast<AffineDimExpr>().getPosition();
71 if (dimId >= dimReplacements.size())
73 return dimReplacements[dimId];
76 unsigned symId = cast<AffineSymbolExpr>().getPosition();
77 if (symId >= symReplacements.size())
79 return symReplacements[symId];
86 auto binOp = cast<AffineBinaryOpExpr>();
87 auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
88 auto newLHS = lhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
89 auto newRHS = rhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
90 if (newLHS == lhs && newRHS == rhs)
94 llvm_unreachable(
"Unknown AffineExpr");
98 return replaceDimsAndSymbols(dimReplacements, {});
103 return replaceDimsAndSymbols({}, symReplacements);
109 unsigned offset)
const {
111 for (
unsigned idx = 0; idx < offset; ++idx)
113 for (
unsigned idx = offset; idx < numDims; ++idx)
115 return replaceDimsAndSymbols(dims, {});
121 unsigned offset)
const {
123 for (
unsigned idx = 0; idx < offset; ++idx)
125 for (
unsigned idx = offset; idx < numSymbols; ++idx)
127 return replaceDimsAndSymbols({}, symbols);
133 auto it = map.find(*
this);
144 auto binOp = cast<AffineBinaryOpExpr>();
145 auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
146 auto newLHS = lhs.replace(map);
147 auto newRHS = rhs.replace(map);
148 if (newLHS == lhs && newRHS == rhs)
152 llvm_unreachable(
"Unknown AffineExpr");
158 map.insert(std::make_pair(expr, replacement));
177 auto expr = this->cast<AffineBinaryOpExpr>();
178 return expr.getLHS().isSymbolicOrConstant() &&
179 expr.getRHS().isSymbolicOrConstant();
182 llvm_unreachable(
"Unknown AffineExpr");
194 auto op = cast<AffineBinaryOpExpr>();
195 return op.getLHS().isPureAffine() && op.getRHS().isPureAffine();
201 auto op = cast<AffineBinaryOpExpr>();
202 return op.getLHS().isPureAffine() && op.getRHS().isPureAffine() &&
203 (op.getLHS().template isa<AffineConstantExpr>() ||
204 op.getRHS().template isa<AffineConstantExpr>());
209 auto op = cast<AffineBinaryOpExpr>();
210 return op.getLHS().isPureAffine() &&
211 op.getRHS().template isa<AffineConstantExpr>();
214 llvm_unreachable(
"Unknown AffineExpr");
228 return std::abs(this->cast<AffineConstantExpr>().getValue());
230 binExpr = this->cast<AffineBinaryOpExpr>();
237 binExpr = cast<AffineBinaryOpExpr>();
238 return llvm::GreatestCommonDivisor64(
243 llvm_unreachable(
"Unknown AffineExpr");
253 return factor * factor == 1;
255 return cast<AffineConstantExpr>().getValue() % factor == 0;
257 binExpr = cast<AffineBinaryOpExpr>();
263 (l * u) % factor == 0;
269 binExpr = cast<AffineBinaryOpExpr>();
270 return llvm::GreatestCommonDivisor64(
277 llvm_unreachable(
"Unknown AffineExpr");
284 if (
auto expr = this->dyn_cast<AffineBinaryOpExpr>()) {
285 return expr.getLHS().isFunctionOfDim(position) ||
286 expr.getRHS().isFunctionOfDim(position);
295 if (
auto expr = this->dyn_cast<AffineBinaryOpExpr>()) {
296 return expr.getLHS().isFunctionOfSymbol(position) ||
297 expr.getRHS().isFunctionOfSymbol(position);
328 "unexpected opKind");
376 llvm_unreachable(
"Unknown AffineExpr");
387 "unexpected opKind");
416 return binaryExpr.
getLHS() *
431 llvm_unreachable(
"Unknown AffineExpr");
467 unsigned symbolPos = symbolExpr.getPosition();
475 llvm_unreachable(
"Unknown AffineExpr");
481 storage->context = context;
486 assignCtx,
static_cast<unsigned>(kind), position);
516 storage->context = context;
528 if (lhsConst && rhsConst)
562 rLhsConst = rLhsConstExpr.getValue();
563 firstExpr = lBinOpExpr.getLHS();
573 rRhsConst = rRhsConstExpr.getValue();
574 secondExpr = rBinOpExpr.getLHS();
580 if (rLhsConst && rRhsConst && firstExpr == secondExpr)
590 return lBin.getLHS() + rhs + lrhs;
603 auto lrhs = rBinOpExpr.getLHS();
604 auto rrhs = rBinOpExpr.getRHS();
613 if (rrhsConstOpExpr && rrhsConstOpExpr.getValue() == -1 && lrhsBinOpExpr &&
616 llrhs = lrhsBinOpExpr.getLHS();
618 rlrhs = lrhsBinOpExpr.getRHS();
622 if (llrhsBinOpExpr.getRHS() == rlrhs && lhs == llrhsBinOpExpr.getLHS())
631 llrhs = lrBinOpExpr.
getLHS();
632 rlrhs = lrBinOpExpr.
getRHS();
634 if (lhs == llrhs && rlrhs == -rrhs) {
657 if (lhsConst && rhsConst)
693 return (lBin.getLHS() * rhs) * lrhs;
720 return *
this + (-other);
728 if (!rhsConst || rhsConst.
getValue() < 1)
746 if (lrhs.getValue() % rhsConst.
getValue() == 0)
755 int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
757 if (llhsDiv % rhsConst.
getValue() == 0 ||
759 return lBin.getLHS().floorDiv(rhsConst.
getValue()) +
760 lBin.getRHS().floorDiv(rhsConst.
getValue());
783 if (!rhsConst || rhsConst.
getValue() < 1)
801 if (lrhs.getValue() % rhsConst.
getValue() == 0)
827 if (!rhsConst || rhsConst.
getValue() < 1)
845 int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
847 if (llhsDiv % rhsConst.
getValue() == 0)
848 return lBin.getRHS() % rhsConst.
getValue();
849 if (lrhsDiv % rhsConst.
getValue() == 0)
850 return lBin.getLHS() % rhsConst.
getValue();
856 if (intermediate && intermediate.getValue() >= 1 &&
857 mod(intermediate.getValue(), rhsConst.
getValue()) == 0) {
858 return lBin.getLHS() % rhsConst.
getValue();
898 assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
899 "unexpected number of local expressions");
903 for (
unsigned j = 0;
j < numDims + numSymbols;
j++) {
904 if (flatExprs[
j] == 0)
908 expr = expr +
id * flatExprs[
j];
912 for (
unsigned j = numDims + numSymbols, e = flatExprs.size() - 1;
j < e;
914 if (flatExprs[
j] == 0)
916 auto term = localExprs[
j - numDims - numSymbols] * flatExprs[
j];
921 int64_t constTerm = flatExprs[flatExprs.size() - 1];
923 expr = expr + constTerm;
940 assert(!flatExprs.empty() &&
"flatExprs cannot be empty");
943 assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
944 "unexpected number of local expressions");
976 auto addEntry = [&](std::pair<unsigned, signed> index, int64_t coefficient,
978 assert(std::find(indices.begin(), indices.end(), index) == indices.end() &&
979 "Key is already present in indices vector and overwriting will " 980 "happen in `indexToExprMap` and `coefficients`!");
982 indices.push_back(index);
983 coefficients.insert({index, coefficient});
984 indexToExprMap.insert({index, expr});
991 for (
unsigned j = 0;
j < numDims; ++
j) {
992 if (flatExprs[
j] == 0)
998 std::pair<unsigned, signed> indexEntry(
j, -1);
1001 for (
unsigned j = numDims;
j < numDims + numSymbols; ++
j) {
1002 if (flatExprs[
j] == 0)
1008 std::pair<unsigned, signed> indexEntry(
j - numDims,
1010 addEntry(indexEntry, flatExprs[
j],
1018 unsigned lhsPos, rhsPos;
1025 if (flatExprs[numDims + numSymbols + it.index()] == 0)
1041 std::pair<unsigned, signed> indexEntry(lhsPos, -1);
1042 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1046 std::pair<unsigned, signed> indexEntry(lhsPos,
1048 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1059 std::pair<unsigned, signed> indexEntry(lhsPos, rhsPos);
1060 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
expr);
1068 std::pair<unsigned, signed> indexEntry(lhsPos, rhsPos);
1069 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
expr);
1071 addedToMap[it.index()] =
true;
1077 std::sort(indices.begin(), indices.end());
1078 for (
const std::pair<unsigned, unsigned> index : indices) {
1079 assert(indexToExprMap.lookup(index) &&
1080 "cannot find key in `indexToExprMap` map");
1081 expr = expr + indexToExprMap.lookup(index) * coefficients.lookup(index);
1085 for (
unsigned j = numDims + numSymbols, e = flatExprs.size() - 1;
j < e;
1089 if (flatExprs[
j] == 0 || addedToMap[
j - numDims - numSymbols])
1091 auto term = localExprs[
j - numDims - numSymbols] * flatExprs[
j];
1096 int64_t constTerm = flatExprs.back();
1098 expr = expr + constTerm;
1103 unsigned numSymbols)
1104 : numDims(numDims), numSymbols(numSymbols), numLocals(0) {
1128 addLocalVariableSemiAffine(a * b, lhs, lhs.size());
1133 auto rhsConst = rhs[getConstantIndex()];
1134 for (
unsigned i = 0, e = lhs.size(); i < e; i++) {
1143 assert(lhs.size() == rhs.size());
1145 for (
unsigned i = 0, e = rhs.size(); i < e; i++) {
1178 AffineExpr modExpr = dividendExpr % divisorExpr;
1179 addLocalVariableSemiAffine(modExpr, lhs, lhs.size());
1183 int64_t rhsConst = rhs[getConstantIndex()];
1186 assert(rhsConst > 0 &&
"RHS constant has to be positive");
1190 for (i = 0, e = lhs.size(); i < e; i++)
1191 if (lhs[i] % rhsConst != 0)
1194 if (i == lhs.size()) {
1195 std::fill(lhs.begin(), lhs.end(), 0);
1203 uint64_t gcd = rhsConst;
1204 for (
unsigned i = 0, e = lhs.size(); i < e; i++)
1205 gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i]));
1208 for (
unsigned i = 0, e = floorDividend.size(); i < e; i++)
1209 floorDividend[i] = floorDividend[i] / static_cast<int64_t>(gcd);
1211 int64_t floorDivisor = rhsConst /
static_cast<int64_t
>(gcd);
1220 if ((loc = findLocalId(floorDivExpr)) == -1) {
1223 lhs[getLocalVarStartIndex() +
numLocals - 1] = -rhsConst;
1226 lhs[getLocalVarStartIndex() + loc] = -rhsConst;
1231 visitDivExpr(expr,
true);
1234 visitDivExpr(expr,
false);
1248 eq[getSymbolStartIndex() + expr.
getPosition()] = 1;
1254 eq[getConstantIndex()] = expr.
getValue();
1257 void SimpleAffineExprFlattener::addLocalVariableSemiAffine(
1259 unsigned long resultSize) {
1260 assert(result.size() == resultSize &&
1261 "`result` vector passed is not of correct size");
1263 if ((loc = findLocalId(expr)) == -1)
1265 std::fill(result.begin(), result.end(), 0);
1267 result[getLocalVarStartIndex() +
numLocals - 1] = 1;
1269 result[getLocalVarStartIndex() + loc] = 1;
1303 addLocalVariableSemiAffine(divExpr, lhs, lhs.size());
1308 int64_t rhsConst = rhs[getConstantIndex()];
1311 assert(rhsConst > 0 &&
"RHS constant has to be positive");
1315 uint64_t gcd = std::abs(rhsConst);
1316 for (
unsigned i = 0, e = lhs.size(); i < e; i++)
1317 gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i]));
1320 for (
unsigned i = 0, e = lhs.size(); i < e; i++)
1321 lhs[i] = lhs[i] / static_cast<int64_t>(gcd);
1323 int64_t divisor = rhsConst /
static_cast<int64_t
>(gcd);
1339 if ((loc = findLocalId(divExpr)) == -1) {
1346 dividend.back() += divisor - 1;
1352 std::fill(lhs.begin(), lhs.end(), 0);
1354 lhs[getLocalVarStartIndex() +
numLocals - 1] = 1;
1356 lhs[getLocalVarStartIndex() + loc] = 1;
1367 assert(divisor > 0 &&
"positive constant divisor expected");
1369 subExpr.insert(subExpr.begin() + getLocalVarStartIndex() +
numLocals, 0);
1377 subExpr.insert(subExpr.begin() + getLocalVarStartIndex() +
numLocals, 0);
1382 int SimpleAffineExprFlattener::findLocalId(
AffineExpr localExpr) {
1414 return simplifiedExpr;
Affine binary operation expression.
Include the generated interface declarations.
static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef< int64_t > flatExprs, unsigned numDims, unsigned numSymbols, ArrayRef< AffineExpr > localExprs, MLIRContext *context)
Constructs a semi-affine expression from a flat ArrayRef.
StorageUniquer & getAffineUniquer()
Returns the storage uniquer used for creating affine constructs.
RHS of mod is always a constant or a symbolic expression with a positive value.
Base storage class appearing in an affine expression.
AffineExpr replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements) const
This method substitutes any uses of dimensions and symbols (e.g.
Storage * get(function_ref< void(Storage *)> initFn, TypeID id, Args &&...args)
Gets a uniqued instance of 'Storage'.
AffineConstantExpr(AffineExpr::ImplType *ptr=nullptr)
bool isPureAffine() const
Returns true if this is a pure affine expression, i.e., multiplication, floordiv, ceildiv...
static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos, AffineExprKind opKind)
Returns true if the expression is divisible by the given symbol with position symbolPos.
AffineExpr compose(AffineMap map) const
Compose with an AffineMap.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
AffineExpr shiftSymbols(unsigned numSymbols, unsigned shift, unsigned offset=0) const
Replace symbols[offset ...
AffineBinaryOpExpr(AffineExpr::ImplType *ptr)
bool operator==(AffineExpr other) const
AffineExpr shiftDims(unsigned numDims, unsigned shift, unsigned offset=0) const
Replace dims[offset ...
A binary operation appearing in an affine expression.
RetTy walkPostOrder(AffineExpr expr)
AffineSymbolExpr(AffineExpr::ImplType *ptr)
unsigned getPosition() const
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
An integer constant appearing in affine expression.
AffineExpr getAffineExprFromFlatForm(ArrayRef< int64_t > flatExprs, unsigned numDims, unsigned numSymbols, ArrayRef< AffineExpr > localExprs, MLIRContext *context)
Constructs an affine expression from a flat ArrayRef.
void walk(std::function< void(AffineExpr)> callback) const
Walk all of the AffineExpr's in this expression in postorder.
void visitConstantExpr(AffineConstantExpr expr)
AffineExpr getRHS() const
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
bool isMultipleOf(int64_t factor) const
Return true if the affine expression is a multiple of 'factor'.
Base class for AffineExpr visitors/walkers.
bool isSymbolicOrConstant() const
Returns true if this expression is made out of only symbols and constants, i.e., it does not involve ...
static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos, AffineExprKind opKind)
Divides the given expression by the given symbol at position symbolPos.
AffineExpr operator*(int64_t v) const
static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs)
AffineDimExpr(AffineExpr::ImplType *ptr)
void visitSymbolExpr(AffineSymbolExpr expr)
static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs)
static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs)
Simplify a multiply expression. Return nullptr if it can't be simplified.
AffineExpr getLHS() const
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
virtual void addLocalIdSemiAffine(AffineExpr localExpr)
Add a local identifier (needed to flatten a mod, floordiv, ceildiv, mul expr) when the rhs is a symbo...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Base type for affine expression.
MLIRContext * getContext() const
int64_t mod(int64_t lhs, int64_t rhs)
Returns MLIR's mod operation on constants.
RHS of mul is always a constant or a symbolic expression.
void visitCeilDivExpr(AffineBinaryOpExpr expr)
virtual void addLocalFloorDivId(ArrayRef< int64_t > dividend, int64_t divisor, AffineExpr localExpr)
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued...
A utility class to get or create instances of "storage classes".
RHS of floordiv is always a constant or a symbolic expression.
AffineExpr ceilDiv(uint64_t v) const
ArrayRef< AffineExpr > getResults() const
bool isFunctionOfSymbol(unsigned position) const
Return true if the affine expression involves AffineSymbolExpr position.
Eliminates identifier at the specified position using Fourier-Motzkin variable elimination.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
AffineExpr floorDiv(uint64_t v) const
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
RHS of ceildiv is always a constant or a symbolic expression.
unsigned getPosition() const
void visitFloorDivExpr(AffineBinaryOpExpr expr)
AffineExpr replaceSymbols(ArrayRef< AffineExpr > symReplacements) const
Symbol-only version of replaceDimsAndSymbols.
static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs)
AffineExpr replaceDims(ArrayRef< AffineExpr > dimReplacements) const
Dim-only version of replaceDimsAndSymbols.
AffineExpr replace(AffineExpr expr, AffineExpr replacement) const
Sparse replace method.
AffineExprKind getKind() const
Return the classification for this type.
A dimensional or symbolic identifier appearing in an affine expression.
static AffineExpr simplifySemiAffine(AffineExpr expr)
Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv operations when the second...
void print(raw_ostream &os) const
A dimensional identifier appearing in an affine expression.
void visitDimExpr(AffineDimExpr expr)
static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs)
Simplify add expression. Return nullptr if it can't be simplified.
SimpleAffineExprFlattener(unsigned numDims, unsigned numSymbols)
MLIRContext is the top-level object for a collection of MLIR operations.
AffineExpr operator+(int64_t v) const
An integer constant appearing in affine expression.
void visitMulExpr(AffineBinaryOpExpr expr)
void visitModExpr(AffineBinaryOpExpr expr)
static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position, MLIRContext *context)
AffineExpr operator-() const
AffineExpr operator%(uint64_t v) const
void visitAddExpr(AffineBinaryOpExpr expr)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
std::vector< SmallVector< int64_t, 8 > > operandExprStack
A symbolic identifier appearing in an affine expression.
SmallVector< AffineExpr, 4 > localExprs