23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/ScopeExit.h"
25 #include "llvm/ADT/SmallBitVector.h"
26 #include "llvm/ADT/SmallVectorExtras.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/MathExtras.h"
36 using llvm::divideCeilSigned;
37 using llvm::divideFloorSigned;
40 #define DEBUG_TYPE "affine-ops"
42 #include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
49 if (
auto arg = llvm::dyn_cast<BlockArgument>(value))
50 return arg.getParentRegion() == region;
73 if (llvm::isa<BlockArgument>(value))
74 return legalityCheck(mapping.
lookup(value), dest);
81 bool isDimLikeOp = isa<ShapedDimOpInterface>(value.
getDefiningOp());
92 return llvm::all_of(values, [&](
Value v) {
99 template <
typename OpTy>
102 static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
103 AffineWriteOpInterface>::value,
104 "only ops with affine read/write interface are supported");
111 dimOperands, src, dest, mapping,
115 symbolOperands, src, dest, mapping,
132 op.getMapOperands(), src, dest, mapping,
137 op.getMapOperands(), src, dest, mapping,
164 if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
169 if (!llvm::hasSingleElement(*src))
177 if (
auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
178 if (iface.hasNoEffect())
186 .Case<AffineApplyOp, AffineReadOpInterface,
187 AffineWriteOpInterface>([&](
auto op) {
212 isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
216 bool shouldAnalyzeRecursively(
Operation *op)
const final {
return true; }
224 void AffineDialect::initialize() {
227 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
229 addInterfaces<AffineInlinerInterface>();
230 declarePromisedInterfaces<ValueBoundsOpInterface, AffineApplyOp, AffineMaxOp,
239 if (
auto poison = dyn_cast<ub::PoisonAttr>(value))
240 return builder.
create<ub::PoisonOp>(loc, type, poison);
241 return arith::ConstantOp::materialize(builder, value, type, loc);
249 if (
auto arg = llvm::dyn_cast<BlockArgument>(value)) {
265 while (
auto *parentOp = curOp->getParentOp()) {
288 auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
290 isa<AffineForOp, AffineParallelOp>(parentOp));
311 auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->
getParentOp();
312 return isa<AffineForOp, AffineParallelOp>(parentOp);
316 if (
auto applyOp = dyn_cast<AffineApplyOp>(op))
317 return applyOp.isValidDim(region);
320 if (
auto dimOp = dyn_cast<ShapedDimOpInterface>(op))
328 template <
typename AnyMemRefDefOp>
331 MemRefType memRefType = memrefDefOp.getType();
334 if (index >= memRefType.getRank()) {
339 if (!memRefType.isDynamicDim(index))
342 unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
343 return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
355 if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
363 if (!index.has_value())
367 Operation *op = dimOp.getShapedValue().getDefiningOp();
368 while (
auto castOp = dyn_cast<memref::CastOp>(op)) {
370 if (isa<UnrankedMemRefType>(castOp.getSource().getType()))
372 op = castOp.getSource().getDefiningOp();
377 int64_t i = index.value();
379 .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
381 .Default([](
Operation *) {
return false; });
448 if (
isPure(defOp) && llvm::all_of(defOp->getOperands(), [&](
Value operand) {
449 return affine::isValidSymbol(operand, region);
455 if (
auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp))
479 printer <<
'(' << operands.take_front(numDims) <<
')';
480 if (operands.size() > numDims)
481 printer <<
'[' << operands.drop_front(numDims) <<
']';
491 numDims = opInfos.size();
505 template <
typename OpTy>
510 for (
auto operand : operands) {
511 if (opIt++ < numDims) {
513 return op.emitOpError(
"operand cannot be used as a dimension id");
515 return op.emitOpError(
"operand cannot be used as a symbol");
526 return AffineValueMap(getAffineMap(), getOperands(), getResult());
533 AffineMapAttr mapAttr;
539 auto map = mapAttr.getValue();
541 if (map.getNumDims() != numDims ||
542 numDims + map.getNumSymbols() != result.
operands.size()) {
544 "dimension or symbol index mismatch");
547 result.
types.append(map.getNumResults(), indexTy);
552 p <<
" " << getMapAttr();
554 getAffineMap().getNumDims(), p);
565 "operand count and affine map dimension and symbol count must match");
569 return emitOpError(
"mapping must produce one value");
577 return llvm::all_of(getOperands(),
585 return llvm::all_of(getOperands(),
592 return llvm::all_of(getOperands(),
599 return llvm::all_of(getOperands(), [&](
Value operand) {
605 auto map = getAffineMap();
608 auto expr = map.getResult(0);
609 if (
auto dim = dyn_cast<AffineDimExpr>(expr))
610 return getOperand(dim.getPosition());
611 if (
auto sym = dyn_cast<AffineSymbolExpr>(expr))
612 return getOperand(map.getNumDims() + sym.getPosition());
616 bool hasPoison =
false;
618 map.constantFold(adaptor.getMapOperands(), result, &hasPoison);
621 if (failed(foldResult))
638 auto dimExpr = dyn_cast<AffineDimExpr>(e);
648 Value operand = operands[dimExpr.getPosition()];
649 int64_t operandDivisor = 1;
653 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() == 0) {
654 operandDivisor = forOp.getStepAsInt();
656 uint64_t lbLargestKnownDivisor =
657 forOp.getLowerBoundMap().getLargestKnownDivisorOfMapExprs();
658 operandDivisor = std::gcd(lbLargestKnownDivisor, forOp.getStepAsInt());
661 return operandDivisor;
668 if (
auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
669 int64_t constVal = constExpr.getValue();
670 return constVal >= 0 && constVal < k;
672 auto dimExpr = dyn_cast<AffineDimExpr>(e);
675 Value operand = operands[dimExpr.getPosition()];
679 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() >= 0 &&
680 forOp.hasConstantUpperBound() && forOp.getConstantUpperBound() <= k) {
696 auto bin = dyn_cast<AffineBinaryOpExpr>(e);
704 quotientTimesDiv = llhs;
710 quotientTimesDiv = rlhs;
720 if (forOp && forOp.hasConstantLowerBound())
721 return forOp.getConstantLowerBound();
728 if (!forOp || !forOp.hasConstantUpperBound())
733 if (forOp.hasConstantLowerBound()) {
734 return forOp.getConstantUpperBound() - 1 -
735 (forOp.getConstantUpperBound() - forOp.getConstantLowerBound() - 1) %
736 forOp.getStepAsInt();
738 return forOp.getConstantUpperBound() - 1;
749 constLowerBounds.reserve(operands.size());
750 constUpperBounds.reserve(operands.size());
751 for (
Value operand : operands) {
756 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr))
757 return constExpr.getValue();
772 constLowerBounds.reserve(operands.size());
773 constUpperBounds.reserve(operands.size());
774 for (
Value operand : operands) {
779 std::optional<int64_t> lowerBound;
780 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
781 lowerBound = constExpr.getValue();
784 constLowerBounds, constUpperBounds,
795 auto binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
806 binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
814 lhs = binExpr.getLHS();
815 rhs = binExpr.getRHS();
816 auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
820 int64_t rhsConstVal = rhsConst.getValue();
822 if (rhsConstVal <= 0)
827 std::optional<int64_t> lhsLbConst =
829 std::optional<int64_t> lhsUbConst =
831 if (lhsLbConst && lhsUbConst) {
832 int64_t lhsLbConstVal = *lhsLbConst;
833 int64_t lhsUbConstVal = *lhsUbConst;
837 divideFloorSigned(lhsLbConstVal, rhsConstVal) ==
838 divideFloorSigned(lhsUbConstVal, rhsConstVal)) {
840 divideFloorSigned(lhsLbConstVal, rhsConstVal), context);
846 divideCeilSigned(lhsLbConstVal, rhsConstVal) ==
847 divideCeilSigned(lhsUbConstVal, rhsConstVal)) {
854 lhsLbConstVal < rhsConstVal && lhsUbConstVal < rhsConstVal) {
866 if (
isQTimesDPlusR(lhs, operands, divisor, quotientTimesDiv, rem)) {
867 if (rhsConstVal % divisor == 0 &&
869 expr = quotientTimesDiv.
floorDiv(rhsConst);
870 }
else if (divisor % rhsConstVal == 0 &&
872 expr = rem % rhsConst;
898 if (operands.empty())
904 constLowerBounds.reserve(operands.size());
905 constUpperBounds.reserve(operands.size());
906 for (
Value operand : operands) {
920 if (
auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
921 lowerBounds.push_back(constExpr.getValue());
922 upperBounds.push_back(constExpr.getValue());
924 lowerBounds.push_back(
926 constLowerBounds, constUpperBounds,
928 upperBounds.push_back(
930 constLowerBounds, constUpperBounds,
939 unsigned i = exprEn.index();
941 if (lowerBounds[i] && upperBounds[i] && *lowerBounds[i] == *upperBounds[i])
946 if (!upperBounds[i]) {
947 irredundantExprs.push_back(e);
953 auto otherLowerBound = en.value();
954 unsigned pos = en.index();
955 if (pos == i || !otherLowerBound)
957 if (*otherLowerBound > *upperBounds[i])
959 if (*otherLowerBound < *upperBounds[i])
964 if (upperBounds[pos] && lowerBounds[i] &&
965 lowerBounds[i] == upperBounds[i] &&
966 otherLowerBound == *upperBounds[pos] && i < pos)
970 irredundantExprs.push_back(e);
972 if (!lowerBounds[i]) {
973 irredundantExprs.push_back(e);
978 auto otherUpperBound = en.value();
979 unsigned pos = en.index();
980 if (pos == i || !otherUpperBound)
982 if (*otherUpperBound < *lowerBounds[i])
984 if (*otherUpperBound > *lowerBounds[i])
986 if (lowerBounds[pos] && upperBounds[i] &&
987 lowerBounds[i] == upperBounds[i] &&
988 otherUpperBound == lowerBounds[pos] && i < pos)
992 irredundantExprs.push_back(e);
1004 static void LLVM_ATTRIBUTE_UNUSED
1006 assert(map.
getNumInputs() == operands.size() &&
"invalid operands for map");
1012 newResults.push_back(expr);
1029 unsigned dimOrSymbolPosition,
1033 bool isDimReplacement = (dimOrSymbolPosition < dims.size());
1034 unsigned pos = isDimReplacement ? dimOrSymbolPosition
1035 : dimOrSymbolPosition - dims.size();
1036 Value &v = isDimReplacement ? dims[pos] : syms[pos];
1049 AffineMap composeMap = affineApply.getAffineMap();
1050 assert(composeMap.
getNumResults() == 1 &&
"affine.apply with >1 results");
1052 affineApply.getMapOperands().end());
1066 dims.append(composeDims.begin(), composeDims.end());
1067 syms.append(composeSyms.begin(), composeSyms.end());
1068 *map = map->
replace(toReplace, replacementExpr, dims.size(), syms.size());
1097 for (
unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
1109 unsigned nDims = 0, nSyms = 0;
1111 dimReplacements.reserve(dims.size());
1112 symReplacements.reserve(syms.size());
1113 for (
auto *container : {&dims, &syms}) {
1114 bool isDim = (container == &dims);
1115 auto &repls = isDim ? dimReplacements : symReplacements;
1117 Value v = en.value();
1121 "map is function of unexpected expr@pos");
1127 operands->push_back(v);
1140 while (llvm::any_of(*operands, [](
Value v) {
1154 return b.
create<AffineApplyOp>(loc, map, valueOperands);
1176 for (
unsigned i : llvm::seq<unsigned>(0, map.
getNumResults())) {
1183 llvm::append_range(dims,
1185 llvm::append_range(symbols,
1192 operands = llvm::to_vector(llvm::concat<Value>(dims, symbols));
1201 assert(map.
getNumResults() == 1 &&
"building affine.apply with !=1 result");
1211 AffineApplyOp applyOp =
1216 for (
unsigned i = 0, e = constOperands.size(); i != e; ++i)
1221 if (failed(applyOp->fold(constOperands, foldResults)) ||
1222 foldResults.empty()) {
1224 listener->notifyOperationInserted(applyOp, {});
1225 return applyOp.getResult();
1229 assert(foldResults.size() == 1 &&
"expected 1 folded result");
1230 return foldResults.front();
1248 return llvm::map_to_vector(llvm::seq<unsigned>(0, map.
getNumResults()),
1250 return makeComposedFoldedAffineApply(
1251 b, loc, map.getSubMap({i}), operands);
1255 template <
typename OpTy>
1267 return makeComposedMinMax<AffineMinOp>(b, loc, map, operands);
1270 template <
typename OpTy>
1282 auto minMaxOp = makeComposedMinMax<OpTy>(newBuilder, loc, map, operands);
1286 for (
unsigned i = 0, e = constOperands.size(); i != e; ++i)
1291 if (failed(minMaxOp->fold(constOperands, foldResults)) ||
1292 foldResults.empty()) {
1294 listener->notifyOperationInserted(minMaxOp, {});
1295 return minMaxOp.getResult();
1299 assert(foldResults.size() == 1 &&
"expected 1 folded result");
1300 return foldResults.front();
1307 return makeComposedFoldedMinMax<AffineMinOp>(b, loc, map, operands);
1314 return makeComposedFoldedMinMax<AffineMaxOp>(b, loc, map, operands);
1319 template <
class MapOrSet>
1322 if (!mapOrSet || operands->empty())
1325 assert(mapOrSet->getNumInputs() == operands->size() &&
1326 "map/set inputs must match number of operands");
1328 auto *context = mapOrSet->getContext();
1330 resultOperands.reserve(operands->size());
1332 remappedSymbols.reserve(operands->size());
1333 unsigned nextDim = 0;
1334 unsigned nextSym = 0;
1335 unsigned oldNumSyms = mapOrSet->getNumSymbols();
1337 for (
unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
1338 if (i < mapOrSet->getNumDims()) {
1342 remappedSymbols.push_back((*operands)[i]);
1345 resultOperands.push_back((*operands)[i]);
1348 resultOperands.push_back((*operands)[i]);
1352 resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
1353 *operands = resultOperands;
1354 *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
1355 oldNumSyms + nextSym);
1357 assert(mapOrSet->getNumInputs() == operands->size() &&
1358 "map/set inputs must match number of operands");
1362 template <
class MapOrSet>
1365 static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
1366 "Argument must be either of AffineMap or IntegerSet type");
1368 if (!mapOrSet || operands->empty())
1371 assert(mapOrSet->getNumInputs() == operands->size() &&
1372 "map/set inputs must match number of operands");
1374 canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
1377 llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
1378 llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
1380 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr))
1381 usedDims[dimExpr.getPosition()] =
true;
1382 else if (
auto symExpr = dyn_cast<AffineSymbolExpr>(expr))
1383 usedSyms[symExpr.getPosition()] =
true;
1386 auto *context = mapOrSet->getContext();
1389 resultOperands.reserve(operands->size());
1391 llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
1393 unsigned nextDim = 0;
1394 for (
unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
1397 auto it = seenDims.find((*operands)[i]);
1398 if (it == seenDims.end()) {
1400 resultOperands.push_back((*operands)[i]);
1401 seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
1403 dimRemapping[i] = it->second;
1407 llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
1409 unsigned nextSym = 0;
1410 for (
unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
1416 IntegerAttr operandCst;
1417 if (
matchPattern((*operands)[i + mapOrSet->getNumDims()],
1424 auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
1425 if (it == seenSymbols.end()) {
1427 resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
1428 seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
1431 symRemapping[i] = it->second;
1434 *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
1436 *operands = resultOperands;
1441 canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
1446 canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
1453 template <
typename AffineOpTy>
1462 LogicalResult matchAndRewrite(AffineOpTy affineOp,
1465 llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
1466 AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp,
1467 AffineVectorStoreOp, AffineVectorLoadOp>::value,
1468 "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op "
1470 auto map = affineOp.getAffineMap();
1472 auto oldOperands = affineOp.getMapOperands();
1477 if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
1478 resultOperands.begin()))
1481 replaceAffineOp(rewriter, affineOp, map, resultOperands);
1489 void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
1496 void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
1500 prefetch, prefetch.getMemref(), map, mapOperands, prefetch.getIsWrite(),
1501 prefetch.getLocalityHint(), prefetch.getIsDataCache());
1504 void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
1508 store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
1511 void SimplifyAffineOp<AffineVectorLoadOp>::replaceAffineOp(
1515 vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map,
1519 void SimplifyAffineOp<AffineVectorStoreOp>::replaceAffineOp(
1523 vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map,
1528 template <
typename AffineOpTy>
1529 void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
1538 results.
add<SimplifyAffineOp<AffineApplyOp>>(context);
1569 p <<
" " << getSrcMemRef() <<
'[';
1571 p <<
"], " << getDstMemRef() <<
'[';
1573 p <<
"], " << getTagMemRef() <<
'[';
1577 p <<
", " << getStride();
1578 p <<
", " << getNumElementsPerStride();
1580 p <<
" : " << getSrcMemRefType() <<
", " << getDstMemRefType() <<
", "
1581 << getTagMemRefType();
1593 AffineMapAttr srcMapAttr;
1596 AffineMapAttr dstMapAttr;
1599 AffineMapAttr tagMapAttr;
1614 getSrcMapAttrStrName(),
1618 getDstMapAttrStrName(),
1622 getTagMapAttrStrName(),
1631 if (!strideInfo.empty() && strideInfo.size() != 2) {
1633 "expected two stride related operands");
1635 bool isStrided = strideInfo.size() == 2;
1640 if (types.size() != 3)
1658 if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
1659 dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
1660 tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1662 "memref operand count not equal to map.numInputs");
1666 LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
1667 if (!llvm::isa<MemRefType>(getOperand(getSrcMemRefOperandIndex()).
getType()))
1668 return emitOpError(
"expected DMA source to be of memref type");
1669 if (!llvm::isa<MemRefType>(getOperand(getDstMemRefOperandIndex()).
getType()))
1670 return emitOpError(
"expected DMA destination to be of memref type");
1671 if (!llvm::isa<MemRefType>(getOperand(getTagMemRefOperandIndex()).
getType()))
1672 return emitOpError(
"expected DMA tag to be of memref type");
1674 unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
1675 getDstMap().getNumInputs() +
1676 getTagMap().getNumInputs();
1677 if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
1678 getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
1679 return emitOpError(
"incorrect number of operands");
1683 for (
auto idx : getSrcIndices()) {
1684 if (!idx.getType().isIndex())
1685 return emitOpError(
"src index to dma_start must have 'index' type");
1688 "src index must be a valid dimension or symbol identifier");
1690 for (
auto idx : getDstIndices()) {
1691 if (!idx.getType().isIndex())
1692 return emitOpError(
"dst index to dma_start must have 'index' type");
1695 "dst index must be a valid dimension or symbol identifier");
1697 for (
auto idx : getTagIndices()) {
1698 if (!idx.getType().isIndex())
1699 return emitOpError(
"tag index to dma_start must have 'index' type");
1702 "tag index must be a valid dimension or symbol identifier");
1713 void AffineDmaStartOp::getEffects(
1739 p <<
" " << getTagMemRef() <<
'[';
1744 p <<
" : " << getTagMemRef().getType();
1755 AffineMapAttr tagMapAttr;
1764 getTagMapAttrStrName(),
1773 if (!llvm::isa<MemRefType>(type))
1775 "expected tag to be of memref type");
1777 if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1779 "tag memref operand count != to map.numInputs");
1783 LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
1784 if (!llvm::isa<MemRefType>(getOperand(0).
getType()))
1785 return emitOpError(
"expected DMA tag to be of memref type");
1787 for (
auto idx : getTagIndices()) {
1788 if (!idx.getType().isIndex())
1789 return emitOpError(
"index to dma_wait must have 'index' type");
1792 "index must be a valid dimension or symbol identifier");
1803 void AffineDmaWaitOp::getEffects(
1819 ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
1820 assert(((!lbMap && lbOperands.empty()) ||
1822 "lower bound operand count does not match the affine map");
1823 assert(((!ubMap && ubOperands.empty()) ||
1825 "upper bound operand count does not match the affine map");
1826 assert(step > 0 &&
"step has to be a positive integer constant");
1832 getOperandSegmentSizeAttr(),
1834 static_cast<int32_t>(ubOperands.size()),
1835 static_cast<int32_t>(iterArgs.size())}));
1837 for (
Value val : iterArgs)
1859 Value inductionVar =
1861 for (
Value val : iterArgs)
1862 bodyBlock->
addArgument(val.getType(), val.getLoc());
1867 if (iterArgs.empty() && !bodyBuilder) {
1868 ensureTerminator(*bodyRegion, builder, result.
location);
1869 }
else if (bodyBuilder) {
1872 bodyBuilder(builder, result.
location, inductionVar,
1878 int64_t ub, int64_t step,
ValueRange iterArgs,
1879 BodyBuilderFn bodyBuilder) {
1882 return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,
1886 LogicalResult AffineForOp::verifyRegions() {
1889 auto *body = getBody();
1890 if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
1891 return emitOpError(
"expected body to have a single index argument for the "
1892 "induction variable");
1896 if (getLowerBoundMap().getNumInputs() > 0)
1898 getLowerBoundMap().getNumDims())))
1901 if (getUpperBoundMap().getNumInputs() > 0)
1903 getUpperBoundMap().getNumDims())))
1906 unsigned opNumResults = getNumResults();
1907 if (opNumResults == 0)
1913 if (getNumIterOperands() != opNumResults)
1915 "mismatch between the number of loop-carried values and results");
1916 if (getNumRegionIterArgs() != opNumResults)
1918 "mismatch between the number of basic block args and results");
1928 bool failedToParsedMinMax =
1932 auto boundAttrStrName =
1933 isLower ? AffineForOp::getLowerBoundMapAttrName(result.
name)
1934 : AffineForOp::getUpperBoundMapAttrName(result.
name);
1941 if (!boundOpInfos.empty()) {
1943 if (boundOpInfos.size() > 1)
1945 "expected only one loop bound operand");
1970 if (
auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(boundAttr)) {
1971 unsigned currentNumOperands = result.
operands.size();
1976 auto map = affineMapAttr.getValue();
1980 "dim operand count and affine map dim count must match");
1982 unsigned numDimAndSymbolOperands =
1983 result.
operands.size() - currentNumOperands;
1984 if (numDims + map.
getNumSymbols() != numDimAndSymbolOperands)
1987 "symbol operand count and affine map symbol count must match");
1993 return p.
emitError(attrLoc,
"lower loop bound affine map with "
1994 "multiple results requires 'max' prefix");
1996 return p.
emitError(attrLoc,
"upper loop bound affine map with multiple "
1997 "results requires 'min' prefix");
2003 if (
auto integerAttr = llvm::dyn_cast<IntegerAttr>(boundAttr)) {
2013 "expected valid affine map representation for loop bounds");
2025 int64_t numOperands = result.
operands.size();
2028 int64_t numLbOperands = result.
operands.size() - numOperands;
2031 numOperands = result.
operands.size();
2034 int64_t numUbOperands = result.
operands.size() - numOperands;
2039 getStepAttrName(result.
name),
2043 IntegerAttr stepAttr;
2045 getStepAttrName(result.
name).data(),
2049 if (stepAttr.getValue().isNegative())
2052 "expected step to be representable as a positive signed integer");
2060 regionArgs.push_back(inductionVariable);
2068 for (
auto argOperandType :
2069 llvm::zip(llvm::drop_begin(regionArgs), operands, result.
types)) {
2070 Type type = std::get<2>(argOperandType);
2071 std::get<0>(argOperandType).type = type;
2079 getOperandSegmentSizeAttr(),
2081 static_cast<int32_t>(numUbOperands),
2082 static_cast<int32_t>(operands.size())}));
2086 if (regionArgs.size() != result.
types.size() + 1)
2089 "mismatch between the number of loop-carried values and results");
2093 AffineForOp::ensureTerminator(*body, builder, result.
location);
2115 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
2116 p << constExpr.getValue();
2124 if (dyn_cast<AffineSymbolExpr>(expr)) {
2140 unsigned AffineForOp::getNumIterOperands() {
2141 AffineMap lbMap = getLowerBoundMapAttr().getValue();
2142 AffineMap ubMap = getUpperBoundMapAttr().getValue();
2147 std::optional<MutableArrayRef<OpOperand>>
2148 AffineForOp::getYieldedValuesMutable() {
2149 return cast<AffineYieldOp>(getBody()->getTerminator()).getOperandsMutable();
2161 if (getStepAsInt() != 1)
2162 p <<
" step " << getStepAsInt();
2164 bool printBlockTerminators =
false;
2165 if (getNumIterOperands() > 0) {
2167 auto regionArgs = getRegionIterArgs();
2168 auto operands = getInits();
2170 llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](
auto it) {
2171 p << std::get<0>(it) <<
" = " << std::get<1>(it);
2173 p <<
") -> (" << getResultTypes() <<
")";
2174 printBlockTerminators =
true;
2179 printBlockTerminators);
2181 (*this)->getAttrs(),
2182 {getLowerBoundMapAttrName(getOperation()->getName()),
2183 getUpperBoundMapAttrName(getOperation()->getName()),
2184 getStepAttrName(getOperation()->getName()),
2185 getOperandSegmentSizeAttr()});
2190 auto foldLowerOrUpperBound = [&forOp](
bool lower) {
2194 auto boundOperands =
2195 lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
2196 for (
auto operand : boundOperands) {
2199 operandConstants.push_back(operandCst);
2203 lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
2205 "bound maps should have at least one result");
2207 if (failed(boundMap.
constantFold(operandConstants, foldedResults)))
2211 assert(!foldedResults.empty() &&
"bounds should have at least one result");
2212 auto maxOrMin = llvm::cast<IntegerAttr>(foldedResults[0]).getValue();
2213 for (
unsigned i = 1, e = foldedResults.size(); i < e; i++) {
2214 auto foldedResult = llvm::cast<IntegerAttr>(foldedResults[i]).getValue();
2215 maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
2216 : llvm::APIntOps::smin(maxOrMin, foldedResult);
2218 lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
2219 : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
2224 bool folded =
false;
2225 if (!forOp.hasConstantLowerBound())
2226 folded |= succeeded(foldLowerOrUpperBound(
true));
2229 if (!forOp.hasConstantUpperBound())
2230 folded |= succeeded(foldLowerOrUpperBound(
false));
2231 return success(folded);
2239 auto lbMap = forOp.getLowerBoundMap();
2240 auto ubMap = forOp.getUpperBoundMap();
2241 auto prevLbMap = lbMap;
2242 auto prevUbMap = ubMap;
2255 if (lbMap == prevLbMap && ubMap == prevUbMap)
2258 if (lbMap != prevLbMap)
2259 forOp.setLowerBound(lbOperands, lbMap);
2260 if (ubMap != prevUbMap)
2261 forOp.setUpperBound(ubOperands, ubMap);
2267 static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
2268 int64_t step = forOp.getStepAsInt();
2269 if (!forOp.hasConstantBounds() || step <= 0)
2270 return std::nullopt;
2271 int64_t lb = forOp.getConstantLowerBound();
2272 int64_t ub = forOp.getConstantUpperBound();
2273 return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
2281 LogicalResult matchAndRewrite(AffineForOp forOp,
2284 if (!llvm::hasSingleElement(*forOp.getBody()))
2286 if (forOp.getNumResults() == 0)
2288 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
2289 if (tripCount && *tripCount == 0) {
2292 rewriter.
replaceOp(forOp, forOp.getInits());
2296 auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
2297 auto iterArgs = forOp.getRegionIterArgs();
2298 bool hasValDefinedOutsideLoop =
false;
2299 bool iterArgsNotInOrder =
false;
2300 for (
unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
2301 Value val = yieldOp.getOperand(i);
2302 auto *iterArgIt = llvm::find(iterArgs, val);
2303 if (iterArgIt == iterArgs.end()) {
2305 assert(forOp.isDefinedOutsideOfLoop(val) &&
2306 "must be defined outside of the loop");
2307 hasValDefinedOutsideLoop =
true;
2308 replacements.push_back(val);
2310 unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
2312 iterArgsNotInOrder =
true;
2313 replacements.push_back(forOp.getInits()[pos]);
2318 if (!tripCount.has_value() &&
2319 (hasValDefinedOutsideLoop || iterArgsNotInOrder))
2323 if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
2325 rewriter.
replaceOp(forOp, replacements);
2333 results.
add<AffineForEmptyLoopFolder>(context);
2337 assert((point.
isParent() || point == getRegion()) &&
"invalid region point");
2344 void AffineForOp::getSuccessorRegions(
2346 assert((point.
isParent() || point == getRegion()) &&
"expected loop region");
2351 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*
this);
2352 if (point.
isParent() && tripCount.has_value()) {
2353 if (tripCount.value() > 0) {
2354 regions.push_back(
RegionSuccessor(&getRegion(), getRegionIterArgs()));
2357 if (tripCount.value() == 0) {
2365 if (!point.
isParent() && tripCount && *tripCount == 1) {
2372 regions.push_back(
RegionSuccessor(&getRegion(), getRegionIterArgs()));
2378 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(op);
2379 return tripCount && *tripCount == 0;
2382 LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
2392 results.assign(getInits().begin(), getInits().end());
2395 return success(folded);
2408 assert(map.
getNumResults() >= 1 &&
"bound map has at least one result");
2409 getLowerBoundOperandsMutable().assign(lbOperands);
2410 setLowerBoundMap(map);
2415 assert(map.
getNumResults() >= 1 &&
"bound map has at least one result");
2416 getUpperBoundOperandsMutable().assign(ubOperands);
2417 setUpperBoundMap(map);
2420 bool AffineForOp::hasConstantLowerBound() {
2421 return getLowerBoundMap().isSingleConstant();
2424 bool AffineForOp::hasConstantUpperBound() {
2425 return getUpperBoundMap().isSingleConstant();
2428 int64_t AffineForOp::getConstantLowerBound() {
2429 return getLowerBoundMap().getSingleConstantResult();
2432 int64_t AffineForOp::getConstantUpperBound() {
2433 return getUpperBoundMap().getSingleConstantResult();
2436 void AffineForOp::setConstantLowerBound(int64_t value) {
2440 void AffineForOp::setConstantUpperBound(int64_t value) {
2444 AffineForOp::operand_range AffineForOp::getControlOperands() {
2449 bool AffineForOp::matchingBoundOperandList() {
2450 auto lbMap = getLowerBoundMap();
2451 auto ubMap = getUpperBoundMap();
2457 for (
unsigned i = 0, e = lbMap.
getNumInputs(); i < e; i++) {
2459 if (getOperand(i) != getOperand(numOperands + i))
2467 std::optional<SmallVector<Value>> AffineForOp::getLoopInductionVars() {
2471 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopLowerBounds() {
2472 if (!hasConstantLowerBound())
2473 return std::nullopt;
2476 OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))};
2479 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopSteps() {
2485 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopUpperBounds() {
2486 if (!hasConstantUpperBound())
2490 OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()))};
2493 FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
2495 bool replaceInitOperandUsesInLoop,
2500 auto inits = llvm::to_vector(getInits());
2501 inits.append(newInitOperands.begin(), newInitOperands.end());
2502 AffineForOp newLoop = rewriter.
create<AffineForOp>(
2507 auto yieldOp = cast<AffineYieldOp>(getBody()->getTerminator());
2509 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
2514 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
2515 assert(newInitOperands.size() == newYieldedValues.size() &&
2516 "expected as many new yield values as new iter operands");
2518 yieldOp.getOperandsMutable().append(newYieldedValues);
2523 rewriter.
mergeBlocks(getBody(), newLoop.getBody(),
2524 newLoop.getBody()->getArguments().take_front(
2525 getBody()->getNumArguments()));
2527 if (replaceInitOperandUsesInLoop) {
2530 for (
auto it : llvm::zip(newInitOperands, newIterArgs)) {
2541 newLoop->getResults().take_front(getNumResults()));
2542 return cast<LoopLikeOpInterface>(newLoop.getOperation());
2570 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2571 if (!ivArg || !ivArg.getOwner() || !ivArg.getOwner()->getParent())
2572 return AffineForOp();
2574 ivArg.getOwner()->getParent()->getParentOfType<AffineForOp>())
2576 return forOp.getInductionVar() == val ? forOp : AffineForOp();
2577 return AffineForOp();
2581 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2582 if (!ivArg || !ivArg.getOwner())
2585 auto parallelOp = dyn_cast<AffineParallelOp>(containingOp);
2586 if (parallelOp && llvm::is_contained(parallelOp.getIVs(), val))
2595 ivs->reserve(forInsts.size());
2596 for (
auto forInst : forInsts)
2597 ivs->push_back(forInst.getInductionVar());
2602 ivs.reserve(affineOps.size());
2605 if (
auto forOp = dyn_cast<AffineForOp>(op))
2606 ivs.push_back(forOp.getInductionVar());
2607 else if (
auto parallelOp = dyn_cast<AffineParallelOp>(op))
2608 for (
size_t i = 0; i < parallelOp.getBody()->getNumArguments(); i++)
2609 ivs.push_back(parallelOp.getBody()->getArgument(i));
2615 template <
typename BoundListTy,
typename LoopCreatorTy>
2620 LoopCreatorTy &&loopCreatorFn) {
2621 assert(lbs.size() == ubs.size() &&
"Mismatch in number of arguments");
2622 assert(lbs.size() == steps.size() &&
"Mismatch in number of arguments");
2634 ivs.reserve(lbs.size());
2635 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
2641 if (i == e - 1 && bodyBuilderFn) {
2643 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2645 nestedBuilder.
create<AffineYieldOp>(nestedLoc);
2650 auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
2658 int64_t ub, int64_t step,
2659 AffineForOp::BodyBuilderFn bodyBuilderFn) {
2660 return builder.
create<AffineForOp>(loc, lb, ub, step,
2661 std::nullopt, bodyBuilderFn);
2668 AffineForOp::BodyBuilderFn bodyBuilderFn) {
2671 if (lbConst && ubConst)
2673 ubConst.value(), step, bodyBuilderFn);
2676 std::nullopt, bodyBuilderFn);
2704 LogicalResult matchAndRewrite(AffineIfOp ifOp,
2706 if (ifOp.getElseRegion().empty() ||
2707 !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
2722 LogicalResult matchAndRewrite(AffineIfOp op,
2725 auto isTriviallyFalse = [](
IntegerSet iSet) {
2726 return iSet.isEmptyIntegerSet();
2730 return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 &&
2731 iSet.getConstraint(0) == 0);
2734 IntegerSet affineIfConditions = op.getIntegerSet();
2736 if (isTriviallyFalse(affineIfConditions)) {
2740 if (op.getNumResults() == 0 && !op.hasElse()) {
2746 blockToMove = op.getElseBlock();
2747 }
else if (isTriviallyTrue(affineIfConditions)) {
2748 blockToMove = op.getThenBlock();
2766 rewriter.
eraseOp(blockToMoveTerminator);
2774 void AffineIfOp::getSuccessorRegions(
2783 if (getElseRegion().empty()) {
2784 regions.push_back(getResults());
2800 auto conditionAttr =
2801 (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2803 return emitOpError(
"requires an integer set attribute named 'condition'");
2806 IntegerSet condition = conditionAttr.getValue();
2808 return emitOpError(
"operand count and condition integer set dimension and "
2809 "symbol count must match");
2821 IntegerSetAttr conditionAttr;
2824 AffineIfOp::getConditionAttrStrName(),
2830 auto set = conditionAttr.getValue();
2831 if (set.getNumDims() != numDims)
2834 "dim operand count and integer set dim count must match");
2835 if (numDims + set.getNumSymbols() != result.
operands.size())
2838 "symbol operand count and integer set symbol count must match");
2852 AffineIfOp::ensureTerminator(*thenRegion, parser.
getBuilder(),
2859 AffineIfOp::ensureTerminator(*elseRegion, parser.
getBuilder(),
2871 auto conditionAttr =
2872 (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2873 p <<
" " << conditionAttr;
2875 conditionAttr.getValue().getNumDims(), p);
2882 auto &elseRegion = this->getElseRegion();
2883 if (!elseRegion.
empty()) {
2892 getConditionAttrStrName());
2897 ->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName())
2901 void AffineIfOp::setIntegerSet(
IntegerSet newSet) {
2907 (*this)->setOperands(operands);
2912 bool withElseRegion) {
2913 assert(resultTypes.empty() || withElseRegion);
2922 if (resultTypes.empty())
2923 AffineIfOp::ensureTerminator(*thenRegion, builder, result.
location);
2926 if (withElseRegion) {
2928 if (resultTypes.empty())
2929 AffineIfOp::ensureTerminator(*elseRegion, builder, result.
location);
2935 AffineIfOp::build(builder, result, {}, set, args,
2950 if (llvm::none_of(operands,
2961 auto set = getIntegerSet();
2967 if (getIntegerSet() == set && llvm::equal(operands, getOperands()))
2970 setConditional(set, operands);
2976 results.
add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context);
2985 assert(operands.size() == 1 + map.
getNumInputs() &&
"inconsistent operands");
2989 auto memrefType = llvm::cast<MemRefType>(operands[0].
getType());
2990 result.
types.push_back(memrefType.getElementType());
2995 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
2998 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
3000 result.
types.push_back(memrefType.getElementType());
3005 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
3006 int64_t rank = memrefType.getRank();
3011 build(builder, result, memref, map, indices);
3020 AffineMapAttr mapAttr;
3025 AffineLoadOp::getMapAttrStrName(),
3035 p <<
" " << getMemRef() <<
'[';
3036 if (AffineMapAttr mapAttr =
3037 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3041 {getMapAttrStrName()});
3047 static LogicalResult
3050 MemRefType memrefType,
unsigned numIndexOperands) {
3053 return op->
emitOpError(
"affine map num results must equal memref rank");
3055 return op->
emitOpError(
"expects as many subscripts as affine map inputs");
3058 for (
auto idx : mapOperands) {
3059 if (!idx.getType().isIndex())
3060 return op->
emitOpError(
"index to load must have 'index' type");
3063 "index must be a valid dimension or symbol identifier");
3071 if (
getType() != memrefType.getElementType())
3072 return emitOpError(
"result type must match element type of memref");
3076 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3077 getMapOperands(), memrefType,
3078 getNumOperands() - 1)))
3086 results.
add<SimplifyAffineOp<AffineLoadOp>>(context);
3095 auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
3102 auto global = dyn_cast_or_null<memref::GlobalOp>(
3109 llvm::dyn_cast_or_null<DenseElementsAttr>(global.getConstantInitValue());
3113 if (
auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(cstAttr))
3114 return splatAttr.getSplatValue<
Attribute>();
3116 if (!getAffineMap().isConstant())
3118 auto indices = llvm::to_vector<4>(
3119 llvm::map_range(getAffineMap().getConstantResults(),
3120 [](int64_t v) -> uint64_t {
return v; }));
3121 return cstAttr.getValues<
Attribute>()[indices];
3131 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
3142 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
3143 int64_t rank = memrefType.getRank();
3148 build(builder, result, valueToStore, memref, map, indices);
3157 AffineMapAttr mapAttr;
3162 mapOperands, mapAttr, AffineStoreOp::getMapAttrStrName(),
3173 p <<
" " << getValueToStore();
3174 p <<
", " << getMemRef() <<
'[';
3175 if (AffineMapAttr mapAttr =
3176 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3180 {getMapAttrStrName()});
3187 if (getValueToStore().
getType() != memrefType.getElementType())
3189 "value to store must have the same type as memref element type");
3193 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3194 getMapOperands(), memrefType,
3195 getNumOperands() - 2)))
3203 results.
add<SimplifyAffineOp<AffineStoreOp>>(context);
3206 LogicalResult AffineStoreOp::fold(FoldAdaptor adaptor,
3216 template <
typename T>
3219 if (op.getNumOperands() !=
3220 op.getMap().getNumDims() + op.getMap().getNumSymbols())
3221 return op.emitOpError(
3222 "operand count and affine map dimension and symbol count must match");
3224 if (op.getMap().getNumResults() == 0)
3225 return op.emitOpError(
"affine map expect at least one result");
3229 template <
typename T>
3231 p <<
' ' << op->getAttr(T::getMapAttrStrName());
3232 auto operands = op.getOperands();
3233 unsigned numDims = op.getMap().getNumDims();
3234 p <<
'(' << operands.take_front(numDims) <<
')';
3236 if (operands.size() != numDims)
3237 p <<
'[' << operands.drop_front(numDims) <<
']';
3239 {T::getMapAttrStrName()});
3242 template <
typename T>
3249 AffineMapAttr mapAttr;
3265 template <
typename T>
3267 static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
3268 "expected affine min or max op");
3274 auto foldedMap = op.getMap().partialConstantFold(operands, &results);
3276 if (foldedMap.getNumSymbols() == 1 && foldedMap.isSymbolIdentity())
3277 return op.getOperand(0);
3280 if (results.empty()) {
3282 if (foldedMap == op.getMap())
3285 return op.getResult();
3289 auto resultIt = std::is_same<T, AffineMinOp>::value
3290 ? llvm::min_element(results)
3291 : llvm::max_element(results);
3292 if (resultIt == results.end())
3298 template <
typename T>
3304 AffineMap oldMap = affineOp.getAffineMap();
3310 if (!llvm::is_contained(newExprs, expr))
3311 newExprs.push_back(expr);
3341 template <
typename T>
3347 AffineMap oldMap = affineOp.getAffineMap();
3349 affineOp.getMapOperands().take_front(oldMap.
getNumDims());
3351 affineOp.getMapOperands().take_back(oldMap.
getNumSymbols());
3353 auto newDimOperands = llvm::to_vector<8>(dimOperands);
3354 auto newSymOperands = llvm::to_vector<8>(symOperands);
3362 if (
auto symExpr = dyn_cast<AffineSymbolExpr>(expr)) {
3363 Value symValue = symOperands[symExpr.getPosition()];
3365 producerOps.push_back(producerOp);
3368 }
else if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
3369 Value dimValue = dimOperands[dimExpr.getPosition()];
3371 producerOps.push_back(producerOp);
3378 newExprs.push_back(expr);
3381 if (producerOps.empty())
3388 for (T producerOp : producerOps) {
3389 AffineMap producerMap = producerOp.getAffineMap();
3390 unsigned numProducerDims = producerMap.
getNumDims();
3395 producerOp.getMapOperands().take_front(numProducerDims);
3397 producerOp.getMapOperands().take_back(numProducerSyms);
3398 newDimOperands.append(dimValues.begin(), dimValues.end());
3399 newSymOperands.append(symValues.begin(), symValues.end());
3403 newExprs.push_back(expr.shiftDims(numProducerDims, numUsedDims)
3404 .shiftSymbols(numProducerSyms, numUsedSyms));
3407 numUsedDims += numProducerDims;
3408 numUsedSyms += numProducerSyms;
3414 llvm::to_vector<8>(llvm::concat<Value>(newDimOperands, newSymOperands));
3433 if (!resultExpr.isPureAffine())
3438 if (failed(flattenResult))
3451 if (llvm::is_sorted(flattenedExprs))
3456 llvm::to_vector(llvm::seq<unsigned>(0, map.
getNumResults()));
3457 llvm::sort(resultPermutation, [&](
unsigned lhs,
unsigned rhs) {
3458 return flattenedExprs[lhs] < flattenedExprs[rhs];
3461 for (
unsigned idx : resultPermutation)
3482 template <
typename T>
3488 AffineMap map = affineOp.getAffineMap();
3496 template <
typename T>
3502 if (affineOp.getMap().getNumResults() != 1)
3505 affineOp.getOperands());
3533 return parseAffineMinMaxOp<AffineMinOp>(parser, result);
3561 return parseAffineMinMaxOp<AffineMaxOp>(parser, result);
3580 IntegerAttr hintInfo;
3582 StringRef readOrWrite, cacheType;
3584 AffineMapAttr mapAttr;
3588 AffinePrefetchOp::getMapAttrStrName(),
3594 AffinePrefetchOp::getLocalityHintAttrStrName(),
3604 if (readOrWrite !=
"read" && readOrWrite !=
"write")
3606 "rw specifier has to be 'read' or 'write'");
3607 result.
addAttribute(AffinePrefetchOp::getIsWriteAttrStrName(),
3610 if (cacheType !=
"data" && cacheType !=
"instr")
3612 "cache type has to be 'data' or 'instr'");
3614 result.
addAttribute(AffinePrefetchOp::getIsDataCacheAttrStrName(),
3621 p <<
" " << getMemref() <<
'[';
3622 AffineMapAttr mapAttr =
3623 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3626 p <<
']' <<
", " << (getIsWrite() ?
"write" :
"read") <<
", "
3627 <<
"locality<" << getLocalityHint() <<
">, "
3628 << (getIsDataCache() ?
"data" :
"instr");
3630 (*this)->getAttrs(),
3631 {getMapAttrStrName(), getLocalityHintAttrStrName(),
3632 getIsDataCacheAttrStrName(), getIsWriteAttrStrName()});
3637 auto mapAttr = (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3641 return emitOpError(
"affine.prefetch affine map num results must equal"
3644 return emitOpError(
"too few operands");
3646 if (getNumOperands() != 1)
3647 return emitOpError(
"too few operands");
3651 for (
auto idx : getMapOperands()) {
3654 "index must be a valid dimension or symbol identifier");
3662 results.
add<SimplifyAffineOp<AffinePrefetchOp>>(context);
3665 LogicalResult AffinePrefetchOp::fold(FoldAdaptor adaptor,
3680 auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) {
3684 build(builder, result, resultTypes, reductions, lbs, {}, ubs,
3694 assert(llvm::all_of(lbMaps,
3696 return m.getNumDims() == lbMaps[0].getNumDims() &&
3697 m.getNumSymbols() == lbMaps[0].getNumSymbols();
3699 "expected all lower bounds maps to have the same number of dimensions "
3701 assert(llvm::all_of(ubMaps,
3703 return m.getNumDims() == ubMaps[0].getNumDims() &&
3704 m.getNumSymbols() == ubMaps[0].getNumSymbols();
3706 "expected all upper bounds maps to have the same number of dimensions "
3708 assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) &&
3709 "expected lower bound maps to have as many inputs as lower bound "
3711 assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) &&
3712 "expected upper bound maps to have as many inputs as upper bound "
3720 for (arith::AtomicRMWKind reduction : reductions)
3721 reductionAttrs.push_back(
3733 groups.reserve(groups.size() + maps.size());
3734 exprs.reserve(maps.size());
3736 llvm::append_range(exprs, m.getResults());
3737 groups.push_back(m.getNumResults());
3739 return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs,
3745 AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups);
3746 AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups);
3764 for (
unsigned i = 0, e = steps.size(); i < e; ++i)
3766 if (resultTypes.empty())
3767 ensureTerminator(*bodyRegion, builder, result.
location);
3771 return {&getRegion()};
3774 unsigned AffineParallelOp::getNumDims() {
return getSteps().size(); }
3776 AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
3777 return getOperands().take_front(getLowerBoundsMap().getNumInputs());
3780 AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
3781 return getOperands().drop_front(getLowerBoundsMap().getNumInputs());
3784 AffineMap AffineParallelOp::getLowerBoundMap(
unsigned pos) {
3785 auto values = getLowerBoundsGroups().getValues<int32_t>();
3787 for (
unsigned i = 0; i < pos; ++i)
3789 return getLowerBoundsMap().getSliceMap(start, values[pos]);
3792 AffineMap AffineParallelOp::getUpperBoundMap(
unsigned pos) {
3793 auto values = getUpperBoundsGroups().getValues<int32_t>();
3795 for (
unsigned i = 0; i < pos; ++i)
3797 return getUpperBoundsMap().getSliceMap(start, values[pos]);
3801 return AffineValueMap(getLowerBoundsMap(), getLowerBoundsOperands());
3805 return AffineValueMap(getUpperBoundsMap(), getUpperBoundsOperands());
3808 std::optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
3809 if (hasMinMaxBounds())
3810 return std::nullopt;
3815 AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
3818 for (
unsigned i = 0, e = rangesValueMap.
getNumResults(); i < e; ++i) {
3819 auto expr = rangesValueMap.
getResult(i);
3820 auto cst = dyn_cast<AffineConstantExpr>(expr);
3822 return std::nullopt;
3823 out.push_back(cst.getValue());
3828 Block *AffineParallelOp::getBody() {
return &getRegion().
front(); }
3830 OpBuilder AffineParallelOp::getBodyBuilder() {
3831 return OpBuilder(getBody(), std::prev(getBody()->end()));
3836 "operands to map must match number of inputs");
3838 auto ubOperands = getUpperBoundsOperands();
3841 newOperands.append(ubOperands.begin(), ubOperands.end());
3842 (*this)->setOperands(newOperands);
3849 "operands to map must match number of inputs");
3852 newOperands.append(ubOperands.begin(), ubOperands.end());
3853 (*this)->setOperands(newOperands);
3859 setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
3864 arith::AtomicRMWKind op) {
3866 case arith::AtomicRMWKind::addf:
3867 return isa<FloatType>(resultType);
3868 case arith::AtomicRMWKind::addi:
3869 return isa<IntegerType>(resultType);
3870 case arith::AtomicRMWKind::assign:
3872 case arith::AtomicRMWKind::mulf:
3873 return isa<FloatType>(resultType);
3874 case arith::AtomicRMWKind::muli:
3875 return isa<IntegerType>(resultType);
3876 case arith::AtomicRMWKind::maximumf:
3877 return isa<FloatType>(resultType);
3878 case arith::AtomicRMWKind::minimumf:
3879 return isa<FloatType>(resultType);
3880 case arith::AtomicRMWKind::maxs: {
3881 auto intType = llvm::dyn_cast<IntegerType>(resultType);
3882 return intType && intType.isSigned();
3884 case arith::AtomicRMWKind::mins: {
3885 auto intType = llvm::dyn_cast<IntegerType>(resultType);
3886 return intType && intType.isSigned();
3888 case arith::AtomicRMWKind::maxu: {
3889 auto intType = llvm::dyn_cast<IntegerType>(resultType);
3890 return intType && intType.isUnsigned();
3892 case arith::AtomicRMWKind::minu: {
3893 auto intType = llvm::dyn_cast<IntegerType>(resultType);
3894 return intType && intType.isUnsigned();
3896 case arith::AtomicRMWKind::ori:
3897 return isa<IntegerType>(resultType);
3898 case arith::AtomicRMWKind::andi:
3899 return isa<IntegerType>(resultType);
3906 auto numDims = getNumDims();
3909 getSteps().size() != numDims || getBody()->getNumArguments() != numDims) {
3910 return emitOpError() <<
"the number of region arguments ("
3911 << getBody()->getNumArguments()
3912 <<
") and the number of map groups for lower ("
3913 << getLowerBoundsGroups().getNumElements()
3914 <<
") and upper bound ("
3915 << getUpperBoundsGroups().getNumElements()
3916 <<
"), and the number of steps (" << getSteps().size()
3917 <<
") must all match";
3920 unsigned expectedNumLBResults = 0;
3921 for (APInt v : getLowerBoundsGroups())
3922 expectedNumLBResults += v.getZExtValue();
3923 if (expectedNumLBResults != getLowerBoundsMap().getNumResults())
3924 return emitOpError() <<
"expected lower bounds map to have "
3925 << expectedNumLBResults <<
" results";
3926 unsigned expectedNumUBResults = 0;
3927 for (APInt v : getUpperBoundsGroups())
3928 expectedNumUBResults += v.getZExtValue();
3929 if (expectedNumUBResults != getUpperBoundsMap().getNumResults())
3930 return emitOpError() <<
"expected upper bounds map to have "
3931 << expectedNumUBResults <<
" results";
3933 if (getReductions().size() != getNumResults())
3934 return emitOpError(
"a reduction must be specified for each output");
3940 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
3941 if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
3942 return emitOpError(
"invalid reduction attribute");
3943 auto kind = arith::symbolizeAtomicRMWKind(intAttr.getInt()).value();
3945 return emitOpError(
"result type cannot match reduction attribute");
3951 getLowerBoundsMap().getNumDims())))
3955 getUpperBoundsMap().getNumDims())))
3960 LogicalResult AffineValueMap::canonicalize() {
3962 auto newMap = getAffineMap();
3964 if (newMap == getAffineMap() && newOperands == operands)
3966 reset(newMap, newOperands);
3979 if (!lbCanonicalized && !ubCanonicalized)
3982 if (lbCanonicalized)
3984 if (ubCanonicalized)
3990 LogicalResult AffineParallelOp::fold(FoldAdaptor adaptor,
4002 StringRef keyword) {
4005 ValueRange dimOperands = operands.take_front(numDims);
4006 ValueRange symOperands = operands.drop_front(numDims);
4008 for (llvm::APInt groupSize : group) {
4012 unsigned size = groupSize.getZExtValue();
4017 p << keyword <<
'(';
4027 p <<
" (" << getBody()->getArguments() <<
") = (";
4029 getLowerBoundsOperands(),
"max");
4032 getUpperBoundsOperands(),
"min");
4035 bool elideSteps = llvm::all_of(steps, [](int64_t step) {
return step == 1; });
4038 llvm::interleaveComma(steps, p);
4041 if (getNumResults()) {
4043 llvm::interleaveComma(getReductions(), p, [&](
auto &attr) {
4044 arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(
4045 llvm::cast<IntegerAttr>(attr).getInt());
4046 p <<
"\"" << arith::stringifyAtomicRMWKind(sym) <<
"\"";
4048 p <<
") -> (" << getResultTypes() <<
")";
4055 (*this)->getAttrs(),
4056 {AffineParallelOp::getReductionsAttrStrName(),
4057 AffineParallelOp::getLowerBoundsMapAttrStrName(),
4058 AffineParallelOp::getLowerBoundsGroupsAttrStrName(),
4059 AffineParallelOp::getUpperBoundsMapAttrStrName(),
4060 AffineParallelOp::getUpperBoundsGroupsAttrStrName(),
4061 AffineParallelOp::getStepsAttrStrName()});
4074 "expected operands to be dim or symbol expression");
4077 for (
const auto &list : operands) {
4081 for (
Value operand : valueOperands) {
4082 unsigned pos = std::distance(uniqueOperands.begin(),
4083 llvm::find(uniqueOperands, operand));
4084 if (pos == uniqueOperands.size())
4085 uniqueOperands.push_back(operand);
4086 replacements.push_back(
4096 enum class MinMaxKind { Min, Max };
4120 const llvm::StringLiteral tmpAttrStrName =
"__pseudo_bound_map";
4122 StringRef mapName = kind == MinMaxKind::Min
4123 ? AffineParallelOp::getUpperBoundsMapAttrStrName()
4124 : AffineParallelOp::getLowerBoundsMapAttrStrName();
4125 StringRef groupsName =
4126 kind == MinMaxKind::Min
4127 ? AffineParallelOp::getUpperBoundsGroupsAttrStrName()
4128 : AffineParallelOp::getLowerBoundsGroupsAttrStrName();
4145 auto parseOperands = [&]() {
4147 kind == MinMaxKind::Min ?
"min" :
"max"))) {
4148 mapOperands.clear();
4155 llvm::append_range(flatExprs, map.getValue().getResults());
4157 auto dimsRef = operandsRef.take_front(map.getValue().getNumDims());
4159 auto symsRef = operandsRef.drop_front(map.getValue().getNumDims());
4161 flatDimOperands.append(map.getValue().getNumResults(), dims);
4162 flatSymOperands.append(map.getValue().getNumResults(), syms);
4163 numMapsPerGroup.push_back(map.getValue().getNumResults());
4166 flatSymOperands.emplace_back(),
4167 flatExprs.emplace_back())))
4169 numMapsPerGroup.push_back(1);
4176 unsigned totalNumDims = 0;
4177 unsigned totalNumSyms = 0;
4178 for (
unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
4179 unsigned numDims = flatDimOperands[i].size();
4180 unsigned numSyms = flatSymOperands[i].size();
4181 flatExprs[i] = flatExprs[i]
4182 .shiftDims(numDims, totalNumDims)
4183 .shiftSymbols(numSyms, totalNumSyms);
4184 totalNumDims += numDims;
4185 totalNumSyms += numSyms;
4197 result.
operands.append(dimOperands.begin(), dimOperands.end());
4198 result.
operands.append(symOperands.begin(), symOperands.end());
4201 auto flatMap =
AffineMap::get(totalNumDims, totalNumSyms, flatExprs,
4203 flatMap = flatMap.replaceDimsAndSymbols(
4204 dimRplacements, symRepacements, dimOperands.size(), symOperands.size());
4228 AffineMapAttr stepsMapAttr;
4233 result.
addAttribute(AffineParallelOp::getStepsAttrStrName(),
4237 AffineParallelOp::getStepsAttrStrName(),
4244 auto stepsMap = stepsMapAttr.getValue();
4245 for (
const auto &result : stepsMap.getResults()) {
4246 auto constExpr = dyn_cast<AffineConstantExpr>(result);
4249 "steps must be constant integers");
4250 steps.push_back(constExpr.getValue());
4252 result.
addAttribute(AffineParallelOp::getStepsAttrStrName(),
4262 auto parseAttributes = [&]() -> ParseResult {
4272 std::optional<arith::AtomicRMWKind> reduction =
4273 arith::symbolizeAtomicRMWKind(attrVal.getValue());
4275 return parser.
emitError(loc,
"invalid reduction value: ") << attrVal;
4276 reductions.push_back(
4284 result.
addAttribute(AffineParallelOp::getReductionsAttrStrName(),
4293 for (
auto &iv : ivs)
4294 iv.type = indexType;
4300 AffineParallelOp::ensureTerminator(*body, builder, result.
location);
4309 auto *parentOp = (*this)->getParentOp();
4310 auto results = parentOp->getResults();
4311 auto operands = getOperands();
4313 if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
4314 return emitOpError() <<
"only terminates affine.if/for/parallel regions";
4315 if (parentOp->getNumResults() != getNumOperands())
4316 return emitOpError() <<
"parent of yield must have same number of "
4317 "results as the yield operands";
4318 for (
auto it : llvm::zip(results, operands)) {
4320 return emitOpError() <<
"types mismatch between yield op and its parent";
4333 assert(operands.size() == 1 + map.
getNumInputs() &&
"inconsistent operands");
4337 result.
types.push_back(resultType);
4341 VectorType resultType,
Value memref,
4343 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
4347 result.
types.push_back(resultType);
4351 VectorType resultType,
Value memref,
4353 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
4354 int64_t rank = memrefType.getRank();
4359 build(builder, result, resultType, memref, map, indices);
4362 void AffineVectorLoadOp::getCanonicalizationPatterns(
RewritePatternSet &results,
4364 results.
add<SimplifyAffineOp<AffineVectorLoadOp>>(context);
4372 MemRefType memrefType;
4373 VectorType resultType;
4375 AffineMapAttr mapAttr;
4380 AffineVectorLoadOp::getMapAttrStrName(),
4391 p <<
" " << getMemRef() <<
'[';
4392 if (AffineMapAttr mapAttr =
4393 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4397 {getMapAttrStrName()});
4403 VectorType vectorType) {
4405 if (memrefType.getElementType() != vectorType.getElementType())
4407 "requires memref and vector types of the same elemental type");
4415 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4416 getMapOperands(), memrefType,
4417 getNumOperands() - 1)))
4433 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
4444 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
4445 int64_t rank = memrefType.getRank();
4450 build(builder, result, valueToStore, memref, map, indices);
4452 void AffineVectorStoreOp::getCanonicalizationPatterns(
4454 results.
add<SimplifyAffineOp<AffineVectorStoreOp>>(context);
4461 MemRefType memrefType;
4462 VectorType resultType;
4465 AffineMapAttr mapAttr;
4471 AffineVectorStoreOp::getMapAttrStrName(),
4482 p <<
" " << getValueToStore();
4483 p <<
", " << getMemRef() <<
'[';
4484 if (AffineMapAttr mapAttr =
4485 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4489 {getMapAttrStrName()});
4490 p <<
" : " <<
getMemRefType() <<
", " << getValueToStore().getType();
4496 *
this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4497 getMapOperands(), memrefType,
4498 getNumOperands() - 2)))
4511 void AffineDelinearizeIndexOp::build(
OpBuilder &odsBuilder,
4515 bool hasOuterBound) {
4517 : staticBasis.size() + 1,
4519 build(odsBuilder, odsState, returnTypes, linearIndex, dynamicBasis,
4523 void AffineDelinearizeIndexOp::build(
OpBuilder &odsBuilder,
4526 bool hasOuterBound) {
4527 if (hasOuterBound && !basis.empty() && basis.front() ==
nullptr) {
4528 hasOuterBound =
false;
4529 basis = basis.drop_front();
4535 build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4539 void AffineDelinearizeIndexOp::build(
OpBuilder &odsBuilder,
4543 bool hasOuterBound) {
4544 if (hasOuterBound && !basis.empty() && basis.front() ==
OpFoldResult()) {
4545 hasOuterBound =
false;
4546 basis = basis.drop_front();
4551 build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4555 void AffineDelinearizeIndexOp::build(
OpBuilder &odsBuilder,
4558 bool hasOuterBound) {
4559 build(odsBuilder, odsState, linearIndex,
ValueRange{}, basis, hasOuterBound);
4564 if (getNumResults() != staticBasis.size() &&
4565 getNumResults() != staticBasis.size() + 1)
4566 return emitOpError(
"should return an index for each basis element and up "
4567 "to one extra index");
4569 auto dynamicMarkersCount = llvm::count_if(staticBasis, ShapedType::isDynamic);
4570 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
4572 "mismatch between dynamic and static basis (kDynamic marker but no "
4573 "corresponding dynamic basis entry) -- this can only happen due to an "
4574 "incorrect fold/rewrite");
4576 if (!llvm::all_of(staticBasis, [](int64_t v) {
4577 return v > 0 || ShapedType::isDynamic(v);
4579 return emitOpError(
"no basis element may be statically non-positive");
4588 static std::optional<SmallVector<int64_t>>
4592 uint64_t dynamicBasisIndex = 0;
4595 mutableDynamicBasis.
erase(dynamicBasisIndex);
4597 ++dynamicBasisIndex;
4602 if (dynamicBasisIndex == dynamicBasis.size())
4603 return std::nullopt;
4609 staticBasis.push_back(ShapedType::kDynamic);
4611 staticBasis.push_back(*basisVal);
4618 AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
4620 std::optional<SmallVector<int64_t>> maybeStaticBasis =
4622 adaptor.getDynamicBasis());
4623 if (maybeStaticBasis) {
4624 setStaticBasis(*maybeStaticBasis);
4629 if (getNumResults() == 1) {
4630 result.push_back(getLinearIndex());
4634 if (adaptor.getLinearIndex() ==
nullptr)
4637 if (!adaptor.getDynamicBasis().empty())
4640 int64_t highPart = cast<IntegerAttr>(adaptor.getLinearIndex()).getInt();
4641 Type attrType = getLinearIndex().getType();
4644 if (hasOuterBound())
4645 staticBasis = staticBasis.drop_front();
4646 for (int64_t modulus : llvm::reverse(staticBasis)) {
4647 result.push_back(
IntegerAttr::get(attrType, llvm::mod(highPart, modulus)));
4648 highPart = llvm::divideFloorSigned(highPart, modulus);
4651 std::reverse(result.begin(), result.end());
4657 if (hasOuterBound()) {
4658 if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
4660 getDynamicBasis().drop_front(), builder);
4662 return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
4666 return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
4671 if (!hasOuterBound())
4679 struct DropUnitExtentBasis
4683 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4686 std::optional<Value> zero = std::nullopt;
4687 Location loc = delinearizeOp->getLoc();
4690 zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
4691 return zero.value();
4697 for (
auto [index, basis] :
4699 std::optional<int64_t> basisVal =
4701 if (basisVal && *basisVal == 1)
4702 replacements[index] =
getZero();
4704 newBasis.push_back(basis);
4707 if (newBasis.size() == delinearizeOp.getNumResults())
4709 "no unit basis elements");
4711 if (!newBasis.empty()) {
4713 auto newDelinearizeOp = rewriter.
create<affine::AffineDelinearizeIndexOp>(
4714 loc, delinearizeOp.getLinearIndex(), newBasis);
4717 for (
auto &replacement : replacements) {
4720 replacement = newDelinearizeOp->
getResult(newIndex++);
4724 rewriter.
replaceOp(delinearizeOp, replacements);
4739 struct CancelDelinearizeOfLinearizeDisjointExactTail
4743 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4745 auto linearizeOp = delinearizeOp.getLinearIndex()
4746 .getDefiningOp<affine::AffineLinearizeIndexOp>();
4749 "index doesn't come from linearize");
4751 if (!linearizeOp.getDisjoint())
4754 ValueRange linearizeIns = linearizeOp.getMultiIndex();
4758 size_t numMatches = 0;
4759 for (
auto [linSize, delinSize] : llvm::zip(
4760 llvm::reverse(linearizeBasis), llvm::reverse(delinearizeBasis))) {
4761 if (linSize != delinSize)
4766 if (numMatches == 0)
4768 delinearizeOp,
"final basis element doesn't match linearize");
4771 if (numMatches == linearizeBasis.size() &&
4772 numMatches == delinearizeBasis.size() &&
4773 linearizeIns.size() == delinearizeOp.getNumResults()) {
4774 rewriter.
replaceOp(delinearizeOp, linearizeOp.getMultiIndex());
4778 Value newLinearize = rewriter.
create<affine::AffineLinearizeIndexOp>(
4779 linearizeOp.getLoc(), linearizeIns.drop_back(numMatches),
4781 linearizeOp.getDisjoint());
4782 auto newDelinearize = rewriter.
create<affine::AffineDelinearizeIndexOp>(
4783 delinearizeOp.getLoc(), newLinearize,
4785 delinearizeOp.hasOuterBound());
4787 mergedResults.append(linearizeIns.take_back(numMatches).begin(),
4788 linearizeIns.take_back(numMatches).end());
4789 rewriter.
replaceOp(delinearizeOp, mergedResults);
4807 struct SplitDelinearizeSpanningLastLinearizeArg final
4811 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4813 auto linearizeOp = delinearizeOp.getLinearIndex()
4814 .getDefiningOp<affine::AffineLinearizeIndexOp>();
4817 "index doesn't come from linearize");
4819 if (!linearizeOp.getDisjoint())
4821 "linearize isn't disjoint");
4823 int64_t target = linearizeOp.getStaticBasis().back();
4824 if (ShapedType::isDynamic(target))
4826 linearizeOp,
"linearize ends with dynamic basis value");
4828 int64_t sizeToSplit = 1;
4829 size_t elemsToSplit = 0;
4831 for (int64_t basisElem : llvm::reverse(basis)) {
4832 if (ShapedType::isDynamic(basisElem))
4834 delinearizeOp,
"dynamic basis element while scanning for split");
4835 sizeToSplit *= basisElem;
4838 if (sizeToSplit > target)
4840 "overshot last argument size");
4841 if (sizeToSplit == target)
4845 if (sizeToSplit < target)
4847 delinearizeOp,
"product of known basis elements doesn't exceed last "
4848 "linearize argument");
4850 if (elemsToSplit < 2)
4853 "need at least two elements to form the basis product");
4855 Value linearizeWithoutBack =
4856 rewriter.
create<affine::AffineLinearizeIndexOp>(
4857 linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(),
4858 linearizeOp.getDynamicBasis(),
4859 linearizeOp.getStaticBasis().drop_back(),
4860 linearizeOp.getDisjoint());
4861 auto delinearizeWithoutSplitPart =
4862 rewriter.
create<affine::AffineDelinearizeIndexOp>(
4863 delinearizeOp.getLoc(), linearizeWithoutBack,
4864 delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),
4865 delinearizeOp.hasOuterBound());
4866 auto delinearizeBack = rewriter.
create<affine::AffineDelinearizeIndexOp>(
4867 delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(),
4868 basis.take_back(elemsToSplit),
true);
4870 llvm::concat<Value>(delinearizeWithoutSplitPart.getResults(),
4871 delinearizeBack.getResults()));
4872 rewriter.
replaceOp(delinearizeOp, results);
4879 void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
4882 .insert<CancelDelinearizeOfLinearizeDisjointExactTail,
4883 DropUnitExtentBasis, SplitDelinearizeSpanningLastLinearizeArg>(
4891 void AffineLinearizeIndexOp::build(
OpBuilder &odsBuilder,
4895 if (!basis.empty() && basis.front() ==
Value())
4896 basis = basis.drop_front();
4901 build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
4904 void AffineLinearizeIndexOp::build(
OpBuilder &odsBuilder,
4910 basis = basis.drop_front();
4914 build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
4917 void AffineLinearizeIndexOp::build(
OpBuilder &odsBuilder,
4921 build(odsBuilder, odsState, multiIndex,
ValueRange{}, basis, disjoint);
4925 size_t numIndexes = getMultiIndex().size();
4926 size_t numBasisElems = getStaticBasis().size();
4927 if (numIndexes != numBasisElems && numIndexes != numBasisElems + 1)
4928 return emitOpError(
"should be passed a basis element for each index except "
4929 "possibly the first");
4931 auto dynamicMarkersCount =
4932 llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
4933 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
4935 "mismatch between dynamic and static basis (kDynamic marker but no "
4936 "corresponding dynamic basis entry) -- this can only happen due to an "
4937 "incorrect fold/rewrite");
4942 OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
4943 std::optional<SmallVector<int64_t>> maybeStaticBasis =
4945 adaptor.getDynamicBasis());
4946 if (maybeStaticBasis) {
4947 setStaticBasis(*maybeStaticBasis);
4951 if (getMultiIndex().empty())
4955 if (getMultiIndex().size() == 1)
4956 return getMultiIndex().front();
4958 if (llvm::any_of(adaptor.getMultiIndex(),
4959 [](
Attribute a) { return a == nullptr; }))
4962 if (!adaptor.getDynamicBasis().empty())
4967 for (
auto [length, indexAttr] :
4968 llvm::zip_first(llvm::reverse(getStaticBasis()),
4969 llvm::reverse(adaptor.getMultiIndex()))) {
4970 result = result + cast<IntegerAttr>(indexAttr).getInt() * stride;
4971 stride = stride * length;
4974 if (!hasOuterBound())
4977 cast<IntegerAttr>(adaptor.getMultiIndex().front()).getInt() * stride;
4984 if (hasOuterBound()) {
4985 if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
4987 getDynamicBasis().drop_front(), builder);
4989 return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
4993 return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
4998 if (!hasOuterBound())
5014 struct DropLinearizeUnitComponentsIfDisjointOrZero final
5018 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5021 size_t numIndices = multiIndex.size();
5023 newIndices.reserve(numIndices);
5025 newBasis.reserve(numIndices);
5027 if (!op.hasOuterBound()) {
5028 newIndices.push_back(multiIndex.front());
5029 multiIndex = multiIndex.drop_front();
5033 for (
auto [index, basisElem] : llvm::zip_equal(multiIndex, basis)) {
5035 if (!basisEntry || *basisEntry != 1) {
5036 newIndices.push_back(index);
5037 newBasis.push_back(basisElem);
5042 if (!op.getDisjoint() && (!indexValue || *indexValue != 0)) {
5043 newIndices.push_back(index);
5044 newBasis.push_back(basisElem);
5048 if (newIndices.size() == numIndices)
5050 "no unit basis entries to replace");
5052 if (newIndices.size() == 0) {
5057 op, newIndices, newBasis, op.getDisjoint());
5066 int64_t nDynamic = 0;
5076 dynamicPart.push_back(cast<Value>(term));
5080 if (
auto constant = dyn_cast<AffineConstantExpr>(result))
5082 return builder.
create<AffineApplyOp>(loc, result, dynamicPart).getResult();
5112 struct CancelLinearizeOfDelinearizePortion final
5123 unsigned linStart = 0;
5124 unsigned delinStart = 0;
5125 unsigned length = 0;
5129 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp,
5136 ValueRange multiIndex = linearizeOp.getMultiIndex();
5137 unsigned numLinArgs = multiIndex.size();
5138 unsigned linArgIdx = 0;
5142 while (linArgIdx < numLinArgs) {
5143 auto asResult = dyn_cast<OpResult>(multiIndex[linArgIdx]);
5149 auto delinearizeOp =
5150 dyn_cast<AffineDelinearizeIndexOp>(asResult.getOwner());
5151 if (!delinearizeOp) {
5168 unsigned delinArgIdx = asResult.getResultNumber();
5170 OpFoldResult firstDelinBound = delinBasis[delinArgIdx];
5172 bool boundsMatch = firstDelinBound == firstLinBound;
5173 bool bothAtFront = linArgIdx == 0 && delinArgIdx == 0;
5174 bool knownByDisjoint =
5175 linearizeOp.getDisjoint() && delinArgIdx == 0 && !firstDelinBound;
5176 if (!boundsMatch && !bothAtFront && !knownByDisjoint) {
5182 unsigned numDelinOuts = delinearizeOp.getNumResults();
5183 for (;
j + linArgIdx < numLinArgs &&
j + delinArgIdx < numDelinOuts;
5185 if (multiIndex[linArgIdx +
j] !=
5186 delinearizeOp.getResult(delinArgIdx +
j))
5188 if (linBasis[linArgIdx +
j] != delinBasis[delinArgIdx +
j])
5194 if (
j <= 1 || !alreadyMatchedDelinearize.insert(delinearizeOp).second) {
5198 matches.push_back(Match{delinearizeOp, linArgIdx, delinArgIdx,
j});
5202 if (matches.empty())
5204 linearizeOp,
"no run of delinearize outputs to deal with");
5212 newIndex.reserve(numLinArgs);
5214 newBasis.reserve(numLinArgs);
5215 unsigned prevMatchEnd = 0;
5216 for (Match m : matches) {
5217 unsigned gap = m.linStart - prevMatchEnd;
5218 llvm::append_range(newIndex, multiIndex.slice(prevMatchEnd, gap));
5219 llvm::append_range(newBasis, linBasisRef.slice(prevMatchEnd, gap));
5221 prevMatchEnd = m.linStart + m.length;
5223 PatternRewriter::InsertionGuard g(rewriter);
5227 linBasisRef.slice(m.linStart, m.length);
5234 if (m.length == m.delinearize.getNumResults()) {
5235 newIndex.push_back(m.delinearize.getLinearIndex());
5236 newBasis.push_back(newSize);
5244 newDelinBasis.erase(newDelinBasis.begin() + m.delinStart,
5245 newDelinBasis.begin() + m.delinStart + m.length);
5246 newDelinBasis.insert(newDelinBasis.begin() + m.delinStart, newSize);
5247 auto newDelinearize = rewriter.
create<AffineDelinearizeIndexOp>(
5248 m.delinearize.getLoc(), m.delinearize.getLinearIndex(),
5254 Value combinedElem = newDelinearize.getResult(m.delinStart);
5255 auto residualDelinearize = rewriter.
create<AffineDelinearizeIndexOp>(
5256 m.delinearize.getLoc(), combinedElem, basisToMerge);
5261 llvm::append_range(newDelinResults,
5262 newDelinearize.getResults().take_front(m.delinStart));
5263 llvm::append_range(newDelinResults, residualDelinearize.getResults());
5266 newDelinearize.getResults().drop_front(m.delinStart + 1));
5268 delinearizeReplacements.push_back(newDelinResults);
5269 newIndex.push_back(combinedElem);
5270 newBasis.push_back(newSize);
5272 llvm::append_range(newIndex, multiIndex.drop_front(prevMatchEnd));
5273 llvm::append_range(newBasis, linBasisRef.drop_front(prevMatchEnd));
5275 linearizeOp, newIndex, newBasis, linearizeOp.getDisjoint());
5277 for (
auto [m, newResults] :
5278 llvm::zip_equal(matches, delinearizeReplacements)) {
5279 if (newResults.empty())
5281 rewriter.
replaceOp(m.delinearize, newResults);
5292 struct DropLinearizeLeadingZero final
5296 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5298 Value leadingIdx = op.getMultiIndex().front();
5302 if (op.getMultiIndex().size() == 1) {
5309 if (op.hasOuterBound())
5310 newMixedBasis = newMixedBasis.drop_front();
5313 op, op.getMultiIndex().drop_front(), newMixedBasis, op.getDisjoint());
5319 void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
5321 patterns.add<CancelLinearizeOfDelinearizePortion, DropLinearizeLeadingZero,
5322 DropLinearizeUnitComponentsIfDisjointOrZero>(context);
5329 #define GET_OP_CLASSES
5330 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
static AffineForOp buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb, int64_t ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds known to be constants.
static bool hasTrivialZeroTripCount(AffineForOp op)
Returns true if the affine.for has zero iterations in trivial cases.
static void composeMultiResultAffineMap(AffineMap &map, SmallVectorImpl< Value > &operands)
Composes the given affine map with the given list of operands, pulling in the maps from any affine....
static void printAffineMinMaxOp(OpAsmPrinter &p, T op)
static bool isResultTypeMatchAtomicRMWKind(Type resultType, arith::AtomicRMWKind op)
static bool remainsLegalAfterInline(Value value, Region *src, Region *dest, const IRMapping &mapping, function_ref< bool(Value, Region *)> legalityCheck)
Checks if value known to be a legal affine dimension or symbol in src region remains legal if the ope...
static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr, DenseIntElementsAttr group, ValueRange operands, StringRef keyword)
Prints a lower(upper) bound of an affine parallel loop with max(min) conditions in it.
static void LLVM_ATTRIBUTE_UNUSED simplifyMapWithOperands(AffineMap &map, ArrayRef< Value > operands)
Simplify the map while exploiting information on the values in operands.
static OpFoldResult foldMinMaxOp(T op, ArrayRef< Attribute > operands)
Fold an affine min or max operation with the given operands.
static LogicalResult canonicalizeLoopBounds(AffineForOp forOp)
Canonicalize the bounds of the given loop.
static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims, unsigned numSymbols, ArrayRef< Value > operands)
Simplify expr while exploiting information from the values in operands.
static bool isValidAffineIndexOperand(Value value, Region *region)
static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
static void composeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Iterate over operands and fold away all those produced by an AffineApplyOp iteratively.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static ParseResult parseBound(bool isLower, OperationState &result, OpAsmParser &p)
Parse a for operation loop bounds.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static void composeSetAndOperands(IntegerSet &set, SmallVectorImpl< Value > &operands)
Compose any affine.apply ops feeding into operands of the integer set set by composing the maps of su...
static LogicalResult replaceDimOrSym(AffineMap *map, unsigned dimOrSymbolPosition, SmallVectorImpl< Value > &dims, SmallVectorImpl< Value > &syms)
Replace all occurrences of AffineExpr at position pos in map by the defining AffineApplyOp expression...
static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType)
Verify common invariants of affine.vector_load and affine.vector_store.
static void simplifyMinOrMaxExprWithOperands(AffineMap &map, ArrayRef< Value > operands, bool isMax)
Simplify the expressions in map while making use of lower or upper bounds of its operands.
static ParseResult parseAffineMinMaxOp(OpAsmParser &parser, OperationState &result)
static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, Region *region)
Returns true if the 'index' dimension of the memref defined by memrefDefOp is a statically shaped one...
static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef< Value > operands, int64_t k)
Check if e is known to be: 0 <= e < k.
static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser, OperationState &result, MinMaxKind kind)
Parses an affine map that can contain a min/max for groups of its results, e.g., max(expr-1,...
static AffineForOp buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds that may or may not be constants.
static void printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, OpAsmPrinter &printer)
Prints dimension and symbol list.
static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef< Value > operands)
Returns the largest known divisor of e.
static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
static void buildAffineLoopNestImpl(OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, LoopCreatorTy &&loopCreatorFn)
Builds an affine loop nest, using "loopCreatorFn" to create individual loop operations.
static LogicalResult foldLoopBounds(AffineForOp forOp)
Fold the constant bounds of a loop.
static LogicalResult verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, unsigned numDims)
Utility function to verify that a set of operands are valid dimension and symbol identifiers.
static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region)
Returns true if the result of the dim op is a valid symbol for region.
static bool isQTimesDPlusR(AffineExpr e, ArrayRef< Value > operands, int64_t &div, AffineExpr "ientTimesDiv, AffineExpr &rem)
Check if expression e is of the form d*e_1 + e_2 where 0 <= e_2 < d.
static ParseResult deduplicateAndResolveOperands(OpAsmParser &parser, ArrayRef< SmallVector< OpAsmParser::UnresolvedOperand >> operands, SmallVectorImpl< Value > &uniqueOperands, SmallVectorImpl< AffineExpr > &replacements, AffineExprKind kind)
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult verifyAffineMinMaxOp(T op)
static void printBound(AffineMapAttr boundMap, Operation::operand_range boundOperands, const char *prefix, OpAsmPrinter &p)
static LogicalResult verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr, Operation::operand_range mapOperands, MemRefType memrefType, unsigned numIndexOperands)
Verify common indexing invariants of affine.load, affine.store, affine.vector_load and affine....
static std::optional< SmallVector< int64_t > > foldCstValueToCstAttrBasis(ArrayRef< OpFoldResult > mixedBasis, MutableOperandRange mutableDynamicBasis, ArrayRef< Attribute > dynamicBasis)
Given mixed basis of affine.delinearize_index/linearize_index replace constant SSA values with the co...
static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map)
Canonicalize the result expression order of an affine map and return success if the order changed.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static Operation::operand_range getLowerBoundOperands(AffineForOp forOp)
static Operation::operand_range getUpperBoundOperands(AffineForOp forOp)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
AffineExprKind getKind() const
Return the classification for this type.
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
MLIRContext * getContext() const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
MLIRContext * getContext() const
bool isFunctionOfDim(unsigned position) const
Return true if any affine expression involves AffineDimExpr position.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
bool isFunctionOfSymbol(unsigned position) const
Return true if any affine expression involves AffineSymbolExpr position.
unsigned getNumResults() const
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
unsigned getNumInputs() const
AffineMap shiftSymbols(unsigned shift, unsigned offset=0) const
Replace symbols[offset ...
AffineExpr getResult(unsigned idx) const
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
static AffineMap getConstantMap(int64_t val, MLIRContext *context)
Returns a single constant result affine map.
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
LogicalResult constantFold(ArrayRef< Attribute > operandConstants, SmallVectorImpl< Attribute > &results, bool *hasPoison=nullptr) const
Folds the results of the application of an affine map on the provided operands to a constant if possi...
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
AffineMap getDimIdentityMap()
AffineMap getMultiDimIdentityMap(unsigned rank)
AffineExpr getAffineSymbolExpr(unsigned position)
AffineExpr getAffineConstantExpr(int64_t constant)
DenseIntElementsAttr getI32TensorAttr(ArrayRef< int32_t > values)
Tensor-typed DenseIntElementsAttr getters.
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
AffineMap getEmptyAffineMap()
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
MLIRContext * getContext() const
AffineMap getSymbolIdentityMap()
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
An attribute that represents a reference to a dense integer vector or tensor object.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
An integer set representing a conjunction of one or more affine equalities and inequalities.
unsigned getNumDims() const
static IntegerSet get(unsigned dimCount, unsigned symbolCount, ArrayRef< AffineExpr > constraints, ArrayRef< bool > eqFlags)
MLIRContext * getContext() const
unsigned getNumInputs() const
ArrayRef< AffineExpr > getConstraints() const
ArrayRef< bool > getEqFlags() const
Returns the equality bits, which specify whether each of the constraints is an equality or inequality...
unsigned getNumSymbols() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
void erase(unsigned subStart, unsigned subLen=1)
Erase the operands within the given sub-range.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void pop_back()
Pop last element from list.
Attribute erase(StringAttr name)
Erase the attribute with the given name from the list.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseAffineMapOfSSAIds(SmallVectorImpl< UnresolvedOperand > &operands, Attribute &map, StringRef attrName, NamedAttrList &attrs, Delimiter delimiter=Delimiter::Square)=0
Parses an affine map attribute where dims and symbols are SSA operands.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseAffineExprOfSSAIds(SmallVectorImpl< UnresolvedOperand > &dimOperands, SmallVectorImpl< UnresolvedOperand > &symbOperands, AffineExpr &expr)=0
Parses an affine expression where dims and symbols are SSA operands.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, ValueRange symOperands)=0
Prints an affine expression of SSA ids with SSA id names used instead of dims and symbols.
virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, ValueRange operands)=0
Prints an affine map of SSA ids, where SSA id names are used in place of dims/symbols.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
A trait of region holding operations that defines a new scope for polyhedral optimization purposes.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
operand_range::iterator operand_iterator
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
std::vector< SmallVector< int64_t, 8 > > operandExprStack
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
AffineBound represents a lower or upper bound in the for operation.
AffineDmaStartOp starts a non-blocking DMA operation that transfers data from a source memref to a de...
AffineDmaWaitOp blocks until the completion of a DMA operation associated with the tag element 'tag[i...
An AffineValueMap is an affine map plus its ML value operands and results for analysis purposes.
LogicalResult canonicalize()
Attempts to canonicalize the map and operands.
ArrayRef< Value > getOperands() const
AffineExpr getResult(unsigned i)
AffineMap getAffineMap() const
unsigned getNumResults() const
Operation * getOwner() const
Return the owner of this operand.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
void buildAffineLoopNest(OpBuilder &builder, Location loc, ArrayRef< int64_t > lbs, ArrayRef< int64_t > ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn=nullptr)
Builds a perfect nest of affine.for loops, i.e., each loop except the innermost one contains only ano...
void fullyComposeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Given an affine map map and its input operands, this method composes into map, maps of AffineApplyOps...
void extractForInductionVars(ArrayRef< AffineForOp > forInsts, SmallVectorImpl< Value > *ivs)
Extracts the induction variables from a list of AffineForOps and places them in the output argument i...
bool isValidDim(Value value)
Returns true if the given Value can be used as a dimension id in the region of the closest surroundin...
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
bool isAffineInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp or AffineParallelOp.
AffineForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
bool isAffineForInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp.
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
bool isTopLevelValue(Value value)
A utility function to check if a value is defined at the top level of an op with trait AffineScope or...
void canonicalizeSetAndOperands(IntegerSet *set, SmallVectorImpl< Value > *operands)
Canonicalizes an integer set the same way canonicalizeMapAndOperands does for affine maps.
void extractInductionVars(ArrayRef< Operation * > affineOps, SmallVectorImpl< Value > &ivs)
Extracts the induction variables from a list of either AffineForOp or AffineParallelOp and places the...
bool isValidSymbol(Value value)
Returns true if the given value can be used as a symbol in the region of the closest surrounding op t...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
AffineParallelOp getAffineParallelInductionVarOwner(Value val)
Returns true if the provided value is among the induction variables of an AffineParallelOp.
Region * getAffineScope(Operation *op)
Returns the closest region enclosing op that is held by an operation with trait AffineScope; nullptr ...
ParseResult parseDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl< Value > &operands, unsigned &numDims)
Parses dimension and symbol list.
bool isAffineParallelInductionVar(Value val)
Returns true if val is the induction variable of an AffineParallelOp.
AffineMinOp makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns an AffineMinOp obtained by composing map and operands with AffineApplyOps supplying those ope...
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
AffineMap removeDuplicateExprs(AffineMap map)
Returns a map with the same dimension and symbol count as map, but whose results are the unique affin...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::optional< int64_t > getBoundForAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols, ArrayRef< std::optional< int64_t >> constLowerBounds, ArrayRef< std::optional< int64_t >> constUpperBounds, bool isUpper)
Get a lower or upper (depending on isUpper) bound for expr while using the constant lower and upper b...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ SymbolId
Symbolic identifier.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
AffineMap foldAttributesIntoMap(Builder &b, AffineMap map, ArrayRef< OpFoldResult > operands, SmallVector< Value > &remainingValues)
Fold all attributes among the given operands into the affine map.
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Canonicalize the affine map result expression order of an affine min/max operation.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Remove duplicated expressions in affine min/max ops.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Merge an affine min/max op to its consumers if its consumer is also an affine min/max op.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.