18#include "llvm/ADT/STLExtras.h"
19#include "llvm/ADT/SmallVectorExtras.h"
20#include "llvm/Support/MathExtras.h"
27using llvm::divideCeilSigned;
28using llvm::divideFloorSigned;
29using llvm::divideSignedWouldOverflow;
39template <
typename WalkRetTy>
42 struct AffineExprWalker
47 : callback(callback) {}
50 return callback(expr);
52 WalkRetTy visitConstantExpr(AffineConstantExpr expr) {
53 return callback(expr);
55 WalkRetTy visitDimExpr(AffineDimExpr expr) {
return callback(expr); }
56 WalkRetTy visitSymbolExpr(AffineSymbolExpr expr) {
return callback(expr); }
59 return AffineExprWalker(callback).walkPostOrder(e);
82 llvm_unreachable(
"unknown binary operation on affine expressions");
94 unsigned dimId = llvm::cast<AffineDimExpr>(*this).getPosition();
95 if (dimId >= dimReplacements.size())
97 return dimReplacements[dimId];
100 unsigned symId = llvm::cast<AffineSymbolExpr>(*this).getPosition();
101 if (symId >= symReplacements.size())
103 return symReplacements[symId];
110 auto binOp = llvm::cast<AffineBinaryOpExpr>(*
this);
111 auto lhs = binOp.getLHS(),
rhs = binOp.getRHS();
112 auto newLHS =
lhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
113 auto newRHS =
rhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
114 if (newLHS ==
lhs && newRHS ==
rhs)
118 llvm_unreachable(
"Unknown AffineExpr");
133 unsigned offset)
const {
135 for (
unsigned idx = 0; idx < offset; ++idx)
137 for (
unsigned idx = offset; idx < numDims; ++idx)
145 unsigned offset)
const {
147 for (
unsigned idx = 0; idx < offset; ++idx)
149 for (
unsigned idx = offset; idx < numSymbols; ++idx)
157 auto it = map.find(*
this);
168 auto binOp = llvm::cast<AffineBinaryOpExpr>(*
this);
169 auto lhs = binOp.getLHS(),
rhs = binOp.getRHS();
170 auto newLHS =
lhs.replace(map);
171 auto newRHS =
rhs.replace(map);
172 if (newLHS ==
lhs && newRHS ==
rhs)
176 llvm_unreachable(
"Unknown AffineExpr");
201 auto expr = llvm::cast<AffineBinaryOpExpr>(*
this);
202 return expr.getLHS().isSymbolicOrConstant() &&
203 expr.getRHS().isSymbolicOrConstant();
206 llvm_unreachable(
"Unknown AffineExpr");
218 auto op = llvm::cast<AffineBinaryOpExpr>(*
this);
219 return op.getLHS().isPureAffine() && op.getRHS().isPureAffine();
225 auto op = llvm::cast<AffineBinaryOpExpr>(*
this);
226 return op.getLHS().isPureAffine() && op.getRHS().isPureAffine() &&
227 (llvm::isa<AffineConstantExpr>(op.getLHS()) ||
228 llvm::isa<AffineConstantExpr>(op.getRHS()));
233 auto op = llvm::cast<AffineBinaryOpExpr>(*
this);
234 return op.getLHS().isPureAffine() &&
235 llvm::isa<AffineConstantExpr>(op.getRHS());
238 llvm_unreachable(
"Unknown AffineExpr");
254 binExpr = llvm::cast<AffineBinaryOpExpr>(*
this);
255 auto rhs = llvm::dyn_cast<AffineConstantExpr>(binExpr.
getRHS());
257 if (
rhs &&
rhs.getValue() != 0) {
259 if (lhsDiv %
rhs.getValue() == 0)
260 return std::abs(lhsDiv /
rhs.getValue());
265 return std::abs(llvm::cast<AffineConstantExpr>(*this).getValue());
267 binExpr = llvm::cast<AffineBinaryOpExpr>(*
this);
274 binExpr = llvm::cast<AffineBinaryOpExpr>(*
this);
279 llvm_unreachable(
"Unknown AffineExpr");
289 return factor * factor == 1;
291 return llvm::cast<AffineConstantExpr>(*this).getValue() % factor == 0;
293 binExpr = llvm::cast<AffineBinaryOpExpr>(*
this);
299 (l * u) % factor == 0;
305 binExpr = llvm::cast<AffineBinaryOpExpr>(*
this);
312 llvm_unreachable(
"Unknown AffineExpr");
319 if (
auto expr = llvm::dyn_cast<AffineBinaryOpExpr>(*
this)) {
320 return expr.getLHS().isFunctionOfDim(position) ||
321 expr.getRHS().isFunctionOfDim(position);
330 if (
auto expr = llvm::dyn_cast<AffineBinaryOpExpr>(*
this)) {
331 return expr.getLHS().isFunctionOfSymbol(position) ||
332 expr.getRHS().isFunctionOfSymbol(position);
358static bool canSimplifyDivisionBySymbol(
AffineExpr expr,
unsigned symbolPos,
360 bool fromMul =
false) {
364 "unexpected opKind");
367 return cast<AffineConstantExpr>(expr).getValue() == 0;
371 return (cast<AffineSymbolExpr>(expr).getPosition() == symbolPos);
375 return canSimplifyDivisionBySymbol(binaryExpr.
getLHS(), symbolPos,
377 canSimplifyDivisionBySymbol(binaryExpr.
getRHS(), symbolPos, opKind);
386 return canSimplifyDivisionBySymbol(binaryExpr.
getLHS(), symbolPos,
388 canSimplifyDivisionBySymbol(binaryExpr.
getRHS(), symbolPos,
394 return canSimplifyDivisionBySymbol(binaryExpr.
getLHS(), symbolPos, opKind,
396 canSimplifyDivisionBySymbol(binaryExpr.
getRHS(), symbolPos, opKind,
416 return canSimplifyDivisionBySymbol(binaryExpr.
getLHS(), symbolPos,
420 llvm_unreachable(
"Unknown AffineExpr");
431 "unexpected opKind");
434 if (cast<AffineConstantExpr>(expr).getValue() != 0)
445 expr.
getKind(), symbolicDivide(binaryExpr.
getLHS(), symbolPos, opKind),
446 symbolicDivide(binaryExpr.
getRHS(), symbolPos, opKind));
453 symbolicDivide(binaryExpr.
getLHS(), symbolPos, expr.
getKind()),
454 symbolicDivide(binaryExpr.
getRHS(), symbolPos, expr.
getKind()));
459 if (!canSimplifyDivisionBySymbol(binaryExpr.
getLHS(), symbolPos, opKind))
460 return binaryExpr.
getLHS() *
461 symbolicDivide(binaryExpr.
getRHS(), symbolPos, opKind);
462 return symbolicDivide(binaryExpr.
getLHS(), symbolPos, opKind) *
471 symbolicDivide(binaryExpr.
getLHS(), symbolPos, expr.
getKind()),
475 llvm_unreachable(
"Unknown AffineExpr");
483 auto addExpr = dyn_cast<AffineBinaryOpExpr>(expr);
488 getSummandExprs(addExpr.getLHS(),
result);
489 getSummandExprs(addExpr.getRHS(),
result);
495 auto mulExpr = dyn_cast<AffineBinaryOpExpr>(candidate);
498 if (
auto lhs = dyn_cast<AffineConstantExpr>(mulExpr.getLHS())) {
499 if (
lhs.getValue() == -1) {
500 expr = mulExpr.getRHS();
504 if (
auto rhs = dyn_cast<AffineConstantExpr>(mulExpr.getRHS())) {
505 if (
rhs.getValue() == -1) {
506 expr = mulExpr.getLHS();
522 unsigned numDims,
unsigned numSymbols) {
526 getSummandExprs(
lhs, summands);
529 for (int64_t i = 0, e = summands.size(); i < e; ++i) {
532 if (!isNegatedAffineExpr(current, beforeNegation))
542 for (int64_t j = 0; j < e; ++j)
544 diff = diff + summands[j];
545 diff = diff - innerMod.
getLHS();
547 auto constExpr = dyn_cast<AffineConstantExpr>(diff);
548 if (constExpr && constExpr.getValue() == 0)
560 unsigned numSymbols) {
571 simplifySemiAffine(binaryExpr.
getLHS(), numDims, numSymbols),
572 simplifySemiAffine(binaryExpr.
getRHS(), numDims, numSymbols));
584 simplifySemiAffine(binaryExpr.
getLHS(), numDims, numSymbols);
586 simplifySemiAffine(binaryExpr.
getRHS(), numDims, numSymbols);
587 if (isModOfModSubtraction(sLHS, sRHS, numDims, numSymbols))
590 simplifySemiAffine(binaryExpr.
getRHS(), numDims, numSymbols));
594 if (!canSimplifyDivisionBySymbol(binaryExpr.
getLHS(), symbolPos,
600 symbolicDivide(sLHS, symbolPos, expr.
getKind());
601 return simplifiedQuotient
606 llvm_unreachable(
"Unknown AffineExpr");
611 auto assignCtx = [context](AffineDimExprStorage *storage) {
612 storage->context = context;
616 return uniquer.
get<AffineDimExprStorage>(
617 assignCtx,
static_cast<unsigned>(kind), position);
645 auto assignCtx = [context](AffineConstantExprStorage *storage) {
646 storage->context = context;
650 return uniquer.
get<AffineConstantExprStorage>(assignCtx, constant);
656 return llvm::map_to_vector(constants, [&](int64_t constant) {
663 auto lhsConst = dyn_cast<AffineConstantExpr>(
lhs);
664 auto rhsConst = dyn_cast<AffineConstantExpr>(
rhs);
666 if (lhsConst && rhsConst) {
668 if (llvm::AddOverflow(lhsConst.getValue(), rhsConst.getValue(), sum)) {
676 if (isa<AffineConstantExpr>(
lhs) ||
677 (
lhs.isSymbolicOrConstant() && !
rhs.isSymbolicOrConstant())) {
685 if (rhsConst.getValue() == 0)
689 auto lBin = dyn_cast<AffineBinaryOpExpr>(
lhs);
691 if (
auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS()))
692 return lBin.getLHS() + (lrhs.getValue() + rhsConst.getValue());
698 std::optional<int64_t> rLhsConst, rRhsConst;
701 auto lBinOpExpr = dyn_cast<AffineBinaryOpExpr>(
lhs);
703 (rLhsConstExpr = dyn_cast<AffineConstantExpr>(lBinOpExpr.getRHS()))) {
704 rLhsConst = rLhsConstExpr.
getValue();
705 firstExpr = lBinOpExpr.getLHS();
711 auto rBinOpExpr = dyn_cast<AffineBinaryOpExpr>(
rhs);
714 (rRhsConstExpr = dyn_cast<AffineConstantExpr>(rBinOpExpr.getRHS()))) {
715 rRhsConst = rRhsConstExpr.
getValue();
716 secondExpr = rBinOpExpr.getLHS();
722 if (rLhsConst && rRhsConst && firstExpr == secondExpr)
730 if (
auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) {
731 return lBin.getLHS() +
rhs + lrhs;
744 auto lrhs = rBinOpExpr.getLHS();
745 auto rrhs = rBinOpExpr.getRHS();
751 auto lrhsBinOpExpr = dyn_cast<AffineBinaryOpExpr>(lrhs);
753 auto rrhsConstOpExpr = dyn_cast<AffineConstantExpr>(rrhs);
754 if (rrhsConstOpExpr && rrhsConstOpExpr.getValue() == -1 && lrhsBinOpExpr &&
757 llrhs = lrhsBinOpExpr.getLHS();
759 rlrhs = lrhsBinOpExpr.getRHS();
760 auto llrhsBinOpExpr = dyn_cast<AffineBinaryOpExpr>(llrhs);
763 if (llrhsBinOpExpr.getRHS() == rlrhs &&
lhs == llrhsBinOpExpr.getLHS())
774 llrhs = lrBinOpExpr.
getLHS();
775 rlrhs = lrBinOpExpr.
getRHS();
776 auto rlrhsConstOpExpr = dyn_cast<AffineConstantExpr>(rlrhs);
778 bool isPositiveRhs = rlrhsConstOpExpr && rlrhsConstOpExpr.getValue() > 0;
780 if (isPositiveRhs &&
lhs == llrhs && rlrhs == -rrhs) {
789 if (
auto simplified = simplifyAdd(lBinOpExpr.getRHS(),
rhs))
790 return lBinOpExpr.getLHS() + simplified;
796static std::pair<AffineExpr, AffineExpr>
798 auto sym1 = dyn_cast<AffineSymbolExpr>(expr1);
799 auto sym2 = dyn_cast<AffineSymbolExpr>(expr2);
802 return sym1.getPosition() < sym2.getPosition() ? std::pair{expr1, expr2}
803 : std::pair{expr2, expr1};
805 auto dim1 = dyn_cast<AffineDimExpr>(expr1);
806 auto dim2 = dyn_cast<AffineDimExpr>(expr2);
808 return dim1.getPosition() < dim2.getPosition() ? std::pair{expr1, expr2}
809 : std::pair{expr2, expr1};
819 return {expr1, expr2};
826 if (
auto simplified = simplifyAdd(*
this, other))
829 auto [
lhs,
rhs] = orderCommutativeArgs(*
this, other);
832 return uniquer.
get<AffineBinaryOpExprStorage>(
838 auto lhsConst = dyn_cast<AffineConstantExpr>(
lhs);
839 auto rhsConst = dyn_cast<AffineConstantExpr>(
rhs);
841 if (lhsConst && rhsConst) {
843 if (llvm::MulOverflow(lhsConst.getValue(), rhsConst.getValue(),
product)) {
849 if (!
lhs.isSymbolicOrConstant() && !
rhs.isSymbolicOrConstant())
855 if (!
rhs.isSymbolicOrConstant() || isa<AffineConstantExpr>(
lhs)) {
864 if (rhsConst.getValue() == 1)
867 if (rhsConst.getValue() == 0)
872 auto lBin = dyn_cast<AffineBinaryOpExpr>(
lhs);
874 if (
auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS()))
875 return lBin.getLHS() * (lrhs.getValue() * rhsConst.getValue());
881 if (
auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) {
882 return (lBin.getLHS() *
rhs) * lrhs;
896 auto [
lhs,
rhs] = orderCommutativeArgs(*
this, other);
911 return *
this + (-other);
915 auto lhsConst = dyn_cast<AffineConstantExpr>(
lhs);
916 auto rhsConst = dyn_cast<AffineConstantExpr>(
rhs);
918 if (!rhsConst || rhsConst.getValue() == 0)
922 if (divideSignedWouldOverflow(lhsConst.getValue(), rhsConst.getValue()))
925 divideFloorSigned(lhsConst.getValue(), rhsConst.getValue()),
936 auto lBin = dyn_cast<AffineBinaryOpExpr>(
lhs);
938 if (
auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) {
940 if (lrhs.getValue() % rhsConst.getValue() == 0)
941 return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
948 int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
949 int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
951 if (llhsDiv % rhsConst.getValue() == 0 ||
952 lrhsDiv % rhsConst.getValue() == 0)
953 return lBin.getLHS().floorDiv(rhsConst.getValue()) +
954 lBin.getRHS().floorDiv(rhsConst.getValue());
974 auto lhsConst = dyn_cast<AffineConstantExpr>(
lhs);
975 auto rhsConst = dyn_cast<AffineConstantExpr>(
rhs);
977 if (!rhsConst || rhsConst.getValue() == 0)
981 if (divideSignedWouldOverflow(lhsConst.getValue(), rhsConst.getValue()))
984 divideCeilSigned(lhsConst.getValue(), rhsConst.getValue()),
990 if (rhsConst.getValue() == 1)
995 auto lBin = dyn_cast<AffineBinaryOpExpr>(
lhs);
997 if (
auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) {
999 if (lrhs.getValue() % rhsConst.getValue() == 0)
1000 return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
1021 auto lhsConst = dyn_cast<AffineConstantExpr>(
lhs);
1022 auto rhsConst = dyn_cast<AffineConstantExpr>(
rhs);
1025 if (!rhsConst || rhsConst.getValue() < 1)
1037 if (
lhs.getLargestKnownDivisor() % rhsConst.getValue() == 0)
1042 auto lBin = dyn_cast<AffineBinaryOpExpr>(
lhs);
1044 int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
1045 int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
1047 if (llhsDiv % rhsConst.getValue() == 0)
1048 return lBin.getRHS() % rhsConst.getValue();
1049 if (lrhsDiv % rhsConst.getValue() == 0)
1050 return lBin.getLHS() % rhsConst.getValue();
1055 auto intermediate = dyn_cast<AffineConstantExpr>(lBin.getRHS());
1056 if (intermediate && intermediate.getValue() >= 1 &&
1057 mod(intermediate.getValue(), rhsConst.getValue()) == 0) {
1058 return lBin.getLHS() % rhsConst.getValue();
1093 unsigned numSymbols,
1097 assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
1098 "unexpected number of local expressions");
1102 for (
unsigned j = 0;
j < numDims + numSymbols;
j++) {
1103 if (flatExprs[
j] == 0)
1107 expr = expr +
id * flatExprs[
j];
1111 for (
unsigned j = numDims + numSymbols, e = flatExprs.size() - 1;
j < e;
1113 if (flatExprs[
j] == 0)
1115 auto term = localExprs[
j - numDims - numSymbols] * flatExprs[
j];
1120 int64_t constTerm = flatExprs[flatExprs.size() - 1];
1122 expr = expr + constTerm;
1136 unsigned numSymbols,
1139 assert(!flatExprs.empty() &&
"flatExprs cannot be empty");
1142 assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
1143 "unexpected number of local expressions");
1175 auto addEntry = [&](std::pair<unsigned, signed>
index,
int64_t coefficient,
1178 "Key is already present in indices vector and overwriting will "
1179 "happen in `indexToExprMap` and `coefficients`!");
1182 coefficients.insert({
index, coefficient});
1183 indexToExprMap.insert({
index, expr});
1191 unsigned offsetSym = 0;
1192 signed offsetDim = -1;
1193 for (
unsigned j = numDims;
j < numDims + numSymbols; ++
j) {
1194 if (flatExprs[
j] == 0)
1200 std::pair<unsigned, signed> indexEntry(
1201 j - numDims, std::max(numDims, numSymbols) + offsetSym++);
1202 addEntry(indexEntry, flatExprs[
j],
1210 unsigned lhsPos, rhsPos;
1215 for (
const auto &it : llvm::enumerate(localExprs)) {
1216 if (flatExprs[numDims + numSymbols + it.index()] == 0)
1219 auto binaryExpr = dyn_cast<AffineBinaryOpExpr>(expr);
1225 if (!((isa<AffineDimExpr>(
lhs) || isa<AffineSymbolExpr>(
lhs)) &&
1226 (isa<AffineDimExpr>(
rhs) || isa<AffineSymbolExpr>(
rhs) ||
1227 isa<AffineConstantExpr>(
rhs)))) {
1230 if (isa<AffineConstantExpr>(
rhs)) {
1235 if (isa<AffineDimExpr>(
lhs)) {
1236 lhsPos = cast<AffineDimExpr>(
lhs).getPosition();
1237 std::pair<unsigned, signed> indexEntry(lhsPos, offsetDim--);
1238 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1241 lhsPos = cast<AffineSymbolExpr>(
lhs).getPosition();
1242 std::pair<unsigned, signed> indexEntry(
1243 lhsPos, std::max(numDims, numSymbols) + offsetSym++);
1244 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1247 }
else if (isa<AffineDimExpr>(
lhs)) {
1253 lhsPos = cast<AffineDimExpr>(
lhs).getPosition();
1254 rhsPos = cast<AffineSymbolExpr>(
rhs).getPosition();
1255 std::pair<unsigned, signed> indexEntry(lhsPos, rhsPos);
1256 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1262 lhsPos = cast<AffineSymbolExpr>(
lhs).getPosition();
1263 rhsPos = cast<AffineSymbolExpr>(
rhs).getPosition();
1264 std::pair<unsigned, signed> indexEntry(
1265 lhsPos, std::max(numDims, numSymbols) + offsetSym++);
1266 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1268 addedToMap[it.index()] =
true;
1271 for (
unsigned j = 0;
j < numDims; ++
j) {
1272 if (flatExprs[
j] == 0)
1278 std::pair<unsigned, signed> indexEntry(
j, offsetDim--);
1286 for (
const std::pair<unsigned, unsigned>
index :
indices) {
1287 assert(indexToExprMap.lookup(
index) &&
1288 "cannot find key in `indexToExprMap` map");
1289 expr = expr + indexToExprMap.lookup(
index) * coefficients.lookup(
index);
1293 for (
unsigned j = numDims + numSymbols, e = flatExprs.size() - 1;
j < e;
1297 if (flatExprs[
j] == 0 || addedToMap[
j - numDims - numSymbols])
1299 auto term = localExprs[
j - numDims - numSymbols] * flatExprs[
j];
1304 int64_t constTerm = flatExprs.back();
1306 expr = expr + constTerm;
1330 if (!isa<AffineConstantExpr>(expr.
getRHS())) {
1337 return addLocalVariableSemiAffine(mulLhs,
rhs, a *
b,
lhs,
lhs.size());
1352 assert(
lhs.size() ==
rhs.size());
1354 for (
unsigned i = 0, e =
rhs.size(); i < e; i++) {
1383 if (!isa<AffineConstantExpr>(expr.
getRHS())) {
1389 AffineExpr modExpr = dividendExpr % divisorExpr;
1390 return addLocalVariableSemiAffine(modLhs,
rhs, modExpr,
lhs,
lhs.size());
1399 for (i = 0, e =
lhs.size(); i < e; i++)
1400 if (
lhs[i] % rhsConst != 0)
1403 if (i ==
lhs.size()) {
1412 uint64_t gcd = rhsConst;
1414 gcd = std::gcd(gcd, (uint64_t)std::abs(lhsElt));
1417 for (
int64_t &floorDividendElt : floorDividend)
1418 floorDividendElt = floorDividendElt /
static_cast<int64_t>(gcd);
1429 if ((loc = findLocalId(floorDivExpr)) == -1) {
1432 lhs[getLocalVarStartIndex() +
numLocals - 1] = -rhsConst;
1435 lhs[getLocalVarStartIndex() + loc] -= rhsConst;
1442 return visitDivExpr(expr,
true);
1446 return visitDivExpr(expr,
false);
1462 eq[getSymbolStartIndex() + expr.
getPosition()] = 1;
1470 eq[getConstantIndex()] = expr.
getValue();
1474LogicalResult SimpleAffineExprFlattener::addLocalVariableSemiAffine(
1477 assert(
result.size() == resultSize &&
1478 "`result` vector passed is not of correct size");
1480 if ((loc = findLocalId(localExpr)) == -1) {
1488 result[getLocalVarStartIndex() + loc] = 1;
1517 if (!isa<AffineConstantExpr>(expr.
getRHS())) {
1518 SmallVector<int64_t, 8> divLhs(
lhs);
1524 return addLocalVariableSemiAffine(divLhs,
rhs, divExpr,
lhs,
lhs.size());
1528 int64_t rhsConst =
rhs[getConstantIndex()];
1534 uint64_t gcd = std::abs(rhsConst);
1535 for (int64_t lhsElt :
lhs)
1536 gcd = std::gcd(gcd, (uint64_t)std::abs(lhsElt));
1539 for (int64_t &lhsElt :
lhs)
1540 lhsElt = lhsElt /
static_cast<int64_t
>(gcd);
1542 int64_t divisor = rhsConst /
static_cast<int64_t
>(gcd);
1558 if ((loc = findLocalId(divExpr)) == -1) {
1560 SmallVector<int64_t, 8> dividend(
lhs);
1564 SmallVector<int64_t, 8> dividend(
lhs);
1565 dividend.back() += divisor - 1;
1575 lhs[getLocalVarStartIndex() + loc] = 1;
1587 assert(divisor > 0 &&
"positive constant divisor expected");
1589 subExpr.insert(subExpr.begin() + getLocalVarStartIndex() +
numLocals, 0);
1598 subExpr.insert(subExpr.begin() + getLocalVarStartIndex() +
numLocals, 0);
1605int SimpleAffineExprFlattener::findLocalId(
AffineExpr localExpr) {
1614 unsigned numSymbols) {
1617 expr = simplifySemiAffine(expr, numDims, numSymbols);
1639 return simplifiedExpr;
1643 AffineExpr expr,
unsigned numDims,
unsigned numSymbols,
1644 ArrayRef<std::optional<int64_t>> constLowerBounds,
1645 ArrayRef<std::optional<int64_t>> constUpperBounds,
bool isUpper) {
1647 if (
auto binOpExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
1651 auto rhsConst = dyn_cast<AffineConstantExpr>(binOpExpr.getRHS());
1652 if (!rhsConst || rhsConst.getValue() < 1)
1653 return std::nullopt;
1656 constLowerBounds, constUpperBounds, isUpper);
1658 return std::nullopt;
1659 return divideFloorSigned(*bound, rhsConst.getValue());
1662 auto rhsConst = dyn_cast<AffineConstantExpr>(binOpExpr.getRHS());
1663 if (rhsConst && rhsConst.getValue() >= 1) {
1666 constLowerBounds, constUpperBounds, isUpper);
1668 return std::nullopt;
1669 return divideCeilSigned(*bound, rhsConst.getValue());
1671 return std::nullopt;
1677 auto rhsConst = dyn_cast<AffineConstantExpr>(binOpExpr.getRHS());
1678 if (rhsConst && rhsConst.getValue() >= 1) {
1679 int64_t rhsConstVal = rhsConst.getValue();
1681 constLowerBounds, constUpperBounds,
1685 constLowerBounds, constUpperBounds, isUpper);
1687 divideFloorSigned(*lb, rhsConstVal) ==
1688 divideFloorSigned(*
ub, rhsConstVal))
1689 return isUpper ? mod(*
ub, rhsConstVal) : mod(*lb, rhsConstVal);
1690 return isUpper ? rhsConstVal - 1 : 0;
1698 if (failed(simpleResult))
1699 return std::nullopt;
1704 return std::nullopt;
1708 for (
unsigned i = 0, e = numDims + numSymbols; i < e; ++i) {
1709 if (flattenedExpr[i] > 0) {
1710 auto &constBound = isUpper ? constUpperBounds[i] : constLowerBounds[i];
1712 return std::nullopt;
1713 bound += *constBound * flattenedExpr[i];
1714 }
else if (flattenedExpr[i] < 0) {
1715 auto &constBound = isUpper ? constLowerBounds[i] : constUpperBounds[i];
1717 return std::nullopt;
1718 bound += *constBound * flattenedExpr[i];
1722 bound += flattenedExpr.back();
static int64_t product(ArrayRef< int64_t > vals)
static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs)
Simplify a multiply expression. Return nullptr if it can't be simplified.
static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs)
static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef< int64_t > flatExprs, unsigned numDims, unsigned numSymbols, ArrayRef< AffineExpr > localExprs, MLIRContext *context)
Constructs a semi-affine expression from a flat ArrayRef. If there are local identifiers (neither dim...
static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs)
static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs)
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
Affine binary operation expression.
AffineExpr getLHS() const
AffineBinaryOpExpr(AffineExpr::ImplType *ptr)
detail::AffineBinaryOpExprStorage ImplType
AffineExpr getRHS() const
An integer constant appearing in affine expression.
AffineConstantExpr(AffineExpr::ImplType *ptr=nullptr)
detail::AffineConstantExprStorage ImplType
A dimensional identifier appearing in an affine expression.
AffineDimExpr(AffineExpr::ImplType *ptr)
detail::AffineDimExprStorage ImplType
unsigned getPosition() const
See documentation for AffineExprVisitorBase.
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
AffineExpr replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements) const
This method substitutes any uses of dimensions and symbols (e.g.
AffineExpr shiftDims(unsigned numDims, unsigned shift, unsigned offset=0) const
Replace dims[offset ... numDims) by dims[offset + shift ... shift + numDims).
bool isSymbolicOrConstant() const
Returns true if this expression is made out of only symbols and constants, i.e., it does not involve ...
AffineExpr operator+(int64_t v) const
AffineExpr operator*(int64_t v) const
bool operator==(AffineExpr other) const
bool isPureAffine() const
Returns true if this is a pure affine expression, i.e., multiplication, floordiv, ceildiv,...
AffineExpr shiftSymbols(unsigned numSymbols, unsigned shift, unsigned offset=0) const
Replace symbols[offset ... numSymbols) by symbols[offset + shift ... shift + numSymbols).
AffineExpr operator-() const
AffineExpr floorDiv(uint64_t v) const
RetT walk(FnT &&callback) const
Walk all of the AffineExpr's in this expression in postorder.
AffineExprKind getKind() const
Return the classification for this type.
bool isMultipleOf(int64_t factor) const
Return true if the affine expression is a multiple of 'factor'.
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
AffineExpr compose(AffineMap map) const
Compose with an AffineMap.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
bool isFunctionOfSymbol(unsigned position) const
Return true if the affine expression involves AffineSymbolExpr position.
AffineExpr replaceDims(ArrayRef< AffineExpr > dimReplacements) const
Dim-only version of replaceDimsAndSymbols.
AffineExpr operator%(uint64_t v) const
MLIRContext * getContext() const
AffineExpr replace(AffineExpr expr, AffineExpr replacement) const
Sparse replace method.
AffineExpr replaceSymbols(ArrayRef< AffineExpr > symReplacements) const
Symbol-only version of replaceDimsAndSymbols.
detail::AffineExprStorage ImplType
AffineExpr ceilDiv(uint64_t v) const
void print(raw_ostream &os) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
ArrayRef< AffineExpr > getResults() const
A symbolic identifier appearing in an affine expression.
unsigned getPosition() const
detail::AffineDimExprStorage ImplType
AffineSymbolExpr(AffineExpr::ImplType *ptr)
MLIRContext is the top-level object for a collection of MLIR operations.
StorageUniquer & getAffineUniquer()
Returns the storage uniquer used for creating affine constructs.
virtual void addLocalFloorDivId(ArrayRef< int64_t > dividend, int64_t divisor, AffineExpr localExpr)
LogicalResult visitSymbolExpr(AffineSymbolExpr expr)
std::vector< SmallVector< int64_t, 8 > > operandExprStack
LogicalResult visitDimExpr(AffineDimExpr expr)
LogicalResult visitFloorDivExpr(AffineBinaryOpExpr expr)
LogicalResult visitConstantExpr(AffineConstantExpr expr)
virtual LogicalResult addLocalIdSemiAffine(ArrayRef< int64_t > lhs, ArrayRef< int64_t > rhs, AffineExpr localExpr)
Add a local identifier (needed to flatten a mod, floordiv, ceildiv, mul expr) when the rhs is a symbo...
LogicalResult visitModExpr(AffineBinaryOpExpr expr)
LogicalResult visitAddExpr(AffineBinaryOpExpr expr)
LogicalResult visitCeilDivExpr(AffineBinaryOpExpr expr)
LogicalResult visitMulExpr(AffineBinaryOpExpr expr)
SmallVector< AffineExpr, 4 > localExprs
SimpleAffineExprFlattener(unsigned numDims, unsigned numSymbols)
A utility class to get or create instances of "storage classes".
Storage * get(function_ref< void(Storage *)> initFn, TypeID id, Args &&...args)
Gets a uniqued instance of 'Storage'.
A utility result that is used to signal how to proceed with an ongoing walk:
Include the generated interface declarations.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
std::optional< int64_t > getBoundForAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols, ArrayRef< std::optional< int64_t > > constLowerBounds, ArrayRef< std::optional< int64_t > > constUpperBounds, bool isUpper)
Get a lower or upper (depending on isUpper) bound for expr while using the constant lower and upper b...
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ Constant
Constant integer.
@ SymbolId
Symbolic identifier.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
AffineExpr getAffineExprFromFlatForm(ArrayRef< int64_t > flatExprs, unsigned numDims, unsigned numSymbols, ArrayRef< AffineExpr > localExprs, MLIRContext *context)
Constructs an affine expression from a flat ArrayRef.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
SmallVector< AffineExpr > getAffineConstantExprs(ArrayRef< int64_t > constants, MLIRContext *context)
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
llvm::function_ref< Fn > function_ref
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
A binary operation appearing in an affine expression.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.