18 #include "llvm/ADT/STLExtras.h"
34 AffineExprWalker(std::function<
void(
AffineExpr)> callback)
35 : callback(std::move(callback)) {}
43 AffineExprWalker(std::move(callback)).walkPostOrder(*
this);
60 llvm_unreachable(
"unknown binary operation on affine expressions");
72 unsigned dimId = cast<AffineDimExpr>().getPosition();
73 if (dimId >= dimReplacements.size())
75 return dimReplacements[dimId];
78 unsigned symId = cast<AffineSymbolExpr>().getPosition();
79 if (symId >= symReplacements.size())
81 return symReplacements[symId];
88 auto binOp = cast<AffineBinaryOpExpr>();
89 auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
90 auto newLHS = lhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
91 auto newRHS = rhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
92 if (newLHS == lhs && newRHS == rhs)
96 llvm_unreachable(
"Unknown AffineExpr");
100 return replaceDimsAndSymbols(dimReplacements, {});
105 return replaceDimsAndSymbols({}, symReplacements);
111 unsigned offset)
const {
113 for (
unsigned idx = 0; idx < offset; ++idx)
115 for (
unsigned idx = offset; idx < numDims; ++idx)
117 return replaceDimsAndSymbols(dims, {});
123 unsigned offset)
const {
125 for (
unsigned idx = 0; idx < offset; ++idx)
127 for (
unsigned idx = offset; idx < numSymbols; ++idx)
129 return replaceDimsAndSymbols({}, symbols);
135 auto it = map.find(*
this);
146 auto binOp = cast<AffineBinaryOpExpr>();
147 auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
148 auto newLHS = lhs.replace(map);
149 auto newRHS = rhs.replace(map);
150 if (newLHS == lhs && newRHS == rhs)
154 llvm_unreachable(
"Unknown AffineExpr");
160 map.insert(std::make_pair(expr, replacement));
179 auto expr = this->cast<AffineBinaryOpExpr>();
180 return expr.getLHS().isSymbolicOrConstant() &&
181 expr.getRHS().isSymbolicOrConstant();
184 llvm_unreachable(
"Unknown AffineExpr");
196 auto op = cast<AffineBinaryOpExpr>();
197 return op.getLHS().isPureAffine() && op.getRHS().isPureAffine();
203 auto op = cast<AffineBinaryOpExpr>();
204 return op.getLHS().isPureAffine() && op.getRHS().isPureAffine() &&
205 (op.getLHS().template isa<AffineConstantExpr>() ||
206 op.getRHS().template isa<AffineConstantExpr>());
211 auto op = cast<AffineBinaryOpExpr>();
212 return op.getLHS().isPureAffine() &&
213 op.getRHS().template isa<AffineConstantExpr>();
216 llvm_unreachable(
"Unknown AffineExpr");
232 binExpr = this->cast<AffineBinaryOpExpr>();
235 if (rhs && rhs.getValue() != 0) {
237 if (lhsDiv % rhs.getValue() == 0)
238 return lhsDiv / rhs.getValue();
243 return std::abs(this->cast<AffineConstantExpr>().getValue());
245 binExpr = this->cast<AffineBinaryOpExpr>();
252 binExpr = cast<AffineBinaryOpExpr>();
257 llvm_unreachable(
"Unknown AffineExpr");
267 return factor * factor == 1;
269 return cast<AffineConstantExpr>().getValue() % factor == 0;
271 binExpr = cast<AffineBinaryOpExpr>();
277 (l * u) % factor == 0;
283 binExpr = cast<AffineBinaryOpExpr>();
290 llvm_unreachable(
"Unknown AffineExpr");
297 if (
auto expr = this->dyn_cast<AffineBinaryOpExpr>()) {
298 return expr.getLHS().isFunctionOfDim(position) ||
299 expr.getRHS().isFunctionOfDim(position);
308 if (
auto expr = this->dyn_cast<AffineBinaryOpExpr>()) {
309 return expr.getLHS().isFunctionOfSymbol(position) ||
310 expr.getRHS().isFunctionOfSymbol(position);
341 "unexpected opKind");
389 llvm_unreachable(
"Unknown AffineExpr");
400 "unexpected opKind");
429 return binaryExpr.
getLHS() *
444 llvm_unreachable(
"Unknown AffineExpr");
488 llvm_unreachable(
"Unknown AffineExpr");
494 storage->context = context;
499 assignCtx,
static_cast<unsigned>(kind), position);
529 storage->context = context;
541 if (lhsConst && rhsConst)
569 std::optional<int64_t> rLhsConst, rRhsConst;
575 rLhsConst = rLhsConstExpr.
getValue();
576 firstExpr = lBinOpExpr.getLHS();
586 rRhsConst = rRhsConstExpr.
getValue();
587 secondExpr = rBinOpExpr.getLHS();
593 if (rLhsConst && rRhsConst && firstExpr == secondExpr)
602 return lBin.getLHS() + rhs + lrhs;
615 auto lrhs = rBinOpExpr.getLHS();
616 auto rrhs = rBinOpExpr.getRHS();
625 if (rrhsConstOpExpr && rrhsConstOpExpr.getValue() == -1 && lrhsBinOpExpr &&
628 llrhs = lrhsBinOpExpr.getLHS();
630 rlrhs = lrhsBinOpExpr.getRHS();
634 if (llrhsBinOpExpr.getRHS() == rlrhs && lhs == llrhsBinOpExpr.getLHS())
643 llrhs = lrBinOpExpr.
getLHS();
644 rlrhs = lrBinOpExpr.
getRHS();
646 if (lhs == llrhs && rlrhs == -rrhs) {
669 if (lhsConst && rhsConst)
705 return (lBin.getLHS() * rhs) * lrhs;
732 return *
this + (-other);
740 if (!rhsConst || rhsConst.
getValue() < 1)
758 if (lrhs.getValue() % rhsConst.
getValue() == 0)
767 int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
769 if (llhsDiv % rhsConst.
getValue() == 0 ||
771 return lBin.getLHS().floorDiv(rhsConst.
getValue()) +
772 lBin.getRHS().floorDiv(rhsConst.
getValue());
795 if (!rhsConst || rhsConst.
getValue() < 1)
813 if (lrhs.getValue() % rhsConst.
getValue() == 0)
839 if (!rhsConst || rhsConst.
getValue() < 1)
857 int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
859 if (llhsDiv % rhsConst.
getValue() == 0)
860 return lBin.getRHS() % rhsConst.
getValue();
861 if (lrhsDiv % rhsConst.
getValue() == 0)
862 return lBin.getLHS() % rhsConst.
getValue();
868 if (intermediate && intermediate.getValue() >= 1 &&
869 mod(intermediate.getValue(), rhsConst.
getValue()) == 0) {
870 return lBin.getLHS() % rhsConst.
getValue();
910 assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
911 "unexpected number of local expressions");
915 for (
unsigned j = 0;
j < numDims + numSymbols;
j++) {
916 if (flatExprs[
j] == 0)
920 expr = expr +
id * flatExprs[
j];
924 for (
unsigned j = numDims + numSymbols, e = flatExprs.size() - 1;
j < e;
926 if (flatExprs[
j] == 0)
928 auto term = localExprs[
j - numDims - numSymbols] * flatExprs[
j];
933 int64_t constTerm = flatExprs[flatExprs.size() - 1];
935 expr = expr + constTerm;
952 assert(!flatExprs.empty() &&
"flatExprs cannot be empty");
955 assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
956 "unexpected number of local expressions");
988 auto addEntry = [&](std::pair<unsigned, signed> index, int64_t coefficient,
990 assert(!llvm::is_contained(indices, index) &&
991 "Key is already present in indices vector and overwriting will "
992 "happen in `indexToExprMap` and `coefficients`!");
994 indices.push_back(index);
995 coefficients.insert({index, coefficient});
996 indexToExprMap.insert({index, expr});
1004 unsigned offsetSym = 0;
1005 signed offsetDim = -1;
1006 for (
unsigned j = numDims;
j < numDims + numSymbols; ++
j) {
1007 if (flatExprs[
j] == 0)
1013 std::pair<unsigned, signed> indexEntry(
1014 j - numDims,
std::max(numDims, numSymbols) + offsetSym++);
1015 addEntry(indexEntry, flatExprs[
j],
1023 unsigned lhsPos, rhsPos;
1030 if (flatExprs[numDims + numSymbols + it.index()] == 0)
1046 std::pair<unsigned, signed> indexEntry(lhsPos, offsetDim--);
1047 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1051 std::pair<unsigned, signed> indexEntry(
1052 lhsPos,
std::max(numDims, numSymbols) + offsetSym++);
1053 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1064 std::pair<unsigned, signed> indexEntry(lhsPos, rhsPos);
1065 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1073 std::pair<unsigned, signed> indexEntry(
1074 lhsPos,
std::max(numDims, numSymbols) + offsetSym++);
1075 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1077 addedToMap[it.index()] =
true;
1080 for (
unsigned j = 0;
j < numDims; ++
j) {
1081 if (flatExprs[
j] == 0)
1087 std::pair<unsigned, signed> indexEntry(
j, offsetDim--);
1094 llvm::sort(indices);
1095 for (
const std::pair<unsigned, unsigned> index : indices) {
1096 assert(indexToExprMap.lookup(index) &&
1097 "cannot find key in `indexToExprMap` map");
1098 expr = expr + indexToExprMap.lookup(index) * coefficients.lookup(index);
1102 for (
unsigned j = numDims + numSymbols, e = flatExprs.size() - 1;
j < e;
1106 if (flatExprs[
j] == 0 || addedToMap[
j - numDims - numSymbols])
1108 auto term = localExprs[
j - numDims - numSymbols] * flatExprs[
j];
1113 int64_t constTerm = flatExprs.back();
1115 expr = expr + constTerm;
1120 unsigned numSymbols)
1121 : numDims(numDims), numSymbols(numSymbols), numLocals(0) {
1145 addLocalVariableSemiAffine(a * b, lhs, lhs.size());
1150 auto rhsConst = rhs[getConstantIndex()];
1151 for (
unsigned i = 0, e = lhs.size(); i < e; i++) {
1160 assert(lhs.size() == rhs.size());
1162 for (
unsigned i = 0, e = rhs.size(); i < e; i++) {
1195 AffineExpr modExpr = dividendExpr % divisorExpr;
1196 addLocalVariableSemiAffine(modExpr, lhs, lhs.size());
1200 int64_t rhsConst = rhs[getConstantIndex()];
1203 assert(rhsConst > 0 &&
"RHS constant has to be positive");
1207 for (i = 0, e = lhs.size(); i < e; i++)
1208 if (lhs[i] % rhsConst != 0)
1211 if (i == lhs.size()) {
1212 std::fill(lhs.begin(), lhs.end(), 0);
1220 uint64_t
gcd = rhsConst;
1221 for (
unsigned i = 0, e = lhs.size(); i < e; i++)
1225 for (
unsigned i = 0, e = floorDividend.size(); i < e; i++)
1226 floorDividend[i] = floorDividend[i] /
static_cast<int64_t
>(
gcd);
1228 int64_t floorDivisor = rhsConst /
static_cast<int64_t
>(
gcd);
1237 if ((loc = findLocalId(floorDivExpr)) == -1) {
1240 lhs[getLocalVarStartIndex() +
numLocals - 1] = -rhsConst;
1243 lhs[getLocalVarStartIndex() + loc] = -rhsConst;
1248 visitDivExpr(expr,
true);
1251 visitDivExpr(expr,
false);
1265 eq[getSymbolStartIndex() + expr.
getPosition()] = 1;
1271 eq[getConstantIndex()] = expr.
getValue();
1274 void SimpleAffineExprFlattener::addLocalVariableSemiAffine(
1276 unsigned long resultSize) {
1277 assert(result.size() == resultSize &&
1278 "`result` vector passed is not of correct size");
1280 if ((loc = findLocalId(expr)) == -1)
1282 std::fill(result.begin(), result.end(), 0);
1284 result[getLocalVarStartIndex() +
numLocals - 1] = 1;
1286 result[getLocalVarStartIndex() + loc] = 1;
1320 addLocalVariableSemiAffine(divExpr, lhs, lhs.size());
1325 int64_t rhsConst = rhs[getConstantIndex()];
1328 assert(rhsConst > 0 &&
"RHS constant has to be positive");
1333 for (
unsigned i = 0, e = lhs.size(); i < e; i++)
1337 for (
unsigned i = 0, e = lhs.size(); i < e; i++)
1338 lhs[i] = lhs[i] /
static_cast<int64_t
>(
gcd);
1340 int64_t divisor = rhsConst /
static_cast<int64_t
>(
gcd);
1356 if ((loc = findLocalId(divExpr)) == -1) {
1363 dividend.back() += divisor - 1;
1369 std::fill(lhs.begin(), lhs.end(), 0);
1371 lhs[getLocalVarStartIndex() +
numLocals - 1] = 1;
1373 lhs[getLocalVarStartIndex() + loc] = 1;
1384 assert(divisor > 0 &&
"positive constant divisor expected");
1386 subExpr.insert(subExpr.begin() + getLocalVarStartIndex() +
numLocals, 0);
1394 subExpr.insert(subExpr.begin() + getLocalVarStartIndex() +
numLocals, 0);
1399 int SimpleAffineExprFlattener::findLocalId(
AffineExpr localExpr) {
1408 unsigned numSymbols) {
1431 return simplifiedExpr;
static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos, AffineExprKind opKind)
Divides the given expression by the given symbol at position symbolPos.
static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs)
Simplify a multiply expression. Return nullptr if it can't be simplified.
static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs)
static AffineExpr simplifySemiAffine(AffineExpr expr)
Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv operations when the second...
static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs)
Simplify add expression. Return nullptr if it can't be simplified.
static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef< int64_t > flatExprs, unsigned numDims, unsigned numSymbols, ArrayRef< AffineExpr > localExprs, MLIRContext *context)
Constructs a semi-affine expression from a flat ArrayRef.
static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs)
static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs)
static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position, MLIRContext *context)
static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos, AffineExprKind opKind)
Returns true if the expression is divisible by the given symbol with position symbolPos.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Affine binary operation expression.
AffineExpr getLHS() const
AffineBinaryOpExpr(AffineExpr::ImplType *ptr)
AffineExpr getRHS() const
An integer constant appearing in affine expression.
AffineConstantExpr(AffineExpr::ImplType *ptr=nullptr)
A dimensional identifier appearing in an affine expression.
AffineDimExpr(AffineExpr::ImplType *ptr)
unsigned getPosition() const
Base class for AffineExpr visitors/walkers.
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
AffineExpr replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements) const
This method substitutes any uses of dimensions and symbols (e.g.
AffineExpr shiftDims(unsigned numDims, unsigned shift, unsigned offset=0) const
Replace dims[offset ...
AffineExpr operator+(int64_t v) const
bool isSymbolicOrConstant() const
Returns true if this expression is made out of only symbols and constants, i.e., it does not involve ...
AffineExpr operator*(int64_t v) const
bool operator==(AffineExpr other) const
bool isPureAffine() const
Returns true if this is a pure affine expression, i.e., multiplication, floordiv, ceildiv,...
AffineExpr shiftSymbols(unsigned numSymbols, unsigned shift, unsigned offset=0) const
Replace symbols[offset ...
AffineExpr operator-() const
void walk(std::function< void(AffineExpr)> callback) const
Walk all of the AffineExpr's in this expression in postorder.
AffineExpr floorDiv(uint64_t v) const
AffineExprKind getKind() const
Return the classification for this type.
bool isMultipleOf(int64_t factor) const
Return true if the affine expression is a multiple of 'factor'.
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
AffineExpr compose(AffineMap map) const
Compose with an AffineMap.
constexpr bool isa() const
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
bool isFunctionOfSymbol(unsigned position) const
Return true if the affine expression involves AffineSymbolExpr position.
AffineExpr replaceDims(ArrayRef< AffineExpr > dimReplacements) const
Dim-only version of replaceDimsAndSymbols.
AffineExpr operator%(uint64_t v) const
MLIRContext * getContext() const
AffineExpr replace(AffineExpr expr, AffineExpr replacement) const
Sparse replace method.
AffineExpr replaceSymbols(ArrayRef< AffineExpr > symReplacements) const
Symbol-only version of replaceDimsAndSymbols.
AffineExpr ceilDiv(uint64_t v) const
void print(raw_ostream &os) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
ArrayRef< AffineExpr > getResults() const
A symbolic identifier appearing in an affine expression.
AffineSymbolExpr(AffineExpr::ImplType *ptr)
unsigned getPosition() const
MLIRContext is the top-level object for a collection of MLIR operations.
StorageUniquer & getAffineUniquer()
Returns the storage uniquer used for creating affine constructs.
virtual void addLocalFloorDivId(ArrayRef< int64_t > dividend, int64_t divisor, AffineExpr localExpr)
void visitFloorDivExpr(AffineBinaryOpExpr expr)
void visitAddExpr(AffineBinaryOpExpr expr)
std::vector< SmallVector< int64_t, 8 > > operandExprStack
void visitDimExpr(AffineDimExpr expr)
void visitConstantExpr(AffineConstantExpr expr)
void visitSymbolExpr(AffineSymbolExpr expr)
virtual void addLocalIdSemiAffine(AffineExpr localExpr)
Add a local identifier (needed to flatten a mod, floordiv, ceildiv, mul expr) when the rhs is a symbo...
SmallVector< AffineExpr, 4 > localExprs
void visitCeilDivExpr(AffineBinaryOpExpr expr)
void visitModExpr(AffineBinaryOpExpr expr)
SimpleAffineExprFlattener(unsigned numDims, unsigned numSymbols)
void visitMulExpr(AffineBinaryOpExpr expr)
A utility class to get or create instances of "storage classes".
Storage * get(function_ref< void(Storage *)> initFn, TypeID id, Args &&...args)
Gets a uniqued instance of 'Storage'.
Detect if any of the given parameter types has a sub-element handler.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt gcd(const MPInt &a, const MPInt &b)
LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt abs(const MPInt &x)
This header declares functions that assist transformations in the MemRef dialect.
int64_t floorDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's floordiv operation on constants.
int64_t ceilDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's ceildiv operation on constants.
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ Constant
Constant integer.
@ SymbolId
Symbolic identifier.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
AffineExpr getAffineExprFromFlatForm(ArrayRef< int64_t > flatExprs, unsigned numDims, unsigned numSymbols, ArrayRef< AffineExpr > localExprs, MLIRContext *context)
Constructs an affine expression from a flat ArrayRef.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
int64_t mod(int64_t lhs, int64_t rhs)
Returns MLIR's mod operation on constants.
A binary operation appearing in an affine expression.
An integer constant appearing in affine expression.
A dimensional or symbolic identifier appearing in an affine expression.
Base storage class appearing in an affine expression.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.