25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/SmallBitVector.h"
27#include "llvm/ADT/SmallVectorExtras.h"
28#include "llvm/ADT/TypeSwitch.h"
29#include "llvm/Support/DebugLog.h"
30#include "llvm/Support/LogicalResult.h"
31#include "llvm/Support/MathExtras.h"
38using llvm::divideCeilSigned;
39using llvm::divideFloorSigned;
42#define DEBUG_TYPE "affine-ops"
44#include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
51 if (
auto arg = dyn_cast<BlockArgument>(value))
52 return arg.getParentRegion() == region;
75 if (llvm::isa<BlockArgument>(value))
76 return legalityCheck(mapping.
lookup(value), dest);
83 bool isDimLikeOp = isa<ShapedDimOpInterface>(value.
getDefiningOp());
94 return llvm::all_of(values, [&](
Value v) {
101template <
typename OpTy>
104 static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
105 AffineWriteOpInterface>::value,
106 "only ops with affine read/write interface are supported");
113 dimOperands, src, dest, mapping,
117 symbolOperands, src, dest, mapping,
134 op.getMapOperands(), src, dest, mapping,
139 op.getMapOperands(), src, dest, mapping,
150struct AffineInlinerInterface :
public DialectInlinerInterface {
151 using DialectInlinerInterface::DialectInlinerInterface;
162 IRMapping &valueMapping)
const final {
166 if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
177 for (Operation &op : srcBlock) {
179 if (
auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
180 if (iface.hasNoEffect())
187 llvm::TypeSwitch<Operation *, bool>(&op)
188 .Case<AffineApplyOp, AffineReadOpInterface,
189 AffineWriteOpInterface>([&](
auto op) {
192 .Default([](Operation *) {
206 bool isLegalToInline(Operation *op, Region *region,
bool wouldBeCloned,
207 IRMapping &valueMapping)
const final {
212 Operation *parentOp = region->getParentOp();
213 return parentOp->
hasTrait<OpTrait::AffineScope>() ||
214 isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
218 bool shouldAnalyzeRecursively(Operation *op)
const final {
return true; }
226void AffineDialect::initialize() {
229#include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
231 addInterfaces<AffineInlinerInterface>();
232 declarePromisedInterfaces<ValueBoundsOpInterface, AffineApplyOp, AffineMaxOp,
241 if (
auto poison = dyn_cast<ub::PoisonAttr>(value))
242 return ub::PoisonOp::create(builder, loc, type, poison);
243 return arith::ConstantOp::materialize(builder, value, type, loc);
251 if (
auto arg = dyn_cast<BlockArgument>(value)) {
267 while (
auto *parentOp = curOp->getParentOp()) {
269 return curOp->getParentRegion();
278 if (!isa<AffineForOp, AffineIfOp, AffineParallelOp>(parentOp))
303 auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
331 if (
auto applyOp = dyn_cast<AffineApplyOp>(op))
332 return applyOp.isValidDim(region);
335 if (isa<AffineDelinearizeIndexOp, AffineLinearizeIndexOp>(op))
336 return llvm::all_of(op->getOperands(),
337 [&](
Value arg) { return ::isValidDim(arg, region); });
340 if (
auto dimOp = dyn_cast<ShapedDimOpInterface>(op))
348template <
typename AnyMemRefDefOp>
351 MemRefType memRefType = memrefDefOp.getType();
354 if (
index >= memRefType.getRank()) {
359 if (!memRefType.isDynamicDim(
index))
362 unsigned dynamicDimPos = memRefType.getDynamicDimIndex(
index);
363 return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
375 if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
383 if (!
index.has_value())
387 Operation *op = dimOp.getShapedValue().getDefiningOp();
388 while (
auto castOp = dyn_cast<memref::CastOp>(op)) {
390 if (isa<UnrankedMemRefType>(castOp.getSource().getType()))
392 op = castOp.getSource().getDefiningOp();
399 .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
401 .Default([](
Operation *) {
return false; });
435 if (parentRegion == region)
476 if (
isPure(defOp) && llvm::all_of(defOp->getOperands(), [&](
Value operand) {
477 return affine::isValidSymbol(operand, region);
483 if (
auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp))
501 printer <<
'(' << operands.take_front(numDims) <<
')';
502 if (operands.size() > numDims)
503 printer <<
'[' << operands.drop_front(numDims) <<
']';
513 numDims = opInfos.size();
527template <
typename OpTy>
532 for (
auto operand : operands) {
533 if (opIt++ < numDims) {
535 return op.emitOpError(
"operand cannot be used as a dimension id");
537 return op.emitOpError(
"operand cannot be used as a symbol");
548 return AffineValueMap(getAffineMap(), getOperands(), getResult());
555 AffineMapAttr mapAttr;
561 auto map = mapAttr.getValue();
563 if (map.getNumDims() != numDims ||
564 numDims + map.getNumSymbols() !=
result.operands.size()) {
566 "dimension or symbol index mismatch");
569 result.types.append(map.getNumResults(), indexTy);
574 p <<
" " << getMapAttr();
576 getAffineMap().getNumDims(), p);
580LogicalResult AffineApplyOp::verify() {
587 "operand count and affine map dimension and symbol count must match");
591 return emitOpError(
"mapping must produce one value");
597 for (
Value operand : getMapOperands().drop_front(affineMap.
getNumDims())) {
599 return emitError(
"dimensional operand cannot be used as a symbol");
607bool AffineApplyOp::isValidDim() {
608 return llvm::all_of(getOperands(),
615bool AffineApplyOp::isValidDim(
Region *region) {
616 return llvm::all_of(getOperands(),
617 [&](
Value op) { return ::isValidDim(op, region); });
622bool AffineApplyOp::isValidSymbol() {
623 return llvm::all_of(getOperands(),
629bool AffineApplyOp::isValidSymbol(
Region *region) {
630 return llvm::all_of(getOperands(), [&](
Value operand) {
636 auto map = getAffineMap();
639 auto expr = map.getResult(0);
640 if (
auto dim = dyn_cast<AffineDimExpr>(expr))
641 return getOperand(dim.getPosition());
642 if (
auto sym = dyn_cast<AffineSymbolExpr>(expr))
643 return getOperand(map.getNumDims() + sym.getPosition());
647 bool hasPoison =
false;
649 map.constantFold(adaptor.getMapOperands(),
result, &hasPoison);
669 auto dimExpr = dyn_cast<AffineDimExpr>(e);
679 Value operand = operands[dimExpr.getPosition()];
684 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() == 0) {
685 operandDivisor = forOp.getStepAsInt();
687 uint64_t lbLargestKnownDivisor =
688 forOp.getLowerBoundMap().getLargestKnownDivisorOfMapExprs();
689 operandDivisor = std::gcd(lbLargestKnownDivisor, forOp.getStepAsInt());
692 return operandDivisor;
699 if (
auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
700 int64_t constVal = constExpr.getValue();
701 return constVal >= 0 && constVal < k;
703 auto dimExpr = dyn_cast<AffineDimExpr>(e);
706 Value operand = operands[dimExpr.getPosition()];
710 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() >= 0 &&
711 forOp.hasConstantUpperBound() && forOp.getConstantUpperBound() <= k) {
727 auto bin = dyn_cast<AffineBinaryOpExpr>(e);
735 quotientTimesDiv = llhs;
741 quotientTimesDiv = rlhs;
751 if (forOp && forOp.hasConstantLowerBound())
752 return forOp.getConstantLowerBound();
759 if (!forOp || !forOp.hasConstantUpperBound())
764 if (forOp.hasConstantLowerBound()) {
765 return forOp.getConstantUpperBound() - 1 -
766 (forOp.getConstantUpperBound() - forOp.getConstantLowerBound() - 1) %
767 forOp.getStepAsInt();
769 return forOp.getConstantUpperBound() - 1;
780 constLowerBounds.reserve(operands.size());
781 constUpperBounds.reserve(operands.size());
782 for (
Value operand : operands) {
787 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr))
788 return constExpr.getValue();
803 constLowerBounds.reserve(operands.size());
804 constUpperBounds.reserve(operands.size());
805 for (
Value operand : operands) {
810 std::optional<int64_t> lowerBound;
811 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
812 lowerBound = constExpr.getValue();
815 constLowerBounds, constUpperBounds,
826 auto binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
837 binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
845 lhs = binExpr.getLHS();
846 rhs = binExpr.getRHS();
847 auto rhsConst = dyn_cast<AffineConstantExpr>(
rhs);
851 int64_t rhsConstVal = rhsConst.getValue();
853 if (rhsConstVal <= 0)
858 std::optional<int64_t> lhsLbConst =
860 std::optional<int64_t> lhsUbConst =
862 if (lhsLbConst && lhsUbConst) {
863 int64_t lhsLbConstVal = *lhsLbConst;
864 int64_t lhsUbConstVal = *lhsUbConst;
868 divideFloorSigned(lhsLbConstVal, rhsConstVal) ==
869 divideFloorSigned(lhsUbConstVal, rhsConstVal)) {
871 divideFloorSigned(lhsLbConstVal, rhsConstVal), context);
877 divideCeilSigned(lhsLbConstVal, rhsConstVal) ==
878 divideCeilSigned(lhsUbConstVal, rhsConstVal)) {
885 lhsLbConstVal < rhsConstVal && lhsUbConstVal < rhsConstVal) {
898 if (rhsConstVal % divisor == 0 &&
900 expr = quotientTimesDiv.
floorDiv(rhsConst);
901 }
else if (divisor % rhsConstVal == 0 &&
903 expr =
rem % rhsConst;
929 if (operands.empty())
935 constLowerBounds.reserve(operands.size());
936 constUpperBounds.reserve(operands.size());
937 for (
Value operand : operands) {
951 if (
auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
952 lowerBounds.push_back(constExpr.getValue());
953 upperBounds.push_back(constExpr.getValue());
955 lowerBounds.push_back(
957 constLowerBounds, constUpperBounds,
959 upperBounds.push_back(
961 constLowerBounds, constUpperBounds,
968 for (
auto exprEn : llvm::enumerate(map.
getResults())) {
970 unsigned i = exprEn.index();
972 if (lowerBounds[i] && upperBounds[i] && *lowerBounds[i] == *upperBounds[i])
977 if (!upperBounds[i]) {
978 irredundantExprs.push_back(e);
983 if (!llvm::any_of(llvm::enumerate(lowerBounds), [&](
const auto &en) {
984 auto otherLowerBound = en.value();
985 unsigned pos = en.index();
986 if (pos == i || !otherLowerBound)
988 if (*otherLowerBound > *upperBounds[i])
990 if (*otherLowerBound < *upperBounds[i])
995 if (upperBounds[pos] && lowerBounds[i] &&
996 lowerBounds[i] == upperBounds[i] &&
997 otherLowerBound == *upperBounds[pos] && i < pos)
1001 irredundantExprs.push_back(e);
1003 if (!lowerBounds[i]) {
1004 irredundantExprs.push_back(e);
1008 if (!llvm::any_of(llvm::enumerate(upperBounds), [&](
const auto &en) {
1009 auto otherUpperBound = en.value();
1010 unsigned pos = en.index();
1011 if (pos == i || !otherUpperBound)
1013 if (*otherUpperBound < *lowerBounds[i])
1015 if (*otherUpperBound > *lowerBounds[i])
1017 if (lowerBounds[pos] && upperBounds[i] &&
1018 lowerBounds[i] == upperBounds[i] &&
1019 otherUpperBound == lowerBounds[pos] && i < pos)
1023 irredundantExprs.push_back(e);
1037 assert(map.
getNumInputs() == operands.size() &&
"invalid operands for map");
1043 newResults.push_back(expr);
1066 LDBG() <<
"replaceAffineMinBoundingBoxExpression: `" << minOp <<
"`";
1067 AffineMap affineMinMap = minOp.getAffineMap();
1070 for (
unsigned i = 0, e = affineMinMap.
getNumResults(); i < e; ++i) {
1076 minOp.getOperands())))
1084 for (
auto [i, dim] : llvm::enumerate(minOp.getDimOperands())) {
1085 auto it = llvm::find(dims, dim);
1086 if (it == dims.end()) {
1087 unmappedDims.push_back(i);
1093 for (
auto [i, sym] : llvm::enumerate(minOp.getSymbolOperands())) {
1094 auto it = llvm::find(syms, sym);
1095 if (it == syms.end()) {
1096 unmappedSyms.push_back(i);
1109 if (llvm::any_of(unmappedDims,
1110 [&](
unsigned i) {
return expr.isFunctionOfDim(i); }) ||
1111 llvm::any_of(unmappedSyms,
1112 [&](
unsigned i) {
return expr.isFunctionOfSymbol(i); }))
1118 repl[dimOrSym.
ceilDiv(convertedExpr)] = c1;
1120 repl[(dimOrSym + convertedExpr - 1).floorDiv(convertedExpr)] = c1;
1125 return success(*map != initialMap);
1134 AffineExpr e,
const llvm::SmallDenseSet<AffineExpr, 4> &exprsToRemove,
1136 auto binOp = dyn_cast<AffineBinaryOpExpr>(e);
1147 llvm::SmallDenseSet<AffineExpr, 4> ourTracker(exprsToRemove);
1152 if (!ourTracker.erase(thisTerm)) {
1153 toPreserve.push_back(thisTerm);
1157 auto nextBinOp = dyn_cast_if_present<AffineBinaryOpExpr>(nextTerm);
1159 thisTerm = nextTerm;
1162 thisTerm = nextBinOp.getRHS();
1163 nextTerm = nextBinOp.getLHS();
1166 if (!ourTracker.empty())
1171 for (
AffineExpr preserved : llvm::reverse(toPreserve))
1172 newExpr = newExpr + preserved;
1173 replacementsMap.insert({e, newExpr});
1191 AffineDelinearizeIndexOp delinOp,
Value resultToReplace,
AffineMap *map,
1193 if (!delinOp.getDynamicBasis().empty())
1195 if (resultToReplace != delinOp.getMultiIndex().back())
1200 for (
auto [pos, dim] : llvm::enumerate(dims)) {
1201 auto asResult = dyn_cast_if_present<OpResult>(dim);
1204 if (asResult.getOwner() == delinOp.getOperation())
1207 for (
auto [pos, sym] : llvm::enumerate(syms)) {
1208 auto asResult = dyn_cast_if_present<OpResult>(sym);
1211 if (asResult.getOwner() == delinOp.getOperation())
1214 if (llvm::is_contained(resToExpr,
AffineExpr()))
1217 bool isDimReplacement = llvm::all_of(resToExpr, llvm::IsaPred<AffineDimExpr>);
1219 llvm::SmallDenseSet<AffineExpr, 4> expectedExprs;
1222 for (
auto [binding, size] : llvm::zip(
1223 llvm::reverse(resToExpr), llvm::reverse(delinOp.getStaticBasis()))) {
1227 if (resToExpr.size() != delinOp.getStaticBasis().size())
1228 expectedExprs.insert(resToExpr[0] * stride);
1237 if (replacements.empty())
1241 if (isDimReplacement)
1242 dims.push_back(delinOp.getLinearIndex());
1244 syms.push_back(delinOp.getLinearIndex());
1245 *map = origMap.
replace(replacements, dims.size(), syms.size());
1249 if (
auto d = dyn_cast<AffineDimExpr>(e)) {
1250 unsigned pos = d.getPosition();
1252 dims[pos] =
nullptr;
1254 if (
auto s = dyn_cast<AffineSymbolExpr>(e)) {
1255 unsigned pos = s.getPosition();
1257 syms[pos] =
nullptr;
1276 unsigned dimOrSymbolPosition,
1279 bool replaceAffineMin) {
1281 bool isDimReplacement = (dimOrSymbolPosition < dims.size());
1282 unsigned pos = isDimReplacement ? dimOrSymbolPosition
1283 : dimOrSymbolPosition - dims.size();
1284 Value &v = isDimReplacement ? dims[pos] : syms[pos];
1288 if (
auto minOp = v.
getDefiningOp<AffineMinOp>(); minOp && replaceAffineMin) {
1295 if (
auto delinOp = v.
getDefiningOp<affine::AffineDelinearizeIndexOp>()) {
1309 AffineMap composeMap = affineApply.getAffineMap();
1310 assert(composeMap.
getNumResults() == 1 &&
"affine.apply with >1 results");
1312 affineApply.getMapOperands().end());
1326 dims.append(composeDims.begin(), composeDims.end());
1327 syms.append(composeSyms.begin(), composeSyms.end());
1328 *map = map->
replace(toReplace, replacementExpr, dims.size(), syms.size());
1338 bool composeAffineMin =
false) {
1358 for (
unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
1371 unsigned nDims = 0, nSyms = 0;
1373 dimReplacements.reserve(dims.size());
1374 symReplacements.reserve(syms.size());
1375 for (
auto *container : {&dims, &syms}) {
1376 bool isDim = (container == &dims);
1377 auto &repls = isDim ? dimReplacements : symReplacements;
1378 for (
const auto &en : llvm::enumerate(*container)) {
1379 Value v = en.value();
1383 "map is function of unexpected expr@pos");
1389 operands->push_back(v);
1402 while (llvm::any_of(*operands, [](
Value v) {
1408 if (composeAffineMin && llvm::any_of(*operands, [](
Value v) {
1418 bool composeAffineMin) {
1423 return AffineApplyOp::create(
b, loc, map, valueOperands);
1429 bool composeAffineMin) {
1434 operands, composeAffineMin);
1441 bool composeAffineMin =
false) {
1447 for (
unsigned i : llvm::seq<unsigned>(0, map.
getNumResults())) {
1455 llvm::append_range(dims,
1457 llvm::append_range(symbols,
1464 operands = llvm::to_vector(llvm::concat<Value>(dims, symbols));
1471 bool composeAffineMin) {
1472 assert(map.
getNumResults() == 1 &&
"building affine.apply with !=1 result");
1482 AffineApplyOp applyOp =
1487 for (
unsigned i = 0, e = constOperands.size(); i != e; ++i)
1492 if (failed(applyOp->fold(constOperands, foldResults)) ||
1493 foldResults.empty()) {
1495 listener->notifyOperationInserted(applyOp, {});
1496 return applyOp.getResult();
1500 return llvm::getSingleElement(foldResults);
1510 operands, composeAffineMin);
1516 bool composeAffineMin) {
1517 return llvm::map_to_vector(
1518 llvm::seq<unsigned>(0, map.
getNumResults()), [&](
unsigned i) {
1519 return makeComposedFoldedAffineApply(b, loc, map.getSubMap({i}),
1520 operands, composeAffineMin);
1524template <
typename OpTy>
1530 return OpTy::create(
b, loc,
b.getIndexType(), map, valueOperands);
1539template <
typename OpTy>
1555 for (
unsigned i = 0, e = constOperands.size(); i != e; ++i)
1560 if (failed(minMaxOp->fold(constOperands, foldResults)) ||
1561 foldResults.empty()) {
1563 listener->notifyOperationInserted(minMaxOp, {});
1564 return minMaxOp.getResult();
1568 return llvm::getSingleElement(foldResults);
1587template <
class MapOrSet>
1590 if (!mapOrSet || operands->empty())
1593 assert(mapOrSet->getNumInputs() == operands->size() &&
1594 "map/set inputs must match number of operands");
1596 auto *context = mapOrSet->getContext();
1598 resultOperands.reserve(operands->size());
1600 remappedSymbols.reserve(operands->size());
1601 unsigned nextDim = 0;
1602 unsigned nextSym = 0;
1603 unsigned oldNumSyms = mapOrSet->getNumSymbols();
1605 for (
unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
1606 if (i < mapOrSet->getNumDims()) {
1610 remappedSymbols.push_back((*operands)[i]);
1613 resultOperands.push_back((*operands)[i]);
1616 resultOperands.push_back((*operands)[i]);
1620 resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
1621 *operands = resultOperands;
1622 *mapOrSet = mapOrSet->replaceDimsAndSymbols(
1623 dimRemapping, {}, nextDim, oldNumSyms + nextSym);
1625 assert(mapOrSet->getNumInputs() == operands->size() &&
1626 "map/set inputs must match number of operands");
1635template <
class MapOrSet>
1638 if (!mapOrSet || operands.empty())
1641 unsigned numOperands = operands.size();
1643 assert(mapOrSet.getNumInputs() == numOperands &&
1644 "map/set inputs must match number of operands");
1646 auto *context = mapOrSet.getContext();
1648 resultOperands.reserve(numOperands);
1650 remappedDims.reserve(numOperands);
1652 symOperands.reserve(mapOrSet.getNumSymbols());
1653 unsigned nextSym = 0;
1654 unsigned nextDim = 0;
1655 unsigned oldNumDims = mapOrSet.getNumDims();
1657 resultOperands.assign(operands.begin(), operands.begin() + oldNumDims);
1658 for (
unsigned i = oldNumDims, e = mapOrSet.getNumInputs(); i != e; ++i) {
1661 symRemapping[i - oldNumDims] =
1663 remappedDims.push_back(operands[i]);
1666 symOperands.push_back(operands[i]);
1670 append_range(resultOperands, remappedDims);
1671 append_range(resultOperands, symOperands);
1672 operands = resultOperands;
1673 mapOrSet = mapOrSet.replaceDimsAndSymbols(
1674 {}, symRemapping, oldNumDims + nextDim, nextSym);
1676 assert(mapOrSet.getNumInputs() == operands.size() &&
1677 "map/set inputs must match number of operands");
1681template <
class MapOrSet>
1684 static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
1685 "Argument must be either of AffineMap or IntegerSet type");
1687 if (!mapOrSet || operands->empty())
1690 assert(mapOrSet->getNumInputs() == operands->size() &&
1691 "map/set inputs must match number of operands");
1697 llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
1698 llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
1700 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr))
1701 usedDims[dimExpr.getPosition()] =
true;
1702 else if (
auto symExpr = dyn_cast<AffineSymbolExpr>(expr))
1703 usedSyms[symExpr.getPosition()] =
true;
1706 auto *context = mapOrSet->getContext();
1709 resultOperands.reserve(operands->size());
1711 llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
1713 unsigned nextDim = 0;
1714 for (
unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
1717 auto it = seenDims.find((*operands)[i]);
1718 if (it == seenDims.end()) {
1720 resultOperands.push_back((*operands)[i]);
1721 seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
1723 dimRemapping[i] = it->second;
1727 llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
1729 unsigned nextSym = 0;
1730 for (
unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
1736 IntegerAttr operandCst;
1737 if (
matchPattern((*operands)[i + mapOrSet->getNumDims()],
1744 auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
1745 if (it == seenSymbols.end()) {
1747 resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
1748 seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
1751 symRemapping[i] = it->second;
1754 *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
1756 *operands = resultOperands;
1773template <
typename AffineOpTy>
1782 LogicalResult matchAndRewrite(AffineOpTy affineOp,
1785 llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
1786 AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp,
1787 AffineVectorStoreOp, AffineVectorLoadOp>::value,
1788 "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op "
1790 auto map = affineOp.getAffineMap();
1792 auto oldOperands = affineOp.getMapOperands();
1797 if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
1798 resultOperands.begin()))
1801 replaceAffineOp(rewriter, affineOp, map, resultOperands);
1809void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
1816void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
1820 prefetch, prefetch.getMemref(), map, mapOperands, prefetch.getIsWrite(),
1821 prefetch.getLocalityHint(), prefetch.getIsDataCache());
1824void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
1828 store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
1831void SimplifyAffineOp<AffineVectorLoadOp>::replaceAffineOp(
1835 vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map,
1839void SimplifyAffineOp<AffineVectorStoreOp>::replaceAffineOp(
1843 vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map,
1848template <
typename AffineOpTy>
1849void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
1858 results.
add<SimplifyAffineOp<AffineApplyOp>>(context);
1873 result.addOperands(srcMemRef);
1875 result.addOperands(srcIndices);
1876 result.addOperands(destMemRef);
1878 result.addOperands(destIndices);
1879 result.addOperands(tagMemRef);
1881 result.addOperands(tagIndices);
1882 result.addOperands(numElements);
1884 result.addOperands({stride, elementsPerStride});
1893 Value elementsPerStride) {
1895 build(builder, state, srcMemRef, srcMap, srcIndices, destMemRef, dstMap,
1896 destIndices, tagMemRef, tagMap, tagIndices, numElements, stride,
1898 auto result = dyn_cast<AffineDmaStartOp>(builder.
create(state));
1899 assert(
result &&
"builder didn't return the right type");
1908 Value elementsPerStride) {
1909 return create(builder, builder.
getLoc(), srcMemRef, srcMap, srcIndices,
1910 destMemRef, dstMap, destIndices, tagMemRef, tagMap, tagIndices,
1911 numElements, stride, elementsPerStride);
1939 AffineMapAttr srcMapAttr;
1942 AffineMapAttr dstMapAttr;
1945 AffineMapAttr tagMapAttr;
1977 if (!strideInfo.empty() && strideInfo.size() != 2) {
1979 "expected two stride related operands");
1981 bool isStrided = strideInfo.size() == 2;
1986 if (types.size() != 3)
2004 if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
2005 dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
2006 tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
2008 "memref operand count not equal to map.numInputs");
2014 return emitOpError(
"expected DMA source to be of memref type");
2016 return emitOpError(
"expected DMA destination to be of memref type");
2018 return emitOpError(
"expected DMA tag to be of memref type");
2023 if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
2024 getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
2025 return emitOpError(
"incorrect number of operands");
2030 if (!idx.getType().isIndex())
2031 return emitOpError(
"src index to dma_start must have 'index' type");
2034 "src index must be a valid dimension or symbol identifier");
2037 if (!idx.getType().isIndex())
2038 return emitOpError(
"dst index to dma_start must have 'index' type");
2041 "dst index must be a valid dimension or symbol identifier");
2044 if (!idx.getType().isIndex())
2045 return emitOpError(
"tag index to dma_start must have 'index' type");
2048 "tag index must be a valid dimension or symbol identifier");
2078 result.addOperands(tagMemRef);
2080 result.addOperands(tagIndices);
2081 result.addOperands(numElements);
2087 Value numElements) {
2089 build(builder, state, tagMemRef, tagMap, tagIndices, numElements);
2090 auto result = dyn_cast<AffineDmaWaitOp>(builder.
create(state));
2091 assert(
result &&
"builder didn't return the right type");
2098 Value numElements) {
2099 return create(builder, builder.
getLoc(), tagMemRef, tagMap, tagIndices,
2120 AffineMapAttr tagMapAttr;
2138 if (!llvm::isa<MemRefType>(type))
2140 "expected tag to be of memref type");
2142 if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
2144 "tag memref operand count != to map.numInputs");
2149 if (!llvm::isa<MemRefType>(getOperand(0).
getType()))
2150 return emitOpError(
"expected DMA tag to be of memref type");
2153 if (!idx.getType().isIndex())
2154 return emitOpError(
"index to dma_wait must have 'index' type");
2157 "index must be a valid dimension or symbol identifier");
2184 ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
2185 assert(((!lbMap && lbOperands.empty()) ||
2187 "lower bound operand count does not match the affine map");
2188 assert(((!ubMap && ubOperands.empty()) ||
2190 "upper bound operand count does not match the affine map");
2191 assert(step > 0 &&
"step has to be a positive integer constant");
2197 getOperandSegmentSizeAttr(),
2199 static_cast<int32_t>(ubOperands.size()),
2200 static_cast<int32_t>(iterArgs.size())}));
2202 for (
Value val : iterArgs)
2203 result.addTypes(val.getType());
2210 result.addAttribute(getLowerBoundMapAttrName(
result.name),
2211 AffineMapAttr::get(lbMap));
2212 result.addOperands(lbOperands);
2215 result.addAttribute(getUpperBoundMapAttrName(
result.name),
2216 AffineMapAttr::get(ubMap));
2217 result.addOperands(ubOperands);
2219 result.addOperands(iterArgs);
2224 Value inductionVar =
2226 for (
Value val : iterArgs)
2227 bodyBlock->
addArgument(val.getType(), val.getLoc());
2232 if (iterArgs.empty() && !bodyBuilder) {
2233 ensureTerminator(*bodyRegion, builder,
result.location);
2234 }
else if (bodyBuilder) {
2237 bodyBuilder(builder,
result.location, inductionVar,
2242void AffineForOp::build(OpBuilder &builder, OperationState &
result, int64_t lb,
2243 int64_t ub, int64_t step,
ValueRange iterArgs,
2244 BodyBuilderFn bodyBuilder) {
2247 return build(builder,
result, {}, lbMap, {}, ubMap, step, iterArgs,
2251LogicalResult AffineForOp::verifyRegions() {
2254 auto *body = getBody();
2255 if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
2256 return emitOpError(
"expected body to have a single index argument for the "
2257 "induction variable");
2261 if (getLowerBoundMap().getNumInputs() > 0)
2263 getLowerBoundMap().getNumDims())))
2266 if (getUpperBoundMap().getNumInputs() > 0)
2268 getUpperBoundMap().getNumDims())))
2270 if (getLowerBoundMap().getNumResults() < 1)
2271 return emitOpError(
"expected lower bound map to have at least one result");
2272 if (getUpperBoundMap().getNumResults() < 1)
2273 return emitOpError(
"expected upper bound map to have at least one result");
2275 unsigned opNumResults = getNumResults();
2276 if (opNumResults == 0)
2282 if (getNumIterOperands() != opNumResults)
2284 "mismatch between the number of loop-carried values and results");
2285 if (getNumRegionIterArgs() != opNumResults)
2287 "mismatch between the number of basic block args and results");
2297 bool failedToParsedMinMax =
2301 auto boundAttrStrName =
2302 isLower ? AffineForOp::getLowerBoundMapAttrName(
result.name)
2303 : AffineForOp::getUpperBoundMapAttrName(
result.name);
2310 if (!boundOpInfos.empty()) {
2312 if (boundOpInfos.size() > 1)
2314 "expected only one loop bound operand");
2326 result.addAttribute(boundAttrStrName, AffineMapAttr::get(map));
2339 if (
auto affineMapAttr = dyn_cast<AffineMapAttr>(boundAttr)) {
2340 unsigned currentNumOperands =
result.operands.size();
2345 auto map = affineMapAttr.getValue();
2346 if (map.getNumDims() != numDims)
2349 "dim operand count and affine map dim count must match");
2351 unsigned numDimAndSymbolOperands =
2352 result.operands.size() - currentNumOperands;
2353 if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
2356 "symbol operand count and affine map symbol count must match");
2360 if (map.getNumResults() > 1 && failedToParsedMinMax) {
2362 return p.
emitError(attrLoc,
"lower loop bound affine map with "
2363 "multiple results requires 'max' prefix");
2365 return p.
emitError(attrLoc,
"upper loop bound affine map with multiple "
2366 "results requires 'min' prefix");
2372 if (
auto integerAttr = dyn_cast<IntegerAttr>(boundAttr)) {
2373 result.attributes.pop_back();
2382 "expected valid affine map representation for loop bounds");
2385ParseResult AffineForOp::parse(OpAsmParser &parser, OperationState &
result) {
2387 OpAsmParser::Argument inductionVariable;
2394 int64_t numOperands =
result.operands.size();
2397 int64_t numLbOperands =
result.operands.size() - numOperands;
2400 numOperands =
result.operands.size();
2403 int64_t numUbOperands =
result.operands.size() - numOperands;
2408 getStepAttrName(
result.name),
2412 IntegerAttr stepAttr;
2414 getStepAttrName(
result.name).data(),
2418 if (stepAttr.getValue().isNegative())
2421 "expected step to be representable as a positive signed integer");
2425 SmallVector<OpAsmParser::Argument, 4> regionArgs;
2426 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
2429 regionArgs.push_back(inductionVariable);
2437 for (
auto argOperandType :
2438 llvm::zip(llvm::drop_begin(regionArgs), operands,
result.types)) {
2439 Type type = std::get<2>(argOperandType);
2440 std::get<0>(argOperandType).type = type;
2448 getOperandSegmentSizeAttr(),
2450 static_cast<int32_t>(numUbOperands),
2451 static_cast<int32_t>(operands.size())}));
2454 Region *body =
result.addRegion();
2455 if (regionArgs.size() !=
result.types.size() + 1)
2458 "mismatch between the number of loop-carried values and results");
2462 AffineForOp::ensureTerminator(*body, builder,
result.location);
2484 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
2485 p << constExpr.getValue();
2493 if (isa<AffineSymbolExpr>(expr)) {
2509unsigned AffineForOp::getNumIterOperands() {
2510 AffineMap lbMap = getLowerBoundMapAttr().getValue();
2511 AffineMap ubMap = getUpperBoundMapAttr().getValue();
2516std::optional<MutableArrayRef<OpOperand>>
2517AffineForOp::getYieldedValuesMutable() {
2518 return cast<AffineYieldOp>(getBody()->getTerminator()).getOperandsMutable();
2521void AffineForOp::print(OpAsmPrinter &p) {
2530 if (getStepAsInt() != 1)
2531 p <<
" step " << getStepAsInt();
2533 bool printBlockTerminators =
false;
2534 if (getNumIterOperands() > 0) {
2536 auto regionArgs = getRegionIterArgs();
2537 auto operands = getInits();
2539 llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](
auto it) {
2540 p << std::get<0>(it) <<
" = " << std::get<1>(it);
2542 p <<
") -> (" << getResultTypes() <<
")";
2543 printBlockTerminators =
true;
2548 printBlockTerminators);
2550 (*this)->getAttrs(),
2551 {getLowerBoundMapAttrName(getOperation()->getName()),
2552 getUpperBoundMapAttrName(getOperation()->getName()),
2553 getStepAttrName(getOperation()->getName()),
2554 getOperandSegmentSizeAttr()});
2559 auto foldLowerOrUpperBound = [&forOp](
bool lower) {
2563 auto boundOperands =
2564 lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
2565 for (
auto operand : boundOperands) {
2568 operandConstants.push_back(operandCst);
2572 lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
2574 "bound maps should have at least one result");
2576 if (failed(boundMap.
constantFold(operandConstants, foldedResults)))
2580 assert(!foldedResults.empty() &&
"bounds should have at least one result");
2581 auto maxOrMin = llvm::cast<IntegerAttr>(foldedResults[0]).getValue();
2582 for (
unsigned i = 1, e = foldedResults.size(); i < e; i++) {
2583 auto foldedResult = llvm::cast<IntegerAttr>(foldedResults[i]).getValue();
2584 maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
2585 : llvm::APIntOps::smin(maxOrMin, foldedResult);
2587 lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
2588 : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
2593 bool folded =
false;
2594 if (!forOp.hasConstantLowerBound())
2595 folded |= succeeded(foldLowerOrUpperBound(
true));
2598 if (!forOp.hasConstantUpperBound())
2599 folded |= succeeded(foldLowerOrUpperBound(
false));
2605 int64_t step = forOp.getStepAsInt();
2606 if (!forOp.hasConstantBounds() || step <= 0)
2607 return std::nullopt;
2608 int64_t lb = forOp.getConstantLowerBound();
2609 int64_t ub = forOp.getConstantUpperBound();
2610 return ub - lb <= 0 ? 0 : (
ub - lb + step - 1) / step;
2615 if (!llvm::hasSingleElement(*forOp.getBody()))
2617 if (forOp.getNumResults() == 0)
2620 if (tripCount == 0) {
2623 return forOp.getInits();
2626 auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
2627 auto iterArgs = forOp.getRegionIterArgs();
2628 bool hasValDefinedOutsideLoop =
false;
2629 bool iterArgsNotInOrder =
false;
2630 for (
unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
2631 Value val = yieldOp.getOperand(i);
2635 if (val == forOp.getInductionVar())
2637 if (iterArgIt == iterArgs.end()) {
2639 assert(forOp.isDefinedOutsideOfLoop(val) &&
2640 "must be defined outside of the loop");
2641 hasValDefinedOutsideLoop =
true;
2642 replacements.push_back(val);
2644 unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
2646 iterArgsNotInOrder =
true;
2647 replacements.push_back(forOp.getInits()[pos]);
2652 if (!tripCount.has_value() &&
2653 (hasValDefinedOutsideLoop || iterArgsNotInOrder))
2657 if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
2659 return llvm::to_vector_of<OpFoldResult>(replacements);
2667 auto lbMap = forOp.getLowerBoundMap();
2668 auto ubMap = forOp.getUpperBoundMap();
2669 auto prevLbMap = lbMap;
2670 auto prevUbMap = ubMap;
2683 if (lbMap == prevLbMap && ubMap == prevUbMap)
2686 if (lbMap != prevLbMap)
2687 forOp.setLowerBound(lbOperands, lbMap);
2688 if (ubMap != prevUbMap)
2689 forOp.setUpperBound(ubOperands, ubMap);
2698LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
2699 SmallVectorImpl<OpFoldResult> &results) {
2708 results.assign(getInits().begin(), getInits().end());
2712 if (!foldResults.empty()) {
2713 results.assign(foldResults);
2719OperandRange AffineForOp::getEntrySuccessorOperands(RegionSuccessor successor) {
2721 "invalid region point");
2728void AffineForOp::getSuccessorRegions(
2729 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
2733 "expected loop region");
2739 if (tripCount.has_value()) {
2743 if (tripCount == 1) {
2750 if (tripCount.value() > 0) {
2751 regions.push_back(RegionSuccessor(&getRegion()));
2754 if (tripCount.value() == 0) {
2763 regions.push_back(RegionSuccessor(&getRegion()));
2767ValueRange AffineForOp::getSuccessorInputs(RegionSuccessor successor) {
2769 return getResults();
2770 return getRegionIterArgs();
2781void AffineForOp::setLowerBound(
ValueRange lbOperands, AffineMap map) {
2783 assert(map.
getNumResults() >= 1 &&
"bound map has at least one result");
2784 getLowerBoundOperandsMutable().assign(lbOperands);
2785 setLowerBoundMap(map);
2788void AffineForOp::setUpperBound(
ValueRange ubOperands, AffineMap map) {
2790 assert(map.
getNumResults() >= 1 &&
"bound map has at least one result");
2791 getUpperBoundOperandsMutable().assign(ubOperands);
2792 setUpperBoundMap(map);
2795bool AffineForOp::hasConstantLowerBound() {
2796 return getLowerBoundMap().isSingleConstant();
2799bool AffineForOp::hasConstantUpperBound() {
2800 return getUpperBoundMap().isSingleConstant();
2803int64_t AffineForOp::getConstantLowerBound() {
2804 return getLowerBoundMap().getSingleConstantResult();
2807int64_t AffineForOp::getConstantUpperBound() {
2808 return getUpperBoundMap().getSingleConstantResult();
2811void AffineForOp::setConstantLowerBound(int64_t value) {
2815void AffineForOp::setConstantUpperBound(int64_t value) {
2819AffineForOp::operand_range AffineForOp::getControlOperands() {
2824bool AffineForOp::matchingBoundOperandList() {
2825 auto lbMap = getLowerBoundMap();
2826 auto ubMap = getUpperBoundMap();
2832 for (
unsigned i = 0, e = lbMap.
getNumInputs(); i < e; i++) {
2834 if (getOperand(i) != getOperand(numOperands + i))
2840SmallVector<Region *> AffineForOp::getLoopRegions() {
return {&getRegion()}; }
2842std::optional<SmallVector<Value>> AffineForOp::getLoopInductionVars() {
2843 return SmallVector<Value>{getInductionVar()};
2846std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopLowerBounds() {
2847 if (!hasConstantLowerBound())
2848 return std::nullopt;
2850 return SmallVector<OpFoldResult>{
2851 OpFoldResult(
b.getI64IntegerAttr(getConstantLowerBound()))};
2854std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopSteps() {
2856 return SmallVector<OpFoldResult>{
2857 OpFoldResult(
b.getI64IntegerAttr(getStepAsInt()))};
2860std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopUpperBounds() {
2861 if (!hasConstantUpperBound())
2864 return SmallVector<OpFoldResult>{
2865 OpFoldResult(
b.getI64IntegerAttr(getConstantUpperBound()))};
2868FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
2869 RewriterBase &rewriter,
ValueRange newInitOperands,
2870 bool replaceInitOperandUsesInLoop,
2873 OpBuilder::InsertionGuard g(rewriter);
2875 auto inits = llvm::to_vector(getInits());
2876 inits.append(newInitOperands.begin(), newInitOperands.end());
2877 AffineForOp newLoop = AffineForOp::create(
2882 auto yieldOp = cast<AffineYieldOp>(getBody()->getTerminator());
2883 ArrayRef<BlockArgument> newIterArgs =
2884 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
2886 OpBuilder::InsertionGuard g(rewriter);
2888 SmallVector<Value> newYieldedValues =
2889 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
2890 assert(newInitOperands.size() == newYieldedValues.size() &&
2891 "expected as many new yield values as new iter operands");
2893 yieldOp.getOperandsMutable().append(newYieldedValues);
2898 rewriter.
mergeBlocks(getBody(), newLoop.getBody(),
2899 newLoop.getBody()->getArguments().take_front(
2900 getBody()->getNumArguments()));
2902 if (replaceInitOperandUsesInLoop) {
2905 for (
auto it : llvm::zip(newInitOperands, newIterArgs)) {
2907 [&](OpOperand &use) {
2909 return newLoop->isProperAncestor(user);
2916 newLoop->getResults().take_front(getNumResults()));
2917 return cast<LoopLikeOpInterface>(newLoop.getOperation());
2945 auto ivArg = dyn_cast<BlockArgument>(val);
2946 if (!ivArg || !ivArg.getOwner() || !ivArg.getOwner()->getParent())
2947 return AffineForOp();
2949 ivArg.getOwner()->getParent()->getParentOfType<AffineForOp>())
2951 return forOp.getInductionVar() == val ? forOp : AffineForOp();
2952 return AffineForOp();
2956 auto ivArg = dyn_cast<BlockArgument>(val);
2957 if (!ivArg || !ivArg.getOwner())
2960 auto parallelOp = dyn_cast_if_present<AffineParallelOp>(containingOp);
2961 if (parallelOp && llvm::is_contained(parallelOp.getIVs(), val))
2970 ivs->reserve(forInsts.size());
2971 for (
auto forInst : forInsts)
2972 ivs->push_back(forInst.getInductionVar());
2977 ivs.reserve(affineOps.size());
2980 if (
auto forOp = dyn_cast<AffineForOp>(op))
2981 ivs.push_back(forOp.getInductionVar());
2982 else if (
auto parallelOp = dyn_cast<AffineParallelOp>(op))
2983 for (
size_t i = 0; i < parallelOp.getBody()->getNumArguments(); i++)
2984 ivs.push_back(parallelOp.getBody()->getArgument(i));
2990template <
typename BoundListTy,
typename LoopCreatorTy>
2995 LoopCreatorTy &&loopCreatorFn) {
2996 assert(lbs.size() == ubs.size() &&
"Mismatch in number of arguments");
2997 assert(lbs.size() == steps.size() &&
"Mismatch in number of arguments");
3009 ivs.reserve(lbs.size());
3010 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
3016 if (i == e - 1 && bodyBuilderFn) {
3018 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
3020 AffineYieldOp::create(nestedBuilder, nestedLoc);
3025 auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
3034 AffineForOp::BodyBuilderFn bodyBuilderFn) {
3035 return AffineForOp::create(builder, loc, lb,
ub, step,
3043 AffineForOp::BodyBuilderFn bodyBuilderFn) {
3046 if (lbConst && ubConst)
3048 ubConst.value(), step, bodyBuilderFn);
3079 LogicalResult matchAndRewrite(AffineIfOp ifOp,
3081 if (ifOp.getElseRegion().empty() ||
3082 !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
3094struct AlwaysTrueOrFalseIf :
public OpRewritePattern<AffineIfOp> {
3095 using OpRewritePattern<AffineIfOp>::OpRewritePattern;
3097 LogicalResult matchAndRewrite(AffineIfOp op,
3098 PatternRewriter &rewriter)
const override {
3100 auto isTriviallyFalse = [](IntegerSet iSet) {
3101 return iSet.isEmptyIntegerSet();
3104 auto isTriviallyTrue = [](IntegerSet iSet) {
3105 return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 &&
3106 iSet.getConstraint(0) == 0);
3109 IntegerSet affineIfConditions = op.getIntegerSet();
3111 if (isTriviallyFalse(affineIfConditions)) {
3115 if (op.getNumResults() == 0 && !op.hasElse()) {
3121 blockToMove = op.getElseBlock();
3122 }
else if (isTriviallyTrue(affineIfConditions)) {
3123 blockToMove = op.getThenBlock();
3127 Operation *blockToMoveTerminator = blockToMove->
getTerminator();
3141 rewriter.
eraseOp(blockToMoveTerminator);
3149void AffineIfOp::getSuccessorRegions(
3150 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
3155 regions.push_back(RegionSuccessor(&getThenRegion()));
3157 if (getElseRegion().empty()) {
3160 regions.push_back(RegionSuccessor(&getElseRegion()));
3170ValueRange AffineIfOp::getSuccessorInputs(RegionSuccessor successor) {
3172 return getResults();
3173 if (successor == &getThenRegion())
3174 return getThenRegion().getArguments();
3175 if (successor == &getElseRegion())
3176 return getElseRegion().getArguments();
3177 llvm_unreachable(
"invalid region successor");
3180LogicalResult AffineIfOp::verify() {
3183 auto conditionAttr =
3184 (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
3186 return emitOpError(
"requires an integer set attribute named 'condition'");
3189 IntegerSet condition = conditionAttr.getValue();
3191 return emitOpError(
"operand count and condition integer set dimension and "
3192 "symbol count must match");
3202ParseResult AffineIfOp::parse(OpAsmParser &parser, OperationState &
result) {
3204 IntegerSetAttr conditionAttr;
3207 AffineIfOp::getConditionAttrStrName(),
3213 auto set = conditionAttr.getValue();
3214 if (set.getNumDims() != numDims)
3217 "dim operand count and integer set dim count must match");
3218 if (numDims + set.getNumSymbols() !=
result.operands.size())
3221 "symbol operand count and integer set symbol count must match");
3228 result.regions.reserve(2);
3229 Region *thenRegion =
result.addRegion();
3230 Region *elseRegion =
result.addRegion();
3235 AffineIfOp::ensureTerminator(*thenRegion, parser.
getBuilder(),
3242 AffineIfOp::ensureTerminator(*elseRegion, parser.
getBuilder(),
3253void AffineIfOp::print(OpAsmPrinter &p) {
3254 auto conditionAttr =
3255 (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
3256 p <<
" " << conditionAttr;
3258 conditionAttr.getValue().getNumDims(), p);
3265 auto &elseRegion = this->getElseRegion();
3266 if (!elseRegion.
empty()) {
3275 getConditionAttrStrName());
3278IntegerSet AffineIfOp::getIntegerSet() {
3280 ->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName())
3284void AffineIfOp::setIntegerSet(IntegerSet newSet) {
3285 (*this)->setAttr(getConditionAttrStrName(), IntegerSetAttr::get(newSet));
3288void AffineIfOp::setConditional(IntegerSet set,
ValueRange operands) {
3290 (*this)->setOperands(operands);
3293void AffineIfOp::build(OpBuilder &builder, OperationState &
result,
3295 bool withElseRegion) {
3296 assert(resultTypes.empty() || withElseRegion);
3297 OpBuilder::InsertionGuard guard(builder);
3299 result.addTypes(resultTypes);
3300 result.addOperands(args);
3301 result.addAttribute(getConditionAttrStrName(), IntegerSetAttr::get(set));
3303 Region *thenRegion =
result.addRegion();
3305 if (resultTypes.empty())
3306 AffineIfOp::ensureTerminator(*thenRegion, builder,
result.location);
3308 Region *elseRegion =
result.addRegion();
3309 if (withElseRegion) {
3311 if (resultTypes.empty())
3312 AffineIfOp::ensureTerminator(*elseRegion, builder,
result.location);
3316void AffineIfOp::build(OpBuilder &builder, OperationState &
result,
3317 IntegerSet set,
ValueRange args,
bool withElseRegion) {
3318 AffineIfOp::build(builder,
result, {}, set, args,
3327 bool composeAffineMin =
false) {
3334 if (llvm::none_of(operands,
3344LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3345 auto set = getIntegerSet();
3346 SmallVector<Value, 4> operands(getOperands());
3351 if (getIntegerSet() == set && llvm::equal(operands, getOperands()))
3354 setConditional(set, operands);
3358void AffineIfOp::getCanonicalizationPatterns(RewritePatternSet &results,
3359 MLIRContext *context) {
3360 results.
add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context);
3367void AffineLoadOp::build(OpBuilder &builder, OperationState &
result,
3369 assert(operands.size() == 1 + map.
getNumInputs() &&
"inconsistent operands");
3370 result.addOperands(operands);
3372 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
3373 auto memrefType = llvm::cast<MemRefType>(operands[0].
getType());
3374 result.types.push_back(memrefType.getElementType());
3377void AffineLoadOp::build(OpBuilder &builder, OperationState &
result,
3378 Value memref, AffineMap map,
ValueRange mapOperands) {
3379 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
3380 result.addOperands(memref);
3381 result.addOperands(mapOperands);
3382 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
3383 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
3384 result.types.push_back(memrefType.getElementType());
3387void AffineLoadOp::build(OpBuilder &builder, OperationState &
result,
3389 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
3390 int64_t rank = memrefType.getRank();
3398ParseResult AffineLoadOp::parse(OpAsmParser &parser, OperationState &
result) {
3403 OpAsmParser::UnresolvedOperand memrefInfo;
3404 AffineMapAttr mapAttr;
3405 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
3409 AffineLoadOp::getMapAttrStrName(),
3418void AffineLoadOp::print(OpAsmPrinter &p) {
3420 if (AffineMapAttr mapAttr =
3421 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3425 {getMapAttrStrName()});
3431template <
typename AffineMemOpTy>
3435 MemRefType memrefType,
unsigned numIndexOperands) {
3438 return op->emitOpError(
"affine map num results must equal memref rank");
3440 return op->emitOpError(
"expects as many subscripts as affine map inputs");
3442 for (
auto idx : mapOperands) {
3443 if (!idx.getType().isIndex())
3444 return op->emitOpError(
"index to load must have 'index' type");
3452LogicalResult AffineLoadOp::verify() {
3454 if (
getType() != memrefType.getElementType())
3455 return emitOpError(
"result type must match element type of memref");
3458 *
this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3459 getMapOperands(), memrefType,
3460 getNumOperands() - 1)))
3466void AffineLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3467 MLIRContext *context) {
3468 results.
add<SimplifyAffineOp<AffineLoadOp>>(context);
3471OpFoldResult AffineLoadOp::fold(FoldAdaptor adaptor) {
3477 auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
3481 auto *symbolTableOp = getGlobalOp->getParentWithTrait<OpTrait::SymbolTable>();
3484 auto global = dyn_cast_or_null<memref::GlobalOp>(
3491 dyn_cast_or_null<DenseElementsAttr>(global.getConstantInitValue());
3495 if (
auto splatAttr = dyn_cast<SplatElementsAttr>(cstAttr))
3496 return splatAttr.getSplatValue<Attribute>();
3498 if (!getAffineMap().isConstant())
3500 auto indices = llvm::to_vector<4>(
3501 llvm::map_range(getAffineMap().getConstantResults(),
3502 [](int64_t v) -> uint64_t {
return v; }));
3503 return cstAttr.getValues<Attribute>()[
indices];
3510void AffineStoreOp::build(OpBuilder &builder, OperationState &
result,
3511 Value valueToStore, Value memref, AffineMap map,
3513 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
3514 result.addOperands(valueToStore);
3515 result.addOperands(memref);
3516 result.addOperands(mapOperands);
3517 result.getOrAddProperties<Properties>().map = AffineMapAttr::get(map);
3521void AffineStoreOp::build(OpBuilder &builder, OperationState &
result,
3522 Value valueToStore, Value memref,
3524 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
3525 int64_t rank = memrefType.getRank();
3533ParseResult AffineStoreOp::parse(OpAsmParser &parser, OperationState &
result) {
3537 OpAsmParser::UnresolvedOperand storeValueInfo;
3538 OpAsmParser::UnresolvedOperand memrefInfo;
3539 AffineMapAttr mapAttr;
3540 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
3544 mapOperands, mapAttr, AffineStoreOp::getMapAttrStrName(),
3554void AffineStoreOp::print(OpAsmPrinter &p) {
3555 p <<
" " << getValueToStore();
3557 if (AffineMapAttr mapAttr =
3558 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3562 {getMapAttrStrName()});
3566LogicalResult AffineStoreOp::verify() {
3569 if (getValueToStore().
getType() != memrefType.getElementType())
3571 "value to store must have the same type as memref element type");
3574 *
this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3575 getMapOperands(), memrefType,
3576 getNumOperands() - 2)))
3582void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
3583 MLIRContext *context) {
3584 results.
add<SimplifyAffineOp<AffineStoreOp>>(context);
3587LogicalResult AffineStoreOp::fold(FoldAdaptor adaptor,
3588 SmallVectorImpl<OpFoldResult> &results) {
3597template <
typename T>
3600 if (op.getNumOperands() !=
3601 op.getMap().getNumDims() + op.getMap().getNumSymbols())
3602 return op.emitOpError(
3603 "operand count and affine map dimension and symbol count must match");
3605 if (op.getMap().getNumResults() == 0)
3606 return op.emitOpError(
"affine map expect at least one result");
3610template <
typename T>
3612 p <<
' ' << op->getAttr(T::getMapAttrStrName());
3613 auto operands = op.getOperands();
3614 unsigned numDims = op.getMap().getNumDims();
3615 p <<
'(' << operands.take_front(numDims) <<
')';
3617 if (operands.size() != numDims)
3618 p <<
'[' << operands.drop_front(numDims) <<
']';
3620 {T::getMapAttrStrName()});
3623template <
typename T>
3630 AffineMapAttr mapAttr;
3646template <
typename T>
3648 static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
3649 "expected affine min or max op");
3655 auto foldedMap = op.getMap().partialConstantFold(operands, &results);
3657 if (foldedMap.getNumSymbols() == 1 && foldedMap.isSymbolIdentity())
3658 return op.getOperand(0);
3661 if (results.empty()) {
3663 if (foldedMap == op.getMap())
3665 op->setAttr(
"map", AffineMapAttr::get(foldedMap));
3666 return op.getResult();
3670 auto resultIt = std::is_same<T, AffineMinOp>::value
3671 ? llvm::min_element(results)
3672 : llvm::max_element(results);
3673 if (resultIt == results.end())
3675 return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt);
3679template <
typename T>
3685 AffineMap oldMap = affineOp.getAffineMap();
3691 if (!llvm::is_contained(newExprs, expr))
3692 newExprs.push_back(expr);
3722template <
typename T>
3728 AffineMap oldMap = affineOp.getAffineMap();
3730 affineOp.getMapOperands().take_front(oldMap.
getNumDims());
3732 affineOp.getMapOperands().take_back(oldMap.
getNumSymbols());
3734 auto newDimOperands = llvm::to_vector<8>(dimOperands);
3735 auto newSymOperands = llvm::to_vector<8>(symOperands);
3743 if (
auto symExpr = dyn_cast<AffineSymbolExpr>(expr)) {
3744 Value symValue = symOperands[symExpr.getPosition()];
3746 producerOps.push_back(producerOp);
3749 }
else if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
3750 Value dimValue = dimOperands[dimExpr.getPosition()];
3752 producerOps.push_back(producerOp);
3759 newExprs.push_back(expr);
3762 if (producerOps.empty())
3769 for (T producerOp : producerOps) {
3770 AffineMap producerMap = producerOp.getAffineMap();
3771 unsigned numProducerDims = producerMap.
getNumDims();
3776 producerOp.getMapOperands().take_front(numProducerDims);
3778 producerOp.getMapOperands().take_back(numProducerSyms);
3779 newDimOperands.append(dimValues.begin(), dimValues.end());
3780 newSymOperands.append(symValues.begin(), symValues.end());
3784 newExprs.push_back(expr.shiftDims(numProducerDims, numUsedDims)
3785 .shiftSymbols(numProducerSyms, numUsedSyms));
3788 numUsedDims += numProducerDims;
3789 numUsedSyms += numProducerSyms;
3795 llvm::to_vector<8>(llvm::concat<Value>(newDimOperands, newSymOperands));
3814 if (!resultExpr.isPureAffine())
3819 if (failed(flattenResult))
3832 if (llvm::is_sorted(flattenedExprs))
3837 llvm::to_vector(llvm::seq<unsigned>(0, map.
getNumResults()));
3838 llvm::sort(resultPermutation, [&](
unsigned lhs,
unsigned rhs) {
3839 return flattenedExprs[
lhs] < flattenedExprs[
rhs];
3842 for (
unsigned idx : resultPermutation)
3863template <
typename T>
3869 AffineMap map = affineOp.getAffineMap();
3877template <
typename T>
3883 if (affineOp.getMap().getNumResults() != 1)
3886 affineOp.getOperands());
3898OpFoldResult AffineMinOp::fold(FoldAdaptor adaptor) {
3902void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
3903 MLIRContext *context) {
3904 patterns.add<CanonicalizeSingleResultAffineMinMaxOp<AffineMinOp>,
3905 DeduplicateAffineMinMaxExpressions<AffineMinOp>,
3906 MergeAffineMinMaxOp<AffineMinOp>, SimplifyAffineOp<AffineMinOp>,
3907 CanonicalizeAffineMinMaxOpExprAndTermOrder<AffineMinOp>>(
3913ParseResult AffineMinOp::parse(OpAsmParser &parser, OperationState &
result) {
3926OpFoldResult AffineMaxOp::fold(FoldAdaptor adaptor) {
3930void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
3931 MLIRContext *context) {
3932 patterns.add<CanonicalizeSingleResultAffineMinMaxOp<AffineMaxOp>,
3933 DeduplicateAffineMinMaxExpressions<AffineMaxOp>,
3934 MergeAffineMinMaxOp<AffineMaxOp>, SimplifyAffineOp<AffineMaxOp>,
3935 CanonicalizeAffineMinMaxOpExprAndTermOrder<AffineMaxOp>>(
3941ParseResult AffineMaxOp::parse(OpAsmParser &parser, OperationState &
result) {
3954ParseResult AffinePrefetchOp::parse(OpAsmParser &parser,
3955 OperationState &
result) {
3960 OpAsmParser::UnresolvedOperand memrefInfo;
3961 IntegerAttr hintInfo;
3963 StringRef readOrWrite, cacheType;
3965 AffineMapAttr mapAttr;
3966 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
3969 AffinePrefetchOp::getMapAttrStrName(),
3975 AffinePrefetchOp::getLocalityHintAttrStrName(),
3985 if (readOrWrite !=
"read" && readOrWrite !=
"write")
3987 "rw specifier has to be 'read' or 'write'");
3988 result.addAttribute(AffinePrefetchOp::getIsWriteAttrStrName(),
3991 if (cacheType !=
"data" && cacheType !=
"instr")
3993 "cache type has to be 'data' or 'instr'");
3995 result.addAttribute(AffinePrefetchOp::getIsDataCacheAttrStrName(),
4001void AffinePrefetchOp::print(OpAsmPrinter &p) {
4002 p <<
" " << getMemref() <<
'[';
4003 AffineMapAttr mapAttr =
4004 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
4007 p <<
']' <<
", " << (getIsWrite() ?
"write" :
"read") <<
", "
4008 <<
"locality<" << getLocalityHint() <<
">, "
4009 << (getIsDataCache() ?
"data" :
"instr");
4011 (*this)->getAttrs(),
4012 {getMapAttrStrName(), getLocalityHintAttrStrName(),
4013 getIsDataCacheAttrStrName(), getIsWriteAttrStrName()});
4017LogicalResult AffinePrefetchOp::verify() {
4018 auto mapAttr = (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
4020 AffineMap map = mapAttr.getValue();
4022 return emitOpError(
"affine.prefetch affine map num results must equal"
4027 if (getNumOperands() != 1)
4032 for (
auto idx : getMapOperands()) {
4035 "index must be a valid dimension or symbol identifier");
4040void AffinePrefetchOp::getCanonicalizationPatterns(RewritePatternSet &results,
4041 MLIRContext *context) {
4043 results.
add<SimplifyAffineOp<AffinePrefetchOp>>(context);
4046LogicalResult AffinePrefetchOp::fold(FoldAdaptor adaptor,
4047 SmallVectorImpl<OpFoldResult> &results) {
4056void AffineParallelOp::build(OpBuilder &builder, OperationState &
result,
4058 ArrayRef<arith::AtomicRMWKind> reductions,
4059 ArrayRef<int64_t> ranges) {
4061 auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) {
4064 SmallVector<int64_t> steps(ranges.size(), 1);
4065 build(builder,
result, resultTypes, reductions, lbs, {}, ubs,
4069void AffineParallelOp::build(OpBuilder &builder, OperationState &
result,
4071 ArrayRef<arith::AtomicRMWKind> reductions,
4072 ArrayRef<AffineMap> lbMaps,
ValueRange lbArgs,
4073 ArrayRef<AffineMap> ubMaps,
ValueRange ubArgs,
4074 ArrayRef<int64_t> steps) {
4075 assert(llvm::all_of(lbMaps,
4076 [lbMaps](AffineMap m) {
4077 return m.
getNumDims() == lbMaps[0].getNumDims() &&
4080 "expected all lower bounds maps to have the same number of dimensions "
4082 assert(llvm::all_of(ubMaps,
4083 [ubMaps](AffineMap m) {
4084 return m.
getNumDims() == ubMaps[0].getNumDims() &&
4087 "expected all upper bounds maps to have the same number of dimensions "
4089 assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) &&
4090 "expected lower bound maps to have as many inputs as lower bound "
4092 assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) &&
4093 "expected upper bound maps to have as many inputs as upper bound "
4096 OpBuilder::InsertionGuard guard(builder);
4097 result.addTypes(resultTypes);
4100 SmallVector<Attribute, 4> reductionAttrs;
4101 for (arith::AtomicRMWKind reduction : reductions)
4102 reductionAttrs.push_back(
4104 result.addAttribute(getReductionsAttrStrName(),
4109 auto concatMapsSameInput = [&builder](ArrayRef<AffineMap> maps,
4110 SmallVectorImpl<int32_t> &groups) {
4113 SmallVector<AffineExpr> exprs;
4114 groups.reserve(groups.size() + maps.size());
4115 exprs.reserve(maps.size());
4116 for (AffineMap m : maps) {
4120 return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs,
4125 SmallVector<int32_t> lbGroups, ubGroups;
4126 AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups);
4127 AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups);
4128 result.addAttribute(getLowerBoundsMapAttrStrName(),
4129 AffineMapAttr::get(lbMap));
4130 result.addAttribute(getLowerBoundsGroupsAttrStrName(),
4132 result.addAttribute(getUpperBoundsMapAttrStrName(),
4133 AffineMapAttr::get(ubMap));
4134 result.addAttribute(getUpperBoundsGroupsAttrStrName(),
4137 result.addOperands(lbArgs);
4138 result.addOperands(ubArgs);
4141 auto *bodyRegion =
result.addRegion();
4145 for (
unsigned i = 0, e = steps.size(); i < e; ++i)
4147 if (resultTypes.empty())
4148 ensureTerminator(*bodyRegion, builder,
result.location);
4151SmallVector<Region *> AffineParallelOp::getLoopRegions() {
4152 return {&getRegion()};
4155unsigned AffineParallelOp::getNumDims() {
return getSteps().size(); }
4157AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
4158 return getOperands().take_front(getLowerBoundsMap().getNumInputs());
4161AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
4162 return getOperands().drop_front(getLowerBoundsMap().getNumInputs());
4165AffineMap AffineParallelOp::getLowerBoundMap(
unsigned pos) {
4166 auto values = getLowerBoundsGroups().getValues<int32_t>();
4168 for (
unsigned i = 0; i < pos; ++i)
4170 return getLowerBoundsMap().getSliceMap(start, values[pos]);
4173AffineMap AffineParallelOp::getUpperBoundMap(
unsigned pos) {
4174 auto values = getUpperBoundsGroups().getValues<int32_t>();
4176 for (
unsigned i = 0; i < pos; ++i)
4178 return getUpperBoundsMap().getSliceMap(start, values[pos]);
4182 return AffineValueMap(getLowerBoundsMap(), getLowerBoundsOperands());
4186 return AffineValueMap(getUpperBoundsMap(), getUpperBoundsOperands());
4189std::optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
4190 if (hasMinMaxBounds())
4191 return std::nullopt;
4194 SmallVector<int64_t, 8> out;
4199 for (
unsigned i = 0, e = rangesValueMap.
getNumResults(); i < e; ++i) {
4200 auto expr = rangesValueMap.
getResult(i);
4201 auto cst = dyn_cast<AffineConstantExpr>(expr);
4203 return std::nullopt;
4204 out.push_back(cst.getValue());
4209Block *AffineParallelOp::getBody() {
return &getRegion().front(); }
4211OpBuilder AffineParallelOp::getBodyBuilder() {
4212 return OpBuilder(getBody(), std::prev(getBody()->end()));
4215void AffineParallelOp::setLowerBounds(
ValueRange lbOperands, AffineMap map) {
4217 "operands to map must match number of inputs");
4219 auto ubOperands = getUpperBoundsOperands();
4221 SmallVector<Value, 4> newOperands(lbOperands);
4222 newOperands.append(ubOperands.begin(), ubOperands.end());
4223 (*this)->setOperands(newOperands);
4225 setLowerBoundsMapAttr(AffineMapAttr::get(map));
4228void AffineParallelOp::setUpperBounds(
ValueRange ubOperands, AffineMap map) {
4230 "operands to map must match number of inputs");
4232 SmallVector<Value, 4> newOperands(getLowerBoundsOperands());
4233 newOperands.append(ubOperands.begin(), ubOperands.end());
4234 (*this)->setOperands(newOperands);
4236 setUpperBoundsMapAttr(AffineMapAttr::get(map));
4239void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
4245 arith::AtomicRMWKind op) {
4247 case arith::AtomicRMWKind::addf:
4248 return isa<FloatType>(resultType);
4249 case arith::AtomicRMWKind::addi:
4250 return isa<IntegerType>(resultType);
4251 case arith::AtomicRMWKind::assign:
4253 case arith::AtomicRMWKind::mulf:
4254 return isa<FloatType>(resultType);
4255 case arith::AtomicRMWKind::muli:
4256 return isa<IntegerType>(resultType);
4257 case arith::AtomicRMWKind::maximumf:
4258 return isa<FloatType>(resultType);
4259 case arith::AtomicRMWKind::minimumf:
4260 return isa<FloatType>(resultType);
4261 case arith::AtomicRMWKind::maxs: {
4262 auto intType = dyn_cast<IntegerType>(resultType);
4263 return intType && intType.isSigned();
4265 case arith::AtomicRMWKind::mins: {
4266 auto intType = dyn_cast<IntegerType>(resultType);
4267 return intType && intType.isSigned();
4269 case arith::AtomicRMWKind::maxu: {
4270 auto intType = dyn_cast<IntegerType>(resultType);
4271 return intType && intType.isUnsigned();
4273 case arith::AtomicRMWKind::minu: {
4274 auto intType = dyn_cast<IntegerType>(resultType);
4275 return intType && intType.isUnsigned();
4277 case arith::AtomicRMWKind::ori:
4278 return isa<IntegerType>(resultType);
4279 case arith::AtomicRMWKind::andi:
4280 return isa<IntegerType>(resultType);
4286LogicalResult AffineParallelOp::verify() {
4287 auto numDims = getNumDims();
4290 getSteps().size() != numDims || getBody()->getNumArguments() != numDims) {
4291 return emitOpError() <<
"the number of region arguments ("
4292 << getBody()->getNumArguments()
4293 <<
") and the number of map groups for lower ("
4294 << getLowerBoundsGroups().getNumElements()
4295 <<
") and upper bound ("
4296 << getUpperBoundsGroups().getNumElements()
4297 <<
"), and the number of steps (" << getSteps().size()
4298 <<
") must all match";
4301 unsigned expectedNumLBResults = 0;
4302 for (APInt v : getLowerBoundsGroups()) {
4303 unsigned results = v.getZExtValue();
4306 <<
"expected lower bound map to have at least one result";
4307 expectedNumLBResults += results;
4309 if (expectedNumLBResults != getLowerBoundsMap().getNumResults())
4310 return emitOpError() <<
"expected lower bounds map to have "
4311 << expectedNumLBResults <<
" results";
4312 unsigned expectedNumUBResults = 0;
4313 for (APInt v : getUpperBoundsGroups()) {
4314 unsigned results = v.getZExtValue();
4317 <<
"expected upper bound map to have at least one result";
4318 expectedNumUBResults += results;
4320 if (expectedNumUBResults != getUpperBoundsMap().getNumResults())
4321 return emitOpError() <<
"expected upper bounds map to have "
4322 << expectedNumUBResults <<
" results";
4324 if (getReductions().size() != getNumResults())
4325 return emitOpError(
"a reduction must be specified for each output");
4329 for (
auto it : llvm::enumerate((getReductions()))) {
4330 Attribute attr = it.value();
4331 auto intAttr = dyn_cast<IntegerAttr>(attr);
4332 if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
4333 return emitOpError(
"invalid reduction attribute");
4334 auto kind = arith::symbolizeAtomicRMWKind(intAttr.getInt()).value();
4336 return emitOpError(
"result type cannot match reduction attribute");
4342 getLowerBoundsMap().getNumDims())))
4346 getUpperBoundsMap().getNumDims())))
4355 if (newMap ==
getAffineMap() && newOperands == operands)
4357 reset(newMap, newOperands);
4367 bool ubCanonicalized = succeeded(
ub.canonicalize());
4370 if (!lbCanonicalized && !ubCanonicalized)
4373 if (lbCanonicalized)
4375 if (ubCanonicalized)
4376 op.setUpperBounds(
ub.getOperands(),
ub.getAffineMap());
4381LogicalResult AffineParallelOp::fold(FoldAdaptor adaptor,
4382 SmallVectorImpl<OpFoldResult> &results) {
4393 StringRef keyword) {
4396 ValueRange dimOperands = operands.take_front(numDims);
4397 ValueRange symOperands = operands.drop_front(numDims);
4399 for (llvm::APInt groupSize : group) {
4403 unsigned size = groupSize.getZExtValue();
4408 p << keyword <<
'(';
4417void AffineParallelOp::print(OpAsmPrinter &p) {
4418 p <<
" (" << getBody()->getArguments() <<
") = (";
4420 getLowerBoundsOperands(),
"max");
4423 getUpperBoundsOperands(),
"min");
4425 SmallVector<int64_t, 8> steps = getSteps();
4426 bool elideSteps = llvm::all_of(steps, [](int64_t step) {
return step == 1; });
4429 llvm::interleaveComma(steps, p);
4432 if (getNumResults()) {
4434 llvm::interleaveComma(getReductions(), p, [&](
auto &attr) {
4435 arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(
4436 llvm::cast<IntegerAttr>(attr).getInt());
4437 p <<
"\"" << arith::stringifyAtomicRMWKind(sym) <<
"\"";
4439 p <<
") -> (" << getResultTypes() <<
")";
4446 (*this)->getAttrs(),
4447 {AffineParallelOp::getReductionsAttrStrName(),
4448 AffineParallelOp::getLowerBoundsMapAttrStrName(),
4449 AffineParallelOp::getLowerBoundsGroupsAttrStrName(),
4450 AffineParallelOp::getUpperBoundsMapAttrStrName(),
4451 AffineParallelOp::getUpperBoundsGroupsAttrStrName(),
4452 AffineParallelOp::getStepsAttrStrName()});
4459static ParseResult deduplicateAndResolveOperands(
4460 OpAsmParser &parser,
4461 ArrayRef<SmallVector<OpAsmParser::UnresolvedOperand>> operands,
4462 SmallVectorImpl<Value> &uniqueOperands,
4463 SmallVectorImpl<AffineExpr> &replacements,
AffineExprKind kind) {
4465 "expected operands to be dim or symbol expression");
4468 for (
const auto &list : operands) {
4469 SmallVector<Value> valueOperands;
4472 for (Value operand : valueOperands) {
4473 unsigned pos = std::distance(uniqueOperands.begin(),
4474 llvm::find(uniqueOperands, operand));
4475 if (pos == uniqueOperands.size())
4476 uniqueOperands.push_back(operand);
4477 replacements.push_back(
4487enum class MinMaxKind { Min, Max };
4506static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser,
4511 const llvm::StringLiteral tmpAttrStrName =
"__pseudo_bound_map";
4513 StringRef mapName = kind == MinMaxKind::Min
4514 ? AffineParallelOp::getUpperBoundsMapAttrStrName()
4515 : AffineParallelOp::getLowerBoundsMapAttrStrName();
4516 StringRef groupsName =
4517 kind == MinMaxKind::Min
4518 ? AffineParallelOp::getUpperBoundsGroupsAttrStrName()
4519 : AffineParallelOp::getLowerBoundsGroupsAttrStrName();
4525 result.addAttribute(
4526 mapName, AffineMapAttr::get(parser.getBuilder().getEmptyAffineMap()));
4527 result.addAttribute(groupsName, parser.getBuilder().getI32TensorAttr({}));
4531 SmallVector<AffineExpr> flatExprs;
4532 SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> flatDimOperands;
4533 SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> flatSymOperands;
4534 SmallVector<int32_t> numMapsPerGroup;
4535 SmallVector<OpAsmParser::UnresolvedOperand> mapOperands;
4536 auto parseOperands = [&]() {
4538 kind == MinMaxKind::Min ?
"min" :
"max"))) {
4539 mapOperands.clear();
4545 result.attributes.erase(tmpAttrStrName);
4546 llvm::append_range(flatExprs, map.getValue().getResults());
4547 auto operandsRef = llvm::ArrayRef(mapOperands);
4548 auto dimsRef = operandsRef.take_front(map.getValue().getNumDims());
4549 SmallVector<OpAsmParser::UnresolvedOperand> dims(dimsRef);
4550 auto symsRef = operandsRef.drop_front(map.getValue().getNumDims());
4551 SmallVector<OpAsmParser::UnresolvedOperand> syms(symsRef);
4552 flatDimOperands.append(map.getValue().getNumResults(), dims);
4553 flatSymOperands.append(map.getValue().getNumResults(), syms);
4554 numMapsPerGroup.push_back(map.getValue().getNumResults());
4557 flatSymOperands.emplace_back(),
4558 flatExprs.emplace_back())))
4560 numMapsPerGroup.push_back(1);
4567 unsigned totalNumDims = 0;
4568 unsigned totalNumSyms = 0;
4569 for (
unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
4570 unsigned numDims = flatDimOperands[i].size();
4571 unsigned numSyms = flatSymOperands[i].size();
4572 flatExprs[i] = flatExprs[i]
4573 .shiftDims(numDims, totalNumDims)
4574 .shiftSymbols(numSyms, totalNumSyms);
4575 totalNumDims += numDims;
4576 totalNumSyms += numSyms;
4580 SmallVector<Value> dimOperands, symOperands;
4581 SmallVector<AffineExpr> dimRplacements, symRepacements;
4582 if (deduplicateAndResolveOperands(parser, flatDimOperands, dimOperands,
4584 deduplicateAndResolveOperands(parser, flatSymOperands, symOperands,
4588 result.operands.append(dimOperands.begin(), dimOperands.end());
4589 result.operands.append(symOperands.begin(), symOperands.end());
4592 auto flatMap =
AffineMap::get(totalNumDims, totalNumSyms, flatExprs,
4594 flatMap = flatMap.replaceDimsAndSymbols(
4595 dimRplacements, symRepacements, dimOperands.size(), symOperands.size());
4597 result.addAttribute(mapName, AffineMapAttr::get(flatMap));
4607ParseResult AffineParallelOp::parse(OpAsmParser &parser,
4608 OperationState &
result) {
4611 SmallVector<OpAsmParser::Argument, 4> ivs;
4614 parseAffineMapWithMinMax(parser,
result, MinMaxKind::Max) ||
4616 parseAffineMapWithMinMax(parser,
result, MinMaxKind::Min))
4619 AffineMapAttr stepsMapAttr;
4620 NamedAttrList stepsAttrs;
4621 SmallVector<OpAsmParser::UnresolvedOperand, 4> stepsMapOperands;
4623 SmallVector<int64_t, 4> steps(ivs.size(), 1);
4624 result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4628 AffineParallelOp::getStepsAttrStrName(),
4634 SmallVector<int64_t, 4> steps;
4635 auto stepsMap = stepsMapAttr.getValue();
4636 for (
const auto &
result : stepsMap.getResults()) {
4637 auto constExpr = dyn_cast<AffineConstantExpr>(
result);
4640 "steps must be constant integers");
4641 steps.push_back(constExpr.getValue());
4643 result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4649 SmallVector<Attribute, 4> reductions;
4653 auto parseAttributes = [&]() -> ParseResult {
4658 NamedAttrList attrStorage;
4663 std::optional<arith::AtomicRMWKind> reduction =
4664 arith::symbolizeAtomicRMWKind(attrVal.getValue());
4666 return parser.
emitError(loc,
"invalid reduction value: ") << attrVal;
4667 reductions.push_back(
4675 result.addAttribute(AffineParallelOp::getReductionsAttrStrName(),
4683 Region *body =
result.addRegion();
4684 for (
auto &iv : ivs)
4685 iv.type = indexType;
4691 AffineParallelOp::ensureTerminator(*body, builder,
result.location);
4699LogicalResult AffineYieldOp::verify() {
4700 auto *parentOp = (*this)->getParentOp();
4701 auto results = parentOp->getResults();
4702 auto operands = getOperands();
4704 if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
4705 return emitOpError() <<
"only terminates affine.if/for/parallel regions";
4706 if (parentOp->getNumResults() != getNumOperands())
4707 return emitOpError() <<
"parent of yield must have same number of "
4708 "results as the yield operands";
4709 for (
auto it : llvm::zip(results, operands)) {
4711 return emitOpError() <<
"types mismatch between yield op and its parent";
4721void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &
result,
4722 VectorType resultType, AffineMap map,
4724 assert(operands.size() == 1 + map.
getNumInputs() &&
"inconsistent operands");
4725 result.addOperands(operands);
4727 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4728 result.types.push_back(resultType);
4731void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &
result,
4732 VectorType resultType, Value memref,
4734 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
4735 result.addOperands(memref);
4736 result.addOperands(mapOperands);
4737 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4738 result.types.push_back(resultType);
4741void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &
result,
4742 VectorType resultType, Value memref,
4744 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
4745 int64_t rank = memrefType.getRank();
4753void AffineVectorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4754 MLIRContext *context) {
4755 results.
add<SimplifyAffineOp<AffineVectorLoadOp>>(context);
4758ParseResult AffineVectorLoadOp::parse(OpAsmParser &parser,
4759 OperationState &
result) {
4763 MemRefType memrefType;
4764 VectorType resultType;
4765 OpAsmParser::UnresolvedOperand memrefInfo;
4766 AffineMapAttr mapAttr;
4767 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
4771 AffineVectorLoadOp::getMapAttrStrName(),
4781void AffineVectorLoadOp::print(OpAsmPrinter &p) {
4783 if (AffineMapAttr mapAttr =
4784 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4788 {getMapAttrStrName()});
4793static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType,
4794 VectorType vectorType) {
4796 if (memrefType.getElementType() != vectorType.getElementType())
4798 "requires memref and vector types of the same elemental type");
4802LogicalResult AffineVectorLoadOp::verify() {
4805 *
this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4806 getMapOperands(), memrefType,
4807 getNumOperands() - 1)))
4820void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &
result,
4821 Value valueToStore, Value memref, AffineMap map,
4823 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
4824 result.addOperands(valueToStore);
4825 result.addOperands(memref);
4826 result.addOperands(mapOperands);
4827 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4831void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &
result,
4832 Value valueToStore, Value memref,
4834 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
4835 int64_t rank = memrefType.getRank();
4842void AffineVectorStoreOp::getCanonicalizationPatterns(
4843 RewritePatternSet &results, MLIRContext *context) {
4844 results.
add<SimplifyAffineOp<AffineVectorStoreOp>>(context);
4847ParseResult AffineVectorStoreOp::parse(OpAsmParser &parser,
4848 OperationState &
result) {
4851 MemRefType memrefType;
4852 VectorType resultType;
4853 OpAsmParser::UnresolvedOperand storeValueInfo;
4854 OpAsmParser::UnresolvedOperand memrefInfo;
4855 AffineMapAttr mapAttr;
4856 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
4861 AffineVectorStoreOp::getMapAttrStrName(),
4871void AffineVectorStoreOp::print(OpAsmPrinter &p) {
4872 p <<
" " << getValueToStore();
4874 if (AffineMapAttr mapAttr =
4875 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4879 {getMapAttrStrName()});
4880 p <<
" : " <<
getMemRefType() <<
", " << getValueToStore().getType();
4883LogicalResult AffineVectorStoreOp::verify() {
4886 *
this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4887 getMapOperands(), memrefType,
4888 getNumOperands() - 2)))
4901void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4902 OperationState &odsState,
4904 ArrayRef<int64_t> staticBasis,
4905 bool hasOuterBound) {
4906 SmallVector<Type> returnTypes(hasOuterBound ? staticBasis.size()
4907 : staticBasis.size() + 1,
4909 build(odsBuilder, odsState, returnTypes, linearIndex, dynamicBasis,
4913void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4914 OperationState &odsState,
4916 bool hasOuterBound) {
4917 if (hasOuterBound && !basis.empty() && basis.front() ==
nullptr) {
4918 hasOuterBound =
false;
4919 basis = basis.drop_front();
4921 SmallVector<Value> dynamicBasis;
4922 SmallVector<int64_t> staticBasis;
4925 build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4929void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4930 OperationState &odsState,
4932 ArrayRef<OpFoldResult> basis,
4933 bool hasOuterBound) {
4934 if (hasOuterBound && !basis.empty() && basis.front() == OpFoldResult()) {
4935 hasOuterBound =
false;
4936 basis = basis.drop_front();
4938 SmallVector<Value> dynamicBasis;
4939 SmallVector<int64_t> staticBasis;
4941 build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4945void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4946 OperationState &odsState,
4947 Value linearIndex, ArrayRef<int64_t> basis,
4948 bool hasOuterBound) {
4949 build(odsBuilder, odsState, linearIndex,
ValueRange{}, basis, hasOuterBound);
4952LogicalResult AffineDelinearizeIndexOp::verify() {
4953 ArrayRef<int64_t> staticBasis = getStaticBasis();
4954 if (getNumResults() != staticBasis.size() &&
4955 getNumResults() != staticBasis.size() + 1)
4956 return emitOpError(
"should return an index for each basis element and up "
4957 "to one extra index");
4959 auto dynamicMarkersCount = llvm::count_if(staticBasis, ShapedType::isDynamic);
4960 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
4962 "mismatch between dynamic and static basis (kDynamic marker but no "
4963 "corresponding dynamic basis entry) -- this can only happen due to an "
4964 "incorrect fold/rewrite");
4966 if (!llvm::all_of(staticBasis, [](int64_t v) {
4967 return v > 0 || ShapedType::isDynamic(v);
4969 return emitOpError(
"no basis element may be statically non-positive");
4978static std::optional<SmallVector<int64_t>>
4982 uint64_t dynamicBasisIndex = 0;
4985 mutableDynamicBasis.
erase(dynamicBasisIndex);
4987 ++dynamicBasisIndex;
4992 if (dynamicBasisIndex == dynamicBasis.size())
4993 return std::nullopt;
4999 staticBasis.push_back(ShapedType::kDynamic);
5001 staticBasis.push_back(*basisVal);
5008AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
5009 SmallVectorImpl<OpFoldResult> &
result) {
5010 std::optional<SmallVector<int64_t>> maybeStaticBasis =
5012 adaptor.getDynamicBasis());
5013 if (maybeStaticBasis) {
5014 setStaticBasis(*maybeStaticBasis);
5019 if (getNumResults() == 1) {
5020 result.push_back(getLinearIndex());
5024 if (adaptor.getLinearIndex() ==
nullptr)
5027 if (!adaptor.getDynamicBasis().empty())
5030 int64_t highPart = cast<IntegerAttr>(adaptor.getLinearIndex()).getInt();
5031 Type attrType = getLinearIndex().getType();
5033 ArrayRef<int64_t> staticBasis = getStaticBasis();
5034 if (hasOuterBound())
5035 staticBasis = staticBasis.drop_front();
5036 for (int64_t modulus : llvm::reverse(staticBasis)) {
5037 result.push_back(IntegerAttr::get(attrType, llvm::mod(highPart, modulus)));
5038 highPart = llvm::divideFloorSigned(highPart, modulus);
5040 result.push_back(IntegerAttr::get(attrType, highPart));
5045SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getEffectiveBasis() {
5047 if (hasOuterBound()) {
5048 if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
5050 getDynamicBasis().drop_front(), builder);
5052 return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
5056 return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
5059SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getPaddedBasis() {
5060 SmallVector<OpFoldResult> ret = getMixedBasis();
5061 if (!hasOuterBound())
5062 ret.insert(ret.begin(), OpFoldResult());
5069struct DropUnitExtentBasis
5070 :
public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
5073 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
5074 PatternRewriter &rewriter)
const override {
5075 SmallVector<Value> replacements(delinearizeOp->getNumResults(),
nullptr);
5076 std::optional<Value> zero = std::nullopt;
5077 Location loc = delinearizeOp->getLoc();
5078 auto getZero = [&]() -> Value {
5081 return zero.value();
5086 SmallVector<OpFoldResult> newBasis;
5087 for (
auto [index, basis] :
5088 llvm::enumerate(delinearizeOp.getPaddedBasis())) {
5089 std::optional<int64_t> basisVal =
5092 replacements[index] =
getZero();
5094 newBasis.push_back(basis);
5097 if (newBasis.size() == delinearizeOp.getNumResults())
5099 "no unit basis elements");
5101 if (!newBasis.empty()) {
5103 auto newDelinearizeOp = affine::AffineDelinearizeIndexOp::create(
5104 rewriter, loc, delinearizeOp.getLinearIndex(), newBasis);
5110 replacement = newDelinearizeOp->getResult(newIndex++);
5114 rewriter.
replaceOp(delinearizeOp, replacements);
5129struct CancelDelinearizeOfLinearizeDisjointExactTail
5130 :
public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
5133 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
5134 PatternRewriter &rewriter)
const override {
5135 auto linearizeOp = delinearizeOp.getLinearIndex()
5136 .getDefiningOp<affine::AffineLinearizeIndexOp>();
5139 "index doesn't come from linearize");
5141 if (!linearizeOp.getDisjoint())
5144 ValueRange linearizeIns = linearizeOp.getMultiIndex();
5146 SmallVector<OpFoldResult> linearizeBasis = linearizeOp.getMixedBasis();
5147 SmallVector<OpFoldResult> delinearizeBasis = delinearizeOp.getMixedBasis();
5148 size_t numMatches = 0;
5149 for (
auto [linSize, delinSize] : llvm::zip(
5150 llvm::reverse(linearizeBasis), llvm::reverse(delinearizeBasis))) {
5151 if (linSize != delinSize)
5156 if (numMatches == 0)
5158 delinearizeOp,
"final basis element doesn't match linearize");
5161 if (numMatches == linearizeBasis.size() &&
5162 numMatches == delinearizeBasis.size() &&
5163 linearizeIns.size() == delinearizeOp.getNumResults()) {
5164 rewriter.
replaceOp(delinearizeOp, linearizeOp.getMultiIndex());
5168 Value newLinearize = affine::AffineLinearizeIndexOp::create(
5169 rewriter, linearizeOp.getLoc(), linearizeIns.drop_back(numMatches),
5170 ArrayRef<OpFoldResult>{linearizeBasis}.drop_back(numMatches),
5171 linearizeOp.getDisjoint());
5172 auto newDelinearize = affine::AffineDelinearizeIndexOp::create(
5173 rewriter, delinearizeOp.getLoc(), newLinearize,
5174 ArrayRef<OpFoldResult>{delinearizeBasis}.drop_back(numMatches),
5175 delinearizeOp.hasOuterBound());
5176 SmallVector<Value> mergedResults(newDelinearize.getResults());
5177 mergedResults.append(linearizeIns.take_back(numMatches).begin(),
5178 linearizeIns.take_back(numMatches).end());
5179 rewriter.
replaceOp(delinearizeOp, mergedResults);
5197struct SplitDelinearizeSpanningLastLinearizeArg final
5198 : OpRewritePattern<affine::AffineDelinearizeIndexOp> {
5201 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
5202 PatternRewriter &rewriter)
const override {
5203 auto linearizeOp = delinearizeOp.getLinearIndex()
5204 .getDefiningOp<affine::AffineLinearizeIndexOp>();
5207 "index doesn't come from linearize");
5209 if (!linearizeOp.getDisjoint())
5211 "linearize isn't disjoint");
5213 int64_t
target = linearizeOp.getStaticBasis().back();
5214 if (ShapedType::isDynamic(
target))
5216 linearizeOp,
"linearize ends with dynamic basis value");
5218 int64_t sizeToSplit = 1;
5219 size_t elemsToSplit = 0;
5220 ArrayRef<int64_t> basis = delinearizeOp.getStaticBasis();
5221 for (int64_t basisElem : llvm::reverse(basis)) {
5222 if (ShapedType::isDynamic(basisElem))
5224 delinearizeOp,
"dynamic basis element while scanning for split");
5225 sizeToSplit *= basisElem;
5228 if (sizeToSplit >
target)
5230 "overshot last argument size");
5231 if (sizeToSplit ==
target)
5235 if (sizeToSplit <
target)
5237 delinearizeOp,
"product of known basis elements doesn't exceed last "
5238 "linearize argument");
5240 if (elemsToSplit < 2)
5243 "need at least two elements to form the basis product");
5245 Value linearizeWithoutBack = affine::AffineLinearizeIndexOp::create(
5246 rewriter, linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(),
5247 linearizeOp.getDynamicBasis(), linearizeOp.getStaticBasis().drop_back(),
5248 linearizeOp.getDisjoint());
5249 auto delinearizeWithoutSplitPart = affine::AffineDelinearizeIndexOp::create(
5250 rewriter, delinearizeOp.getLoc(), linearizeWithoutBack,
5251 delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),
5252 delinearizeOp.hasOuterBound());
5253 auto delinearizeBack = affine::AffineDelinearizeIndexOp::create(
5254 rewriter, delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(),
5255 basis.take_back(elemsToSplit),
true);
5256 SmallVector<Value> results = llvm::to_vector(
5257 llvm::concat<Value>(delinearizeWithoutSplitPart.getResults(),
5258 delinearizeBack.getResults()));
5259 rewriter.
replaceOp(delinearizeOp, results);
5266void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
5267 RewritePatternSet &
patterns, MLIRContext *context) {
5269 .insert<CancelDelinearizeOfLinearizeDisjointExactTail,
5270 DropUnitExtentBasis, SplitDelinearizeSpanningLastLinearizeArg>(
5278void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
5279 OperationState &odsState,
5282 if (!basis.empty() && basis.front() == Value())
5283 basis = basis.drop_front();
5284 SmallVector<Value> dynamicBasis;
5285 SmallVector<int64_t> staticBasis;
5288 build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
5291void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
5292 OperationState &odsState,
5294 ArrayRef<OpFoldResult> basis,
5296 if (!basis.empty() && basis.front() == OpFoldResult())
5297 basis = basis.drop_front();
5298 SmallVector<Value> dynamicBasis;
5299 SmallVector<int64_t> staticBasis;
5301 build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
5304void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
5305 OperationState &odsState,
5307 ArrayRef<int64_t> basis,
bool disjoint) {
5308 build(odsBuilder, odsState, multiIndex,
ValueRange{}, basis, disjoint);
5311LogicalResult AffineLinearizeIndexOp::verify() {
5312 size_t numIndexes = getMultiIndex().size();
5313 size_t numBasisElems = getStaticBasis().size();
5314 if (numIndexes != numBasisElems && numIndexes != numBasisElems + 1)
5315 return emitOpError(
"should be passed a basis element for each index except "
5316 "possibly the first");
5318 auto dynamicMarkersCount =
5319 llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
5320 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
5322 "mismatch between dynamic and static basis (kDynamic marker but no "
5323 "corresponding dynamic basis entry) -- this can only happen due to an "
5324 "incorrect fold/rewrite");
5329OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
5330 std::optional<SmallVector<int64_t>> maybeStaticBasis =
5332 adaptor.getDynamicBasis());
5333 if (maybeStaticBasis) {
5334 setStaticBasis(*maybeStaticBasis);
5338 if (getMultiIndex().empty())
5339 return IntegerAttr::get(getResult().
getType(), 0);
5342 if (getMultiIndex().size() == 1)
5343 return getMultiIndex().front();
5345 if (llvm::is_contained(adaptor.getMultiIndex(),
nullptr))
5348 if (!adaptor.getDynamicBasis().empty())
5353 for (
auto [length, indexAttr] :
5354 llvm::zip_first(llvm::reverse(getStaticBasis()),
5355 llvm::reverse(adaptor.getMultiIndex()))) {
5356 result =
result + cast<IntegerAttr>(indexAttr).getInt() * stride;
5357 stride = stride * length;
5360 if (!hasOuterBound())
5363 cast<IntegerAttr>(adaptor.getMultiIndex().front()).getInt() * stride;
5368SmallVector<OpFoldResult> AffineLinearizeIndexOp::getEffectiveBasis() {
5370 if (hasOuterBound()) {
5371 if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
5373 getDynamicBasis().drop_front(), builder);
5375 return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
5379 return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
5382SmallVector<OpFoldResult> AffineLinearizeIndexOp::getPaddedBasis() {
5383 SmallVector<OpFoldResult> ret = getMixedBasis();
5384 if (!hasOuterBound())
5385 ret.insert(ret.begin(), OpFoldResult());
5400struct DropLinearizeUnitComponentsIfDisjointOrZero final
5401 : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5404 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5405 PatternRewriter &rewriter)
const override {
5407 size_t numIndices = multiIndex.size();
5408 SmallVector<Value> newIndices;
5409 newIndices.reserve(numIndices);
5410 SmallVector<OpFoldResult> newBasis;
5411 newBasis.reserve(numIndices);
5413 if (!op.hasOuterBound()) {
5414 newIndices.push_back(multiIndex.front());
5415 multiIndex = multiIndex.drop_front();
5418 SmallVector<OpFoldResult> basis = op.getMixedBasis();
5419 for (
auto [index, basisElem] : llvm::zip_equal(multiIndex, basis)) {
5421 if (!basisEntry || *basisEntry != 1) {
5422 newIndices.push_back(index);
5423 newBasis.push_back(basisElem);
5428 if (!op.getDisjoint() && (!indexValue || *indexValue != 0)) {
5429 newIndices.push_back(index);
5430 newBasis.push_back(basisElem);
5434 if (newIndices.size() == numIndices)
5436 "no unit basis entries to replace");
5438 if (newIndices.empty()) {
5443 op, newIndices, newBasis, op.getDisjoint());
5449 ArrayRef<OpFoldResult> terms) {
5450 int64_t nDynamic = 0;
5451 SmallVector<Value> dynamicPart;
5453 for (OpFoldResult term : terms) {
5460 dynamicPart.push_back(cast<Value>(term));
5464 if (
auto constant = dyn_cast<AffineConstantExpr>(
result))
5466 return AffineApplyOp::create(builder, loc,
result, dynamicPart).getResult();
5496struct CancelLinearizeOfDelinearizePortion final
5497 : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5507 unsigned linStart = 0;
5508 unsigned delinStart = 0;
5509 unsigned length = 0;
5513 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp,
5514 PatternRewriter &rewriter)
const override {
5515 SmallVector<Match> matches;
5517 const SmallVector<OpFoldResult> linBasis = linearizeOp.getPaddedBasis();
5518 ArrayRef<OpFoldResult> linBasisRef = linBasis;
5520 ValueRange multiIndex = linearizeOp.getMultiIndex();
5521 unsigned numLinArgs = multiIndex.size();
5522 unsigned linArgIdx = 0;
5525 llvm::SmallPtrSet<Operation *, 2> alreadyMatchedDelinearize;
5526 while (linArgIdx < numLinArgs) {
5527 auto asResult = dyn_cast<OpResult>(multiIndex[linArgIdx]);
5533 auto delinearizeOp =
5534 dyn_cast<AffineDelinearizeIndexOp>(asResult.getOwner());
5535 if (!delinearizeOp) {
5552 unsigned delinArgIdx = asResult.getResultNumber();
5553 SmallVector<OpFoldResult> delinBasis = delinearizeOp.getPaddedBasis();
5554 OpFoldResult firstDelinBound = delinBasis[delinArgIdx];
5555 OpFoldResult firstLinBound = linBasis[linArgIdx];
5556 bool boundsMatch = firstDelinBound == firstLinBound;
5557 bool bothAtFront = linArgIdx == 0 && delinArgIdx == 0;
5558 bool knownByDisjoint =
5559 linearizeOp.getDisjoint() && delinArgIdx == 0 && !firstDelinBound;
5560 if (!boundsMatch && !bothAtFront && !knownByDisjoint) {
5566 unsigned numDelinOuts = delinearizeOp.getNumResults();
5567 for (; j + linArgIdx < numLinArgs && j + delinArgIdx < numDelinOuts;
5569 if (multiIndex[linArgIdx + j] !=
5570 delinearizeOp.getResult(delinArgIdx + j))
5572 if (linBasis[linArgIdx + j] != delinBasis[delinArgIdx + j])
5578 if (j <= 1 || !alreadyMatchedDelinearize.insert(delinearizeOp).second) {
5582 matches.push_back(Match{delinearizeOp, linArgIdx, delinArgIdx, j});
5586 if (matches.empty())
5588 linearizeOp,
"no run of delinearize outputs to deal with");
5593 SmallVector<SmallVector<Value>> delinearizeReplacements;
5595 SmallVector<Value> newIndex;
5596 newIndex.reserve(numLinArgs);
5597 SmallVector<OpFoldResult> newBasis;
5598 newBasis.reserve(numLinArgs);
5599 unsigned prevMatchEnd = 0;
5600 for (Match m : matches) {
5601 unsigned gap = m.linStart - prevMatchEnd;
5602 llvm::append_range(newIndex, multiIndex.slice(prevMatchEnd, gap));
5603 llvm::append_range(newBasis, linBasisRef.slice(prevMatchEnd, gap));
5605 prevMatchEnd = m.linStart + m.length;
5607 PatternRewriter::InsertionGuard g(rewriter);
5610 ArrayRef<OpFoldResult> basisToMerge =
5611 linBasisRef.slice(m.linStart, m.length);
5614 OpFoldResult newSize =
5619 newIndex.push_back(m.delinearize.getLinearIndex());
5620 newBasis.push_back(newSize);
5622 delinearizeReplacements.push_back(SmallVector<Value>());
5626 SmallVector<Value> newDelinResults;
5627 SmallVector<OpFoldResult> newDelinBasis = m.delinearize.getPaddedBasis();
5628 newDelinBasis.erase(newDelinBasis.begin() + m.delinStart,
5629 newDelinBasis.begin() + m.delinStart + m.length);
5630 newDelinBasis.insert(newDelinBasis.begin() + m.delinStart, newSize);
5631 auto newDelinearize = AffineDelinearizeIndexOp::create(
5632 rewriter, m.delinearize.getLoc(), m.delinearize.getLinearIndex(),
5638 Value combinedElem = newDelinearize.getResult(m.delinStart);
5639 auto residualDelinearize = AffineDelinearizeIndexOp::create(
5640 rewriter, m.delinearize.getLoc(), combinedElem, basisToMerge);
5645 llvm::append_range(newDelinResults,
5646 newDelinearize.getResults().take_front(m.delinStart));
5647 llvm::append_range(newDelinResults, residualDelinearize.getResults());
5650 newDelinearize.getResults().drop_front(m.delinStart + 1));
5652 delinearizeReplacements.push_back(newDelinResults);
5653 newIndex.push_back(combinedElem);
5654 newBasis.push_back(newSize);
5656 llvm::append_range(newIndex, multiIndex.drop_front(prevMatchEnd));
5657 llvm::append_range(newBasis, linBasisRef.drop_front(prevMatchEnd));
5659 linearizeOp, newIndex, newBasis, linearizeOp.getDisjoint());
5661 for (
auto [m, newResults] :
5662 llvm::zip_equal(matches, delinearizeReplacements)) {
5663 if (newResults.empty())
5665 rewriter.
replaceOp(m.delinearize, newResults);
5676struct DropLinearizeLeadingZero final
5677 : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5680 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5681 PatternRewriter &rewriter)
const override {
5682 Value leadingIdx = op.getMultiIndex().front();
5686 if (op.getMultiIndex().size() == 1) {
5691 SmallVector<OpFoldResult> mixedBasis = op.getMixedBasis();
5692 ArrayRef<OpFoldResult> newMixedBasis = mixedBasis;
5693 if (op.hasOuterBound())
5694 newMixedBasis = newMixedBasis.drop_front();
5697 op, op.getMultiIndex().drop_front(), newMixedBasis, op.getDisjoint());
5703void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
5704 RewritePatternSet &
patterns, MLIRContext *context) {
5705 patterns.add<CancelLinearizeOfDelinearizePortion, DropLinearizeLeadingZero,
5706 DropLinearizeUnitComponentsIfDisjointOrZero>(context);
5713#define GET_OP_CLASSES
5714#include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
static AffineForOp buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb, int64_t ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds known to be constants.
static bool hasTrivialZeroTripCount(AffineForOp op)
Returns true if the affine.for has zero iterations in trivial cases.
static LogicalResult verifyMemoryOpIndexing(AffineMemOpTy op, AffineMapAttr mapAttr, Operation::operand_range mapOperands, MemRefType memrefType, unsigned numIndexOperands)
Verify common indexing invariants of affine.load, affine.store, affine.vector_load and affine....
static void printAffineMinMaxOp(OpAsmPrinter &p, T op)
static bool isResultTypeMatchAtomicRMWKind(Type resultType, arith::AtomicRMWKind op)
static bool remainsLegalAfterInline(Value value, Region *src, Region *dest, const IRMapping &mapping, function_ref< bool(Value, Region *)> legalityCheck)
Checks if value known to be a legal affine dimension or symbol in src region remains legal if the ope...
static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr, DenseIntElementsAttr group, ValueRange operands, StringRef keyword)
Prints a lower(upper) bound of an affine parallel loop with max(min) conditions in it.
static OpFoldResult foldMinMaxOp(T op, ArrayRef< Attribute > operands)
Fold an affine min or max operation with the given operands.
static bool isTopLevelValueOrAbove(Value value, Region *region)
A utility function to check if a value is defined at the top level of region or is an argument of reg...
static LogicalResult canonicalizeLoopBounds(AffineForOp forOp)
Canonicalize the bounds of the given loop.
static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims, unsigned numSymbols, ArrayRef< Value > operands)
Simplify expr while exploiting information from the values in operands.
static bool isValidAffineIndexOperand(Value value, Region *region)
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
static ParseResult parseBound(bool isLower, OperationState &result, OpAsmParser &p)
Parse a for operation loop bounds.
static std::optional< SmallVector< int64_t > > foldCstValueToCstAttrBasis(ArrayRef< OpFoldResult > mixedBasis, MutableOperandRange mutableDynamicBasis, ArrayRef< Attribute > dynamicBasis)
Given mixed basis of affine.delinearize_index/linearize_index replace constant SSA values with the co...
static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
static void simplifyMinOrMaxExprWithOperands(AffineMap &map, ArrayRef< Value > operands, bool isMax)
Simplify the expressions in map while making use of lower or upper bounds of its operands.
static ParseResult parseAffineMinMaxOp(OpAsmParser &parser, OperationState &result)
static LogicalResult replaceAffineDelinearizeIndexInverseExpression(AffineDelinearizeIndexOp delinOp, Value resultToReplace, AffineMap *map, SmallVectorImpl< Value > &dims, SmallVectorImpl< Value > &syms)
If this map contains of the expression x_1 + x_1 * C_1 + ... x_n * C_N + / ... (not necessarily in or...
static void composeSetAndOperands(IntegerSet &set, SmallVectorImpl< Value > &operands, bool composeAffineMin=false)
Compose any affine.apply ops feeding into operands of the integer set set by composing the maps of su...
static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, Region *region)
Returns true if the 'index' dimension of the memref defined by memrefDefOp is a statically shaped one...
static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef< Value > operands, int64_t k)
Check if e is known to be: 0 <= e < k.
static AffineForOp buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds that may or may not be constants.
static void simplifyMapWithOperands(AffineMap &map, ArrayRef< Value > operands)
Simplify the map while exploiting information on the values in operands.
static void printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, OpAsmPrinter &printer)
Prints dimension and symbol list.
static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef< Value > operands)
Returns the largest known divisor of e.
static void composeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands, bool composeAffineMin=false)
Iterate over operands and fold away all those produced by an AffineApplyOp iteratively.
static void legalizeDemotedDims(MapOrSet &mapOrSet, SmallVectorImpl< Value > &operands)
A valid affine dimension may appear as a symbol in affine.apply operations.
static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static void buildAffineLoopNestImpl(OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, LoopCreatorTy &&loopCreatorFn)
Builds an affine loop nest, using "loopCreatorFn" to create individual loop operations.
static LogicalResult foldLoopBounds(AffineForOp forOp)
Fold the constant bounds of a loop.
static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp, AffineExpr dimOrSym, AffineMap *map, ValueRange dims, ValueRange syms)
Assuming dimOrSym is a quantity in the apply op map map and defined by minOp = affine_min(x_1,...
static SmallVector< OpFoldResult > AffineForEmptyLoopFolder(AffineForOp forOp)
Fold the empty loop.
static LogicalResult verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, unsigned numDims)
Utility function to verify that a set of operands are valid dimension and symbol identifiers.
static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region)
Returns true if the result of the dim op is a valid symbol for region.
static bool isQTimesDPlusR(AffineExpr e, ArrayRef< Value > operands, int64_t &div, AffineExpr "ientTimesDiv, AffineExpr &rem)
Check if expression e is of the form d*e_1 + e_2 where 0 <= e_2 < d.
static LogicalResult replaceDimOrSym(AffineMap *map, unsigned dimOrSymbolPosition, SmallVectorImpl< Value > &dims, SmallVectorImpl< Value > &syms, bool replaceAffineMin)
Replace all occurrences of AffineExpr at position pos in map by the defining AffineApplyOp expression...
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static std::optional< uint64_t > getTrivialConstantTripCount(AffineForOp forOp)
Returns constant trip count in trivial cases.
static LogicalResult verifyAffineMinMaxOp(T op)
static void printBound(AffineMapAttr boundMap, Operation::operand_range boundOperands, const char *prefix, OpAsmPrinter &p)
static void shortenAddChainsContainingAll(AffineExpr e, const llvm::SmallDenseSet< AffineExpr, 4 > &exprsToRemove, AffineExpr newVal, DenseMap< AffineExpr, AffineExpr > &replacementsMap)
Recursively traverse e.
static void composeMultiResultAffineMap(AffineMap &map, SmallVectorImpl< Value > &operands, bool composeAffineMin=false)
Composes the given affine map with the given list of operands, pulling in the maps from any affine....
static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map)
Canonicalize the result expression order of an affine map and return success if the order changed.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value getMemRef(Operation *memOp)
Returns the memref being read/written by a memref/affine load/store op.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static Operation::operand_range getLowerBoundOperands(AffineForOp forOp)
static Operation::operand_range getUpperBoundOperands(AffineForOp forOp)
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
AffineExprKind getKind() const
Return the classification for this type.
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
MLIRContext * getContext() const
AffineExpr replace(AffineExpr expr, AffineExpr replacement) const
Sparse replace method.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
MLIRContext * getContext() const
bool isFunctionOfDim(unsigned position) const
Return true if any affine expression involves AffineDimExpr position.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ... numDims) by dims[offset + shift ... shift + numDims).
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
bool isFunctionOfSymbol(unsigned position) const
Return true if any affine expression involves AffineSymbolExpr position.
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
unsigned getNumInputs() const
AffineMap shiftSymbols(unsigned shift, unsigned offset=0) const
Replace symbols[offset ... numSymbols) by symbols[offset + shift ... shift + numSymbols).
AffineExpr getResult(unsigned idx) const
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
static AffineMap getConstantMap(int64_t val, MLIRContext *context)
Returns a single constant result affine map.
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
LogicalResult constantFold(ArrayRef< Attribute > operandConstants, SmallVectorImpl< Attribute > &results, bool *hasPoison=nullptr) const
Folds the results of the application of an affine map on the provided operands to a constant if possi...
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
AffineMap getDimIdentityMap()
AffineMap getMultiDimIdentityMap(unsigned rank)
AffineExpr getAffineSymbolExpr(unsigned position)
AffineExpr getAffineConstantExpr(int64_t constant)
DenseIntElementsAttr getI32TensorAttr(ArrayRef< int32_t > values)
Tensor-typed DenseIntElementsAttr getters.
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
AffineMap getEmptyAffineMap()
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
AffineMap getSymbolIdentityMap()
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
An attribute that represents a reference to a dense integer vector or tensor object.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
An integer set representing a conjunction of one or more affine equalities and inequalities.
unsigned getNumDims() const
static IntegerSet get(unsigned dimCount, unsigned symbolCount, ArrayRef< AffineExpr > constraints, ArrayRef< bool > eqFlags)
MLIRContext * getContext() const
unsigned getNumInputs() const
ArrayRef< AffineExpr > getConstraints() const
ArrayRef< bool > getEqFlags() const
Returns the equality bits, which specify whether each of the constraints is an equality or inequality...
unsigned getNumSymbols() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
void erase(unsigned subStart, unsigned subLen=1)
Erase the operands within the given sub-range.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseAffineMapOfSSAIds(SmallVectorImpl< UnresolvedOperand > &operands, Attribute &map, StringRef attrName, NamedAttrList &attrs, Delimiter delimiter=Delimiter::Square)=0
Parses an affine map attribute where dims and symbols are SSA operands.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseAffineExprOfSSAIds(SmallVectorImpl< UnresolvedOperand > &dimOperands, SmallVectorImpl< UnresolvedOperand > &symbOperands, AffineExpr &expr)=0
Parses an affine expression where dims and symbols are SSA operands.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, ValueRange symOperands)=0
Prints an affine expression of SSA ids with SSA id names used instead of dims and symbols.
virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, ValueRange operands)=0
Prints an affine map of SSA ids, where SSA id names are used in place of dims/symbols.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A trait of region holding operations that defines a new scope for polyhedral optimization purposes.
This class provides the API for ops that are known to be isolated from above.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OperandRange operand_range
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
operand_range::iterator operand_iterator
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
bool isParent() const
Returns true if branching from the parent op.
Operation * getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
bool hasOneBlock()
Return true if this region has exactly one block.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
static DefaultResource * get()
std::vector< SmallVector< int64_t, 8 > > operandExprStack
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
A variable that can be added to the constraint set as a "column".
static bool compare(const Variable &lhs, ComparisonOperator cmp, const Variable &rhs)
Return "true" if "lhs cmp rhs" was proven to hold.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
AffineBound represents a lower or upper bound in the for operation.
AffineDmaStartOp starts a non-blocking DMA operation that transfers data from a source memref to a de...
OpOperand & getTagMemRefMutable()
Value getTagMemRef()
Returns the Tag MemRef for this DMA operation.
static void build(OpBuilder &builder, OperationState &result, Value srcMemRef, AffineMap srcMap, ValueRange srcIndices, Value destMemRef, AffineMap dstMap, ValueRange destIndices, Value tagMemRef, AffineMap tagMap, ValueRange tagIndices, Value numElements, Value stride=nullptr, Value elementsPerStride=nullptr)
operand_range getDstIndices()
Returns the destination memref indices for this DMA operation.
Value getNumElementsPerStride()
Returns the number of elements to transfer per stride for this DMA op.
AffineMapAttr getTagMapAttr()
operand_range getSrcIndices()
Returns the source memref affine map indices for this DMA operation.
AffineMapAttr getSrcMapAttr()
bool isStrided()
Returns true if this DMA operation is strided, returns false otherwise.
AffineMap getDstMap()
Returns the affine map used to access the destination memref.
void print(OpAsmPrinter &p)
OpOperand & getDstMemRefMutable()
Value getDstMemRef()
Returns the destination MemRefType for this DMA operation.
static StringRef getSrcMapAttrStrName()
AffineMapAttr getDstMapAttr()
unsigned getSrcMemRefOperandIndex()
Returns the operand index of the source memref.
unsigned getTagMemRefOperandIndex()
Returns the operand index of the tag memref.
static StringRef getTagMapAttrStrName()
LogicalResult verifyInvariantsImpl()
void getEffects(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects)
MemRefType getSrcMemRefType()
MemRefType getTagMemRefType()
AffineMap getSrcMap()
Returns the affine map used to access the source memref.
Value getNumElements()
Returns the number of elements being transferred by this DMA operation.
static AffineDmaStartOp create(OpBuilder &builder, Location location, Value srcMemRef, AffineMap srcMap, ValueRange srcIndices, Value destMemRef, AffineMap dstMap, ValueRange destIndices, Value tagMemRef, AffineMap tagMap, ValueRange tagIndices, Value numElements, Value stride=nullptr, Value elementsPerStride=nullptr)
AffineMap getTagMap()
Returns the affine map used to access the tag memref.
static ParseResult parse(OpAsmParser &parser, OperationState &result)
Value getStride()
Returns the stride value for this DMA operation.
unsigned getDstMemRefOperandIndex()
Returns the operand index of the destination memref.
static StringRef getDstMapAttrStrName()
static StringRef getOperationName()
Value getSrcMemRef()
Returns the source MemRefType for this DMA operation.
OpOperand & getSrcMemRefMutable()
operand_range getTagIndices()
Returns the tag memref indices for this DMA operation.
MemRefType getDstMemRefType()
LogicalResult fold(ArrayRef< Attribute > cstOperands, SmallVectorImpl< OpFoldResult > &results)
AffineDmaWaitOp blocks until the completion of a DMA operation associated with the tag element 'tag[i...
Value getNumElements()
Returns the number of elements transferred by the associated DMA op.
LogicalResult verifyInvariantsImpl()
static StringRef getOperationName()
Value getTagMemRef()
Returns the Tag MemRef associated with the DMA operation being waited on.
static ParseResult parse(OpAsmParser &parser, OperationState &result)
static StringRef getTagMapAttrStrName()
void getEffects(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects)
LogicalResult fold(ArrayRef< Attribute > cstOperands, SmallVectorImpl< OpFoldResult > &results)
AffineMapAttr getTagMapAttr()
void print(OpAsmPrinter &p)
static AffineDmaWaitOp create(OpBuilder &builder, Location location, Value tagMemRef, AffineMap tagMap, ValueRange tagIndices, Value numElements)
static void build(OpBuilder &builder, OperationState &result, Value tagMemRef, AffineMap tagMap, ValueRange tagIndices, Value numElements)
OpOperand & getTagMemRefMutable()
operand_range getTagIndices()
Returns the tag memref index for this DMA operation.
An AffineValueMap is an affine map plus its ML value operands and results for analysis purposes.
LogicalResult canonicalize()
Attempts to canonicalize the map and operands.
ArrayRef< Value > getOperands() const
AffineExpr getResult(unsigned i)
AffineMap getAffineMap() const
void reset(AffineMap map, ValueRange operands, ValueRange results={})
unsigned getNumResults() const
static void difference(const AffineValueMap &a, const AffineValueMap &b, AffineValueMap *res)
Return the value map that is the difference of value maps 'a' and 'b', represented as an affine map a...
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Operation * getOwner() const
Return the owner of this operand.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
void buildAffineLoopNest(OpBuilder &builder, Location loc, ArrayRef< int64_t > lbs, ArrayRef< int64_t > ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn=nullptr)
Builds a perfect nest of affine.for loops, i.e., each loop except the innermost one contains only ano...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
void extractForInductionVars(ArrayRef< AffineForOp > forInsts, SmallVectorImpl< Value > *ivs)
Extracts the induction variables from a list of AffineForOps and places them in the output argument i...
bool isValidDim(Value value)
Returns true if the given Value can be used as a dimension id in the region of the closest surroundin...
bool isAffineInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp or AffineParallelOp.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
OpFoldResult computeProduct(Location loc, OpBuilder &builder, ArrayRef< OpFoldResult > terms)
Return the product of terms, creating an affine.apply if any of them are non-constant values.
AffineForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
bool isAffineForInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
bool isTopLevelValue(Value value)
A utility function to check if a value is defined at the top level of an op with trait AffineScope or...
Region * getAffineAnalysisScope(Operation *op)
Returns the closest region enclosing op that is held by a non-affine operation; nullptr if there is n...
void fullyComposeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands, bool composeAffineMin=false)
Given an affine map map and its input operands, this method composes into map, maps of AffineApplyOps...
void canonicalizeSetAndOperands(IntegerSet *set, SmallVectorImpl< Value > *operands)
Canonicalizes an integer set the same way canonicalizeMapAndOperands does for affine maps.
void extractInductionVars(ArrayRef< Operation * > affineOps, SmallVectorImpl< Value > &ivs)
Extracts the induction variables from a list of either AffineForOp or AffineParallelOp and places the...
bool isValidSymbol(Value value)
Returns true if the given value can be used as a symbol in the region of the closest surrounding op t...
AffineParallelOp getAffineParallelInductionVarOwner(Value val)
Returns true if the provided value is among the induction variables of an AffineParallelOp.
Region * getAffineScope(Operation *op)
Returns the closest region enclosing op that is held by an operation with trait AffineScope; nullptr ...
ParseResult parseDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl< Value > &operands, unsigned &numDims)
Parses dimension and symbol list.
bool isAffineParallelInductionVar(Value val)
Returns true if val is the induction variable of an AffineParallelOp.
AffineMinOp makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns an AffineMinOp obtained by composing map and operands with AffineApplyOps supplying those ope...
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
AffineMap removeDuplicateExprs(AffineMap map)
Returns a map with the same dimension and symbol count as map, but whose results are the unique affin...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::function< SmallVector< Value >( OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::optional< int64_t > getBoundForAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols, ArrayRef< std::optional< int64_t > > constLowerBounds, ArrayRef< std::optional< int64_t > > constUpperBounds, bool isUpper)
Get a lower or upper (depending on isUpper) bound for expr while using the constant lower and upper b...
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ SymbolId
Symbolic identifier.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
llvm::TypeSwitch< T, ResultT > TypeSwitch
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
AffineMap foldAttributesIntoMap(Builder &b, AffineMap map, ArrayRef< OpFoldResult > operands, SmallVector< Value > &remainingValues)
Fold all attributes among the given operands into the affine map.
llvm::function_ref< Fn > function_ref
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Canonicalize the affine map result expression order of an affine min/max operation.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Remove duplicated expressions in affine min/max ops.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Merge an affine min/max op to its consumers if its consumer is also an affine min/max op.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.