22 #include "llvm/ADT/ScopeExit.h"
23 #include "llvm/ADT/SmallBitVector.h"
24 #include "llvm/ADT/SmallVectorExtras.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/MathExtras.h"
34 using llvm::divideCeilSigned;
35 using llvm::divideFloorSigned;
38 #define DEBUG_TYPE "affine-ops"
40 #include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
47 if (
auto arg = llvm::dyn_cast<BlockArgument>(value))
48 return arg.getParentRegion() == region;
71 if (llvm::isa<BlockArgument>(value))
72 return legalityCheck(mapping.
lookup(value), dest);
79 bool isDimLikeOp = isa<ShapedDimOpInterface>(value.
getDefiningOp());
90 return llvm::all_of(values, [&](
Value v) {
97 template <
typename OpTy>
100 static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
101 AffineWriteOpInterface>::value,
102 "only ops with affine read/write interface are supported");
109 dimOperands, src, dest, mapping,
113 symbolOperands, src, dest, mapping,
130 op.getMapOperands(), src, dest, mapping,
135 op.getMapOperands(), src, dest, mapping,
162 if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
167 if (!llvm::hasSingleElement(*src))
175 if (
auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
176 if (iface.hasNoEffect())
184 .Case<AffineApplyOp, AffineReadOpInterface,
185 AffineWriteOpInterface>([&](
auto op) {
210 isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
214 bool shouldAnalyzeRecursively(
Operation *op)
const final {
return true; }
222 void AffineDialect::initialize() {
225 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
227 addInterfaces<AffineInlinerInterface>();
228 declarePromisedInterfaces<ValueBoundsOpInterface, AffineApplyOp, AffineMaxOp,
237 if (
auto poison = dyn_cast<ub::PoisonAttr>(value))
238 return builder.
create<ub::PoisonOp>(loc, type, poison);
239 return arith::ConstantOp::materialize(builder, value, type, loc);
247 if (
auto arg = llvm::dyn_cast<BlockArgument>(value)) {
263 while (
auto *parentOp = curOp->getParentOp()) {
286 auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
288 isa<AffineForOp, AffineParallelOp>(parentOp));
309 auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->
getParentOp();
310 return isa<AffineForOp, AffineParallelOp>(parentOp);
314 if (
auto applyOp = dyn_cast<AffineApplyOp>(op))
315 return applyOp.isValidDim(region);
318 if (
auto dimOp = dyn_cast<ShapedDimOpInterface>(op))
326 template <
typename AnyMemRefDefOp>
329 MemRefType memRefType = memrefDefOp.getType();
332 if (index >= memRefType.getRank()) {
337 if (!memRefType.isDynamicDim(index))
340 unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
341 return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
353 if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
361 if (!index.has_value())
365 Operation *op = dimOp.getShapedValue().getDefiningOp();
366 while (
auto castOp = dyn_cast<memref::CastOp>(op)) {
368 if (isa<UnrankedMemRefType>(castOp.getSource().getType()))
370 op = castOp.getSource().getDefiningOp();
375 int64_t i = index.value();
377 .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
379 .Default([](
Operation *) {
return false; });
445 if (
auto applyOp = dyn_cast<AffineApplyOp>(defOp))
446 return applyOp.isValidSymbol(region);
449 if (
auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp))
473 printer <<
'(' << operands.take_front(numDims) <<
')';
474 if (operands.size() > numDims)
475 printer <<
'[' << operands.drop_front(numDims) <<
']';
485 numDims = opInfos.size();
499 template <
typename OpTy>
504 for (
auto operand : operands) {
505 if (opIt++ < numDims) {
507 return op.
emitOpError(
"operand cannot be used as a dimension id");
509 return op.
emitOpError(
"operand cannot be used as a symbol");
520 return AffineValueMap(getAffineMap(), getOperands(), getResult());
527 AffineMapAttr mapAttr;
533 auto map = mapAttr.getValue();
535 if (map.getNumDims() != numDims ||
536 numDims + map.getNumSymbols() != result.
operands.size()) {
538 "dimension or symbol index mismatch");
541 result.
types.append(map.getNumResults(), indexTy);
546 p <<
" " << getMapAttr();
548 getAffineMap().getNumDims(), p);
559 "operand count and affine map dimension and symbol count must match");
563 return emitOpError(
"mapping must produce one value");
571 return llvm::all_of(getOperands(),
579 return llvm::all_of(getOperands(),
586 return llvm::all_of(getOperands(),
593 return llvm::all_of(getOperands(), [&](
Value operand) {
599 auto map = getAffineMap();
602 auto expr = map.getResult(0);
603 if (
auto dim = dyn_cast<AffineDimExpr>(expr))
604 return getOperand(dim.getPosition());
605 if (
auto sym = dyn_cast<AffineSymbolExpr>(expr))
606 return getOperand(map.getNumDims() + sym.getPosition());
610 bool hasPoison =
false;
612 map.constantFold(adaptor.getMapOperands(), result, &hasPoison);
615 if (failed(foldResult))
632 auto dimExpr = dyn_cast<AffineDimExpr>(e);
642 Value operand = operands[dimExpr.getPosition()];
643 int64_t operandDivisor = 1;
647 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() == 0) {
648 operandDivisor = forOp.getStepAsInt();
650 uint64_t lbLargestKnownDivisor =
651 forOp.getLowerBoundMap().getLargestKnownDivisorOfMapExprs();
652 operandDivisor = std::gcd(lbLargestKnownDivisor, forOp.getStepAsInt());
655 return operandDivisor;
662 if (
auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
663 int64_t constVal = constExpr.getValue();
664 return constVal >= 0 && constVal < k;
666 auto dimExpr = dyn_cast<AffineDimExpr>(e);
669 Value operand = operands[dimExpr.getPosition()];
673 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() >= 0 &&
674 forOp.hasConstantUpperBound() && forOp.getConstantUpperBound() <= k) {
690 auto bin = dyn_cast<AffineBinaryOpExpr>(e);
698 quotientTimesDiv = llhs;
704 quotientTimesDiv = rlhs;
714 if (forOp && forOp.hasConstantLowerBound())
715 return forOp.getConstantLowerBound();
722 if (!forOp || !forOp.hasConstantUpperBound())
727 if (forOp.hasConstantLowerBound()) {
728 return forOp.getConstantUpperBound() - 1 -
729 (forOp.getConstantUpperBound() - forOp.getConstantLowerBound() - 1) %
730 forOp.getStepAsInt();
732 return forOp.getConstantUpperBound() - 1;
743 constLowerBounds.reserve(operands.size());
744 constUpperBounds.reserve(operands.size());
745 for (
Value operand : operands) {
750 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr))
751 return constExpr.getValue();
766 constLowerBounds.reserve(operands.size());
767 constUpperBounds.reserve(operands.size());
768 for (
Value operand : operands) {
773 std::optional<int64_t> lowerBound;
774 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
775 lowerBound = constExpr.getValue();
778 constLowerBounds, constUpperBounds,
789 auto binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
800 binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
808 lhs = binExpr.getLHS();
809 rhs = binExpr.getRHS();
810 auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
814 int64_t rhsConstVal = rhsConst.getValue();
816 if (rhsConstVal <= 0)
821 std::optional<int64_t> lhsLbConst =
823 std::optional<int64_t> lhsUbConst =
825 if (lhsLbConst && lhsUbConst) {
826 int64_t lhsLbConstVal = *lhsLbConst;
827 int64_t lhsUbConstVal = *lhsUbConst;
831 divideFloorSigned(lhsLbConstVal, rhsConstVal) ==
832 divideFloorSigned(lhsUbConstVal, rhsConstVal)) {
834 divideFloorSigned(lhsLbConstVal, rhsConstVal), context);
840 divideCeilSigned(lhsLbConstVal, rhsConstVal) ==
841 divideCeilSigned(lhsUbConstVal, rhsConstVal)) {
848 lhsLbConstVal < rhsConstVal && lhsUbConstVal < rhsConstVal) {
860 if (
isQTimesDPlusR(lhs, operands, divisor, quotientTimesDiv, rem)) {
861 if (rhsConstVal % divisor == 0 &&
863 expr = quotientTimesDiv.
floorDiv(rhsConst);
864 }
else if (divisor % rhsConstVal == 0 &&
866 expr = rem % rhsConst;
892 if (operands.empty())
898 constLowerBounds.reserve(operands.size());
899 constUpperBounds.reserve(operands.size());
900 for (
Value operand : operands) {
914 if (
auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
915 lowerBounds.push_back(constExpr.getValue());
916 upperBounds.push_back(constExpr.getValue());
918 lowerBounds.push_back(
920 constLowerBounds, constUpperBounds,
922 upperBounds.push_back(
924 constLowerBounds, constUpperBounds,
933 unsigned i = exprEn.index();
935 if (lowerBounds[i] && upperBounds[i] && *lowerBounds[i] == *upperBounds[i])
940 if (!upperBounds[i]) {
941 irredundantExprs.push_back(e);
947 auto otherLowerBound = en.value();
948 unsigned pos = en.index();
949 if (pos == i || !otherLowerBound)
951 if (*otherLowerBound > *upperBounds[i])
953 if (*otherLowerBound < *upperBounds[i])
958 if (upperBounds[pos] && lowerBounds[i] &&
959 lowerBounds[i] == upperBounds[i] &&
960 otherLowerBound == *upperBounds[pos] && i < pos)
964 irredundantExprs.push_back(e);
966 if (!lowerBounds[i]) {
967 irredundantExprs.push_back(e);
972 auto otherUpperBound = en.value();
973 unsigned pos = en.index();
974 if (pos == i || !otherUpperBound)
976 if (*otherUpperBound < *lowerBounds[i])
978 if (*otherUpperBound > *lowerBounds[i])
980 if (lowerBounds[pos] && upperBounds[i] &&
981 lowerBounds[i] == upperBounds[i] &&
982 otherUpperBound == lowerBounds[pos] && i < pos)
986 irredundantExprs.push_back(e);
998 static void LLVM_ATTRIBUTE_UNUSED
1000 assert(map.
getNumInputs() == operands.size() &&
"invalid operands for map");
1006 newResults.push_back(expr);
1023 unsigned dimOrSymbolPosition,
1027 bool isDimReplacement = (dimOrSymbolPosition < dims.size());
1028 unsigned pos = isDimReplacement ? dimOrSymbolPosition
1029 : dimOrSymbolPosition - dims.size();
1030 Value &v = isDimReplacement ? dims[pos] : syms[pos];
1043 AffineMap composeMap = affineApply.getAffineMap();
1044 assert(composeMap.
getNumResults() == 1 &&
"affine.apply with >1 results");
1046 affineApply.getMapOperands().end());
1060 dims.append(composeDims.begin(), composeDims.end());
1061 syms.append(composeSyms.begin(), composeSyms.end());
1062 *map = map->
replace(toReplace, replacementExpr, dims.size(), syms.size());
1090 bool changed =
false;
1091 for (
unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
1103 unsigned nDims = 0, nSyms = 0;
1105 dimReplacements.reserve(dims.size());
1106 symReplacements.reserve(syms.size());
1107 for (
auto *container : {&dims, &syms}) {
1108 bool isDim = (container == &dims);
1109 auto &repls = isDim ? dimReplacements : symReplacements;
1111 Value v = en.value();
1115 "map is function of unexpected expr@pos");
1121 operands->push_back(v);
1134 while (llvm::any_of(*operands, [](
Value v) {
1148 return b.
create<AffineApplyOp>(loc, map, valueOperands);
1170 for (
unsigned i : llvm::seq<unsigned>(0, map.
getNumResults())) {
1177 llvm::append_range(dims,
1179 llvm::append_range(symbols,
1186 operands = llvm::to_vector(llvm::concat<Value>(dims, symbols));
1195 assert(map.
getNumResults() == 1 &&
"building affine.apply with !=1 result");
1205 AffineApplyOp applyOp =
1210 for (
unsigned i = 0, e = constOperands.size(); i != e; ++i)
1215 if (failed(applyOp->fold(constOperands, foldResults)) ||
1216 foldResults.empty()) {
1218 listener->notifyOperationInserted(applyOp, {});
1219 return applyOp.getResult();
1223 assert(foldResults.size() == 1 &&
"expected 1 folded result");
1224 return foldResults.front();
1242 return llvm::map_to_vector(llvm::seq<unsigned>(0, map.
getNumResults()),
1244 return makeComposedFoldedAffineApply(
1245 b, loc, map.getSubMap({i}), operands);
1249 template <
typename OpTy>
1261 return makeComposedMinMax<AffineMinOp>(b, loc, map, operands);
1264 template <
typename OpTy>
1276 auto minMaxOp = makeComposedMinMax<OpTy>(newBuilder, loc, map, operands);
1280 for (
unsigned i = 0, e = constOperands.size(); i != e; ++i)
1285 if (failed(minMaxOp->fold(constOperands, foldResults)) ||
1286 foldResults.empty()) {
1288 listener->notifyOperationInserted(minMaxOp, {});
1289 return minMaxOp.getResult();
1293 assert(foldResults.size() == 1 &&
"expected 1 folded result");
1294 return foldResults.front();
1301 return makeComposedFoldedMinMax<AffineMinOp>(b, loc, map, operands);
1308 return makeComposedFoldedMinMax<AffineMaxOp>(b, loc, map, operands);
1313 template <
class MapOrSet>
1316 if (!mapOrSet || operands->empty())
1319 assert(mapOrSet->getNumInputs() == operands->size() &&
1320 "map/set inputs must match number of operands");
1322 auto *context = mapOrSet->getContext();
1324 resultOperands.reserve(operands->size());
1326 remappedSymbols.reserve(operands->size());
1327 unsigned nextDim = 0;
1328 unsigned nextSym = 0;
1329 unsigned oldNumSyms = mapOrSet->getNumSymbols();
1331 for (
unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
1332 if (i < mapOrSet->getNumDims()) {
1336 remappedSymbols.push_back((*operands)[i]);
1339 resultOperands.push_back((*operands)[i]);
1342 resultOperands.push_back((*operands)[i]);
1346 resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
1347 *operands = resultOperands;
1348 *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
1349 oldNumSyms + nextSym);
1351 assert(mapOrSet->getNumInputs() == operands->size() &&
1352 "map/set inputs must match number of operands");
1356 template <
class MapOrSet>
1359 static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
1360 "Argument must be either of AffineMap or IntegerSet type");
1362 if (!mapOrSet || operands->empty())
1365 assert(mapOrSet->getNumInputs() == operands->size() &&
1366 "map/set inputs must match number of operands");
1368 canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
1371 llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
1372 llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
1374 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr))
1375 usedDims[dimExpr.getPosition()] =
true;
1376 else if (
auto symExpr = dyn_cast<AffineSymbolExpr>(expr))
1377 usedSyms[symExpr.getPosition()] =
true;
1380 auto *context = mapOrSet->getContext();
1383 resultOperands.reserve(operands->size());
1385 llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
1387 unsigned nextDim = 0;
1388 for (
unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
1391 auto it = seenDims.find((*operands)[i]);
1392 if (it == seenDims.end()) {
1394 resultOperands.push_back((*operands)[i]);
1395 seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
1397 dimRemapping[i] = it->second;
1401 llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
1403 unsigned nextSym = 0;
1404 for (
unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
1410 IntegerAttr operandCst;
1411 if (
matchPattern((*operands)[i + mapOrSet->getNumDims()],
1418 auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
1419 if (it == seenSymbols.end()) {
1421 resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
1422 seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
1425 symRemapping[i] = it->second;
1428 *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
1430 *operands = resultOperands;
1435 canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
1440 canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
1447 template <
typename AffineOpTy>
1456 LogicalResult matchAndRewrite(AffineOpTy affineOp,
1459 llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
1460 AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp,
1461 AffineVectorStoreOp, AffineVectorLoadOp>::value,
1462 "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op "
1464 auto map = affineOp.getAffineMap();
1466 auto oldOperands = affineOp.getMapOperands();
1471 if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
1472 resultOperands.begin()))
1475 replaceAffineOp(rewriter, affineOp, map, resultOperands);
1483 void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
1490 void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
1494 prefetch, prefetch.getMemref(), map, mapOperands, prefetch.getIsWrite(),
1495 prefetch.getLocalityHint(), prefetch.getIsDataCache());
1498 void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
1502 store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
1505 void SimplifyAffineOp<AffineVectorLoadOp>::replaceAffineOp(
1509 vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map,
1513 void SimplifyAffineOp<AffineVectorStoreOp>::replaceAffineOp(
1517 vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map,
1522 template <
typename AffineOpTy>
1523 void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
1532 results.
add<SimplifyAffineOp<AffineApplyOp>>(context);
1563 p <<
" " << getSrcMemRef() <<
'[';
1565 p <<
"], " << getDstMemRef() <<
'[';
1567 p <<
"], " << getTagMemRef() <<
'[';
1571 p <<
", " << getStride();
1572 p <<
", " << getNumElementsPerStride();
1574 p <<
" : " << getSrcMemRefType() <<
", " << getDstMemRefType() <<
", "
1575 << getTagMemRefType();
1587 AffineMapAttr srcMapAttr;
1590 AffineMapAttr dstMapAttr;
1593 AffineMapAttr tagMapAttr;
1608 getSrcMapAttrStrName(),
1612 getDstMapAttrStrName(),
1616 getTagMapAttrStrName(),
1625 if (!strideInfo.empty() && strideInfo.size() != 2) {
1627 "expected two stride related operands");
1629 bool isStrided = strideInfo.size() == 2;
1634 if (types.size() != 3)
1652 if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
1653 dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
1654 tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1656 "memref operand count not equal to map.numInputs");
1660 LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
1661 if (!llvm::isa<MemRefType>(getOperand(getSrcMemRefOperandIndex()).
getType()))
1662 return emitOpError(
"expected DMA source to be of memref type");
1663 if (!llvm::isa<MemRefType>(getOperand(getDstMemRefOperandIndex()).
getType()))
1664 return emitOpError(
"expected DMA destination to be of memref type");
1665 if (!llvm::isa<MemRefType>(getOperand(getTagMemRefOperandIndex()).
getType()))
1666 return emitOpError(
"expected DMA tag to be of memref type");
1668 unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
1669 getDstMap().getNumInputs() +
1670 getTagMap().getNumInputs();
1671 if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
1672 getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
1673 return emitOpError(
"incorrect number of operands");
1677 for (
auto idx : getSrcIndices()) {
1678 if (!idx.getType().isIndex())
1679 return emitOpError(
"src index to dma_start must have 'index' type");
1682 "src index must be a valid dimension or symbol identifier");
1684 for (
auto idx : getDstIndices()) {
1685 if (!idx.getType().isIndex())
1686 return emitOpError(
"dst index to dma_start must have 'index' type");
1689 "dst index must be a valid dimension or symbol identifier");
1691 for (
auto idx : getTagIndices()) {
1692 if (!idx.getType().isIndex())
1693 return emitOpError(
"tag index to dma_start must have 'index' type");
1696 "tag index must be a valid dimension or symbol identifier");
1707 void AffineDmaStartOp::getEffects(
1733 p <<
" " << getTagMemRef() <<
'[';
1738 p <<
" : " << getTagMemRef().getType();
1749 AffineMapAttr tagMapAttr;
1758 getTagMapAttrStrName(),
1767 if (!llvm::isa<MemRefType>(type))
1769 "expected tag to be of memref type");
1771 if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1773 "tag memref operand count != to map.numInputs");
1777 LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
1778 if (!llvm::isa<MemRefType>(getOperand(0).
getType()))
1779 return emitOpError(
"expected DMA tag to be of memref type");
1781 for (
auto idx : getTagIndices()) {
1782 if (!idx.getType().isIndex())
1783 return emitOpError(
"index to dma_wait must have 'index' type");
1786 "index must be a valid dimension or symbol identifier");
1797 void AffineDmaWaitOp::getEffects(
1813 ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
1814 assert(((!lbMap && lbOperands.empty()) ||
1816 "lower bound operand count does not match the affine map");
1817 assert(((!ubMap && ubOperands.empty()) ||
1819 "upper bound operand count does not match the affine map");
1820 assert(step > 0 &&
"step has to be a positive integer constant");
1826 getOperandSegmentSizeAttr(),
1828 static_cast<int32_t>(ubOperands.size()),
1829 static_cast<int32_t>(iterArgs.size())}));
1831 for (
Value val : iterArgs)
1853 Value inductionVar =
1855 for (
Value val : iterArgs)
1856 bodyBlock->
addArgument(val.getType(), val.getLoc());
1861 if (iterArgs.empty() && !bodyBuilder) {
1862 ensureTerminator(*bodyRegion, builder, result.
location);
1863 }
else if (bodyBuilder) {
1866 bodyBuilder(builder, result.
location, inductionVar,
1872 int64_t ub, int64_t step,
ValueRange iterArgs,
1873 BodyBuilderFn bodyBuilder) {
1876 return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,
1880 LogicalResult AffineForOp::verifyRegions() {
1883 auto *body = getBody();
1884 if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
1885 return emitOpError(
"expected body to have a single index argument for the "
1886 "induction variable");
1890 if (getLowerBoundMap().getNumInputs() > 0)
1892 getLowerBoundMap().getNumDims())))
1895 if (getUpperBoundMap().getNumInputs() > 0)
1897 getUpperBoundMap().getNumDims())))
1900 unsigned opNumResults = getNumResults();
1901 if (opNumResults == 0)
1907 if (getNumIterOperands() != opNumResults)
1909 "mismatch between the number of loop-carried values and results");
1910 if (getNumRegionIterArgs() != opNumResults)
1912 "mismatch between the number of basic block args and results");
1922 bool failedToParsedMinMax =
1926 auto boundAttrStrName =
1927 isLower ? AffineForOp::getLowerBoundMapAttrName(result.
name)
1928 : AffineForOp::getUpperBoundMapAttrName(result.
name);
1935 if (!boundOpInfos.empty()) {
1937 if (boundOpInfos.size() > 1)
1939 "expected only one loop bound operand");
1964 if (
auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(boundAttr)) {
1965 unsigned currentNumOperands = result.
operands.size();
1970 auto map = affineMapAttr.getValue();
1974 "dim operand count and affine map dim count must match");
1976 unsigned numDimAndSymbolOperands =
1977 result.
operands.size() - currentNumOperands;
1978 if (numDims + map.
getNumSymbols() != numDimAndSymbolOperands)
1981 "symbol operand count and affine map symbol count must match");
1987 return p.
emitError(attrLoc,
"lower loop bound affine map with "
1988 "multiple results requires 'max' prefix");
1990 return p.
emitError(attrLoc,
"upper loop bound affine map with multiple "
1991 "results requires 'min' prefix");
1997 if (
auto integerAttr = llvm::dyn_cast<IntegerAttr>(boundAttr)) {
2007 "expected valid affine map representation for loop bounds");
2019 int64_t numOperands = result.
operands.size();
2022 int64_t numLbOperands = result.
operands.size() - numOperands;
2025 numOperands = result.
operands.size();
2028 int64_t numUbOperands = result.
operands.size() - numOperands;
2033 getStepAttrName(result.
name),
2037 IntegerAttr stepAttr;
2039 getStepAttrName(result.
name).data(),
2043 if (stepAttr.getValue().isNegative())
2046 "expected step to be representable as a positive signed integer");
2054 regionArgs.push_back(inductionVariable);
2062 for (
auto argOperandType :
2063 llvm::zip(llvm::drop_begin(regionArgs), operands, result.
types)) {
2064 Type type = std::get<2>(argOperandType);
2065 std::get<0>(argOperandType).type = type;
2073 getOperandSegmentSizeAttr(),
2075 static_cast<int32_t>(numUbOperands),
2076 static_cast<int32_t>(operands.size())}));
2080 if (regionArgs.size() != result.
types.size() + 1)
2083 "mismatch between the number of loop-carried values and results");
2087 AffineForOp::ensureTerminator(*body, builder, result.
location);
2109 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
2110 p << constExpr.getValue();
2118 if (dyn_cast<AffineSymbolExpr>(expr)) {
2134 unsigned AffineForOp::getNumIterOperands() {
2135 AffineMap lbMap = getLowerBoundMapAttr().getValue();
2136 AffineMap ubMap = getUpperBoundMapAttr().getValue();
2141 std::optional<MutableArrayRef<OpOperand>>
2142 AffineForOp::getYieldedValuesMutable() {
2143 return cast<AffineYieldOp>(getBody()->getTerminator()).getOperandsMutable();
2155 if (getStepAsInt() != 1)
2156 p <<
" step " << getStepAsInt();
2158 bool printBlockTerminators =
false;
2159 if (getNumIterOperands() > 0) {
2161 auto regionArgs = getRegionIterArgs();
2162 auto operands = getInits();
2164 llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](
auto it) {
2165 p << std::get<0>(it) <<
" = " << std::get<1>(it);
2167 p <<
") -> (" << getResultTypes() <<
")";
2168 printBlockTerminators =
true;
2173 printBlockTerminators);
2175 (*this)->getAttrs(),
2176 {getLowerBoundMapAttrName(getOperation()->getName()),
2177 getUpperBoundMapAttrName(getOperation()->getName()),
2178 getStepAttrName(getOperation()->getName()),
2179 getOperandSegmentSizeAttr()});
2184 auto foldLowerOrUpperBound = [&forOp](
bool lower) {
2188 auto boundOperands =
2189 lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
2190 for (
auto operand : boundOperands) {
2193 operandConstants.push_back(operandCst);
2197 lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
2199 "bound maps should have at least one result");
2201 if (failed(boundMap.
constantFold(operandConstants, foldedResults)))
2205 assert(!foldedResults.empty() &&
"bounds should have at least one result");
2206 auto maxOrMin = llvm::cast<IntegerAttr>(foldedResults[0]).getValue();
2207 for (
unsigned i = 1, e = foldedResults.size(); i < e; i++) {
2208 auto foldedResult = llvm::cast<IntegerAttr>(foldedResults[i]).getValue();
2209 maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
2210 : llvm::APIntOps::smin(maxOrMin, foldedResult);
2212 lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
2213 : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
2218 bool folded =
false;
2219 if (!forOp.hasConstantLowerBound())
2220 folded |= succeeded(foldLowerOrUpperBound(
true));
2223 if (!forOp.hasConstantUpperBound())
2224 folded |= succeeded(foldLowerOrUpperBound(
false));
2225 return success(folded);
2233 auto lbMap = forOp.getLowerBoundMap();
2234 auto ubMap = forOp.getUpperBoundMap();
2235 auto prevLbMap = lbMap;
2236 auto prevUbMap = ubMap;
2249 if (lbMap == prevLbMap && ubMap == prevUbMap)
2252 if (lbMap != prevLbMap)
2253 forOp.setLowerBound(lbOperands, lbMap);
2254 if (ubMap != prevUbMap)
2255 forOp.setUpperBound(ubOperands, ubMap);
2261 static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
2262 int64_t step = forOp.getStepAsInt();
2263 if (!forOp.hasConstantBounds() || step <= 0)
2264 return std::nullopt;
2265 int64_t lb = forOp.getConstantLowerBound();
2266 int64_t ub = forOp.getConstantUpperBound();
2267 return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
2275 LogicalResult matchAndRewrite(AffineForOp forOp,
2278 if (!llvm::hasSingleElement(*forOp.getBody()))
2280 if (forOp.getNumResults() == 0)
2282 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
2283 if (tripCount && *tripCount == 0) {
2286 rewriter.
replaceOp(forOp, forOp.getInits());
2290 auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
2291 auto iterArgs = forOp.getRegionIterArgs();
2292 bool hasValDefinedOutsideLoop =
false;
2293 bool iterArgsNotInOrder =
false;
2294 for (
unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
2295 Value val = yieldOp.getOperand(i);
2296 auto *iterArgIt = llvm::find(iterArgs, val);
2297 if (iterArgIt == iterArgs.end()) {
2299 assert(forOp.isDefinedOutsideOfLoop(val) &&
2300 "must be defined outside of the loop");
2301 hasValDefinedOutsideLoop =
true;
2302 replacements.push_back(val);
2304 unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
2306 iterArgsNotInOrder =
true;
2307 replacements.push_back(forOp.getInits()[pos]);
2312 if (!tripCount.has_value() &&
2313 (hasValDefinedOutsideLoop || iterArgsNotInOrder))
2317 if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
2319 rewriter.
replaceOp(forOp, replacements);
2327 results.
add<AffineForEmptyLoopFolder>(context);
2331 assert((point.
isParent() || point == getRegion()) &&
"invalid region point");
2338 void AffineForOp::getSuccessorRegions(
2340 assert((point.
isParent() || point == getRegion()) &&
"expected loop region");
2345 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*
this);
2346 if (point.
isParent() && tripCount.has_value()) {
2347 if (tripCount.value() > 0) {
2348 regions.push_back(
RegionSuccessor(&getRegion(), getRegionIterArgs()));
2351 if (tripCount.value() == 0) {
2359 if (!point.
isParent() && tripCount && *tripCount == 1) {
2366 regions.push_back(
RegionSuccessor(&getRegion(), getRegionIterArgs()));
2372 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(op);
2373 return tripCount && *tripCount == 0;
2376 LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
2386 results.assign(getInits().begin(), getInits().end());
2389 return success(folded);
2402 assert(map.
getNumResults() >= 1 &&
"bound map has at least one result");
2403 getLowerBoundOperandsMutable().assign(lbOperands);
2404 setLowerBoundMap(map);
2409 assert(map.
getNumResults() >= 1 &&
"bound map has at least one result");
2410 getUpperBoundOperandsMutable().assign(ubOperands);
2411 setUpperBoundMap(map);
2414 bool AffineForOp::hasConstantLowerBound() {
2415 return getLowerBoundMap().isSingleConstant();
2418 bool AffineForOp::hasConstantUpperBound() {
2419 return getUpperBoundMap().isSingleConstant();
2422 int64_t AffineForOp::getConstantLowerBound() {
2423 return getLowerBoundMap().getSingleConstantResult();
2426 int64_t AffineForOp::getConstantUpperBound() {
2427 return getUpperBoundMap().getSingleConstantResult();
2430 void AffineForOp::setConstantLowerBound(int64_t value) {
2434 void AffineForOp::setConstantUpperBound(int64_t value) {
2438 AffineForOp::operand_range AffineForOp::getControlOperands() {
2443 bool AffineForOp::matchingBoundOperandList() {
2444 auto lbMap = getLowerBoundMap();
2445 auto ubMap = getUpperBoundMap();
2451 for (
unsigned i = 0, e = lbMap.
getNumInputs(); i < e; i++) {
2453 if (getOperand(i) != getOperand(numOperands + i))
2461 std::optional<SmallVector<Value>> AffineForOp::getLoopInductionVars() {
2465 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopLowerBounds() {
2466 if (!hasConstantLowerBound())
2467 return std::nullopt;
2470 OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))};
2473 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopSteps() {
2479 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopUpperBounds() {
2480 if (!hasConstantUpperBound())
2484 OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()))};
2487 FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
2489 bool replaceInitOperandUsesInLoop,
2494 auto inits = llvm::to_vector(getInits());
2495 inits.append(newInitOperands.begin(), newInitOperands.end());
2496 AffineForOp newLoop = rewriter.
create<AffineForOp>(
2501 auto yieldOp = cast<AffineYieldOp>(getBody()->getTerminator());
2503 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
2508 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
2509 assert(newInitOperands.size() == newYieldedValues.size() &&
2510 "expected as many new yield values as new iter operands");
2512 yieldOp.getOperandsMutable().append(newYieldedValues);
2517 rewriter.
mergeBlocks(getBody(), newLoop.getBody(),
2518 newLoop.getBody()->getArguments().take_front(
2519 getBody()->getNumArguments()));
2521 if (replaceInitOperandUsesInLoop) {
2524 for (
auto it : llvm::zip(newInitOperands, newIterArgs)) {
2535 newLoop->getResults().take_front(getNumResults()));
2536 return cast<LoopLikeOpInterface>(newLoop.getOperation());
2564 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2565 if (!ivArg || !ivArg.getOwner() || !ivArg.getOwner()->getParent())
2566 return AffineForOp();
2568 ivArg.getOwner()->getParent()->getParentOfType<AffineForOp>())
2570 return forOp.getInductionVar() == val ? forOp : AffineForOp();
2571 return AffineForOp();
2575 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2576 if (!ivArg || !ivArg.getOwner())
2579 auto parallelOp = dyn_cast<AffineParallelOp>(containingOp);
2580 if (parallelOp && llvm::is_contained(parallelOp.getIVs(), val))
2589 ivs->reserve(forInsts.size());
2590 for (
auto forInst : forInsts)
2591 ivs->push_back(forInst.getInductionVar());
2596 ivs.reserve(affineOps.size());
2599 if (
auto forOp = dyn_cast<AffineForOp>(op))
2600 ivs.push_back(forOp.getInductionVar());
2601 else if (
auto parallelOp = dyn_cast<AffineParallelOp>(op))
2602 for (
size_t i = 0; i < parallelOp.getBody()->getNumArguments(); i++)
2603 ivs.push_back(parallelOp.getBody()->getArgument(i));
2609 template <
typename BoundListTy,
typename LoopCreatorTy>
2614 LoopCreatorTy &&loopCreatorFn) {
2615 assert(lbs.size() == ubs.size() &&
"Mismatch in number of arguments");
2616 assert(lbs.size() == steps.size() &&
"Mismatch in number of arguments");
2628 ivs.reserve(lbs.size());
2629 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
2635 if (i == e - 1 && bodyBuilderFn) {
2637 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2639 nestedBuilder.
create<AffineYieldOp>(nestedLoc);
2644 auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
2652 int64_t ub, int64_t step,
2653 AffineForOp::BodyBuilderFn bodyBuilderFn) {
2654 return builder.
create<AffineForOp>(loc, lb, ub, step,
2655 std::nullopt, bodyBuilderFn);
2662 AffineForOp::BodyBuilderFn bodyBuilderFn) {
2665 if (lbConst && ubConst)
2667 ubConst.value(), step, bodyBuilderFn);
2670 std::nullopt, bodyBuilderFn);
2698 LogicalResult matchAndRewrite(AffineIfOp ifOp,
2700 if (ifOp.getElseRegion().empty() ||
2701 !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
2716 LogicalResult matchAndRewrite(AffineIfOp op,
2719 auto isTriviallyFalse = [](
IntegerSet iSet) {
2720 return iSet.isEmptyIntegerSet();
2724 return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 &&
2725 iSet.getConstraint(0) == 0);
2728 IntegerSet affineIfConditions = op.getIntegerSet();
2730 if (isTriviallyFalse(affineIfConditions)) {
2740 blockToMove = op.getElseBlock();
2741 }
else if (isTriviallyTrue(affineIfConditions)) {
2742 blockToMove = op.getThenBlock();
2760 rewriter.
eraseOp(blockToMoveTerminator);
2768 void AffineIfOp::getSuccessorRegions(
2777 if (getElseRegion().empty()) {
2778 regions.push_back(getResults());
2794 auto conditionAttr =
2795 (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2797 return emitOpError(
"requires an integer set attribute named 'condition'");
2800 IntegerSet condition = conditionAttr.getValue();
2802 return emitOpError(
"operand count and condition integer set dimension and "
2803 "symbol count must match");
2815 IntegerSetAttr conditionAttr;
2818 AffineIfOp::getConditionAttrStrName(),
2824 auto set = conditionAttr.getValue();
2825 if (set.getNumDims() != numDims)
2828 "dim operand count and integer set dim count must match");
2829 if (numDims + set.getNumSymbols() != result.
operands.size())
2832 "symbol operand count and integer set symbol count must match");
2846 AffineIfOp::ensureTerminator(*thenRegion, parser.
getBuilder(),
2853 AffineIfOp::ensureTerminator(*elseRegion, parser.
getBuilder(),
2865 auto conditionAttr =
2866 (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2867 p <<
" " << conditionAttr;
2869 conditionAttr.getValue().getNumDims(), p);
2876 auto &elseRegion = this->getElseRegion();
2877 if (!elseRegion.
empty()) {
2886 getConditionAttrStrName());
2891 ->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName())
2895 void AffineIfOp::setIntegerSet(
IntegerSet newSet) {
2901 (*this)->setOperands(operands);
2906 bool withElseRegion) {
2907 assert(resultTypes.empty() || withElseRegion);
2916 if (resultTypes.empty())
2917 AffineIfOp::ensureTerminator(*thenRegion, builder, result.
location);
2920 if (withElseRegion) {
2922 if (resultTypes.empty())
2923 AffineIfOp::ensureTerminator(*elseRegion, builder, result.
location);
2929 AffineIfOp::build(builder, result, {}, set, args,
2944 if (llvm::none_of(operands,
2955 auto set = getIntegerSet();
2961 if (getIntegerSet() == set && llvm::equal(operands, getOperands()))
2964 setConditional(set, operands);
2970 results.
add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context);
2979 assert(operands.size() == 1 + map.
getNumInputs() &&
"inconsistent operands");
2983 auto memrefType = llvm::cast<MemRefType>(operands[0].
getType());
2984 result.
types.push_back(memrefType.getElementType());
2989 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
2992 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
2994 result.
types.push_back(memrefType.getElementType());
2999 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
3000 int64_t rank = memrefType.getRank();
3005 build(builder, result, memref, map, indices);
3014 AffineMapAttr mapAttr;
3019 AffineLoadOp::getMapAttrStrName(),
3029 p <<
" " << getMemRef() <<
'[';
3030 if (AffineMapAttr mapAttr =
3031 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3035 {getMapAttrStrName()});
3041 static LogicalResult
3044 MemRefType memrefType,
unsigned numIndexOperands) {
3047 return op->
emitOpError(
"affine map num results must equal memref rank");
3049 return op->
emitOpError(
"expects as many subscripts as affine map inputs");
3052 for (
auto idx : mapOperands) {
3053 if (!idx.getType().isIndex())
3054 return op->
emitOpError(
"index to load must have 'index' type");
3057 "index must be a valid dimension or symbol identifier");
3065 if (
getType() != memrefType.getElementType())
3066 return emitOpError(
"result type must match element type of memref");
3070 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3071 getMapOperands(), memrefType,
3072 getNumOperands() - 1)))
3080 results.
add<SimplifyAffineOp<AffineLoadOp>>(context);
3089 auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
3096 auto global = dyn_cast_or_null<memref::GlobalOp>(
3103 llvm::dyn_cast_or_null<DenseElementsAttr>(global.getConstantInitValue());
3107 if (
auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(cstAttr))
3108 return splatAttr.getSplatValue<
Attribute>();
3110 if (!getAffineMap().isConstant())
3112 auto indices = llvm::to_vector<4>(
3113 llvm::map_range(getAffineMap().getConstantResults(),
3114 [](int64_t v) -> uint64_t {
return v; }));
3115 return cstAttr.getValues<
Attribute>()[indices];
3125 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
3136 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
3137 int64_t rank = memrefType.getRank();
3142 build(builder, result, valueToStore, memref, map, indices);
3151 AffineMapAttr mapAttr;
3156 mapOperands, mapAttr, AffineStoreOp::getMapAttrStrName(),
3167 p <<
" " << getValueToStore();
3168 p <<
", " << getMemRef() <<
'[';
3169 if (AffineMapAttr mapAttr =
3170 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3174 {getMapAttrStrName()});
3181 if (getValueToStore().
getType() != memrefType.getElementType())
3183 "value to store must have the same type as memref element type");
3187 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3188 getMapOperands(), memrefType,
3189 getNumOperands() - 2)))
3197 results.
add<SimplifyAffineOp<AffineStoreOp>>(context);
3200 LogicalResult AffineStoreOp::fold(FoldAdaptor adaptor,
3210 template <
typename T>
3214 op.getMap().getNumDims() + op.getMap().getNumSymbols())
3216 "operand count and affine map dimension and symbol count must match");
3219 return op.
emitOpError(
"affine map expect at least one result");
3223 template <
typename T>
3225 p <<
' ' << op->
getAttr(T::getMapAttrStrName());
3227 unsigned numDims = op.getMap().getNumDims();
3228 p <<
'(' << operands.take_front(numDims) <<
')';
3230 if (operands.size() != numDims)
3231 p <<
'[' << operands.drop_front(numDims) <<
']';
3233 {T::getMapAttrStrName()});
3236 template <
typename T>
3243 AffineMapAttr mapAttr;
3259 template <
typename T>
3261 static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
3262 "expected affine min or max op");
3268 auto foldedMap = op.getMap().partialConstantFold(operands, &results);
3270 if (foldedMap.getNumSymbols() == 1 && foldedMap.isSymbolIdentity())
3274 if (results.empty()) {
3276 if (foldedMap == op.getMap())
3283 auto resultIt = std::is_same<T, AffineMinOp>::value
3284 ? llvm::min_element(results)
3285 : llvm::max_element(results);
3286 if (resultIt == results.end())
3292 template <
typename T>
3298 AffineMap oldMap = affineOp.getAffineMap();
3304 if (!llvm::is_contained(newExprs, expr))
3305 newExprs.push_back(expr);
3335 template <
typename T>
3341 AffineMap oldMap = affineOp.getAffineMap();
3343 affineOp.getMapOperands().take_front(oldMap.
getNumDims());
3345 affineOp.getMapOperands().take_back(oldMap.
getNumSymbols());
3347 auto newDimOperands = llvm::to_vector<8>(dimOperands);
3348 auto newSymOperands = llvm::to_vector<8>(symOperands);
3356 if (
auto symExpr = dyn_cast<AffineSymbolExpr>(expr)) {
3357 Value symValue = symOperands[symExpr.getPosition()];
3359 producerOps.push_back(producerOp);
3362 }
else if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
3363 Value dimValue = dimOperands[dimExpr.getPosition()];
3365 producerOps.push_back(producerOp);
3372 newExprs.push_back(expr);
3375 if (producerOps.empty())
3382 for (T producerOp : producerOps) {
3383 AffineMap producerMap = producerOp.getAffineMap();
3384 unsigned numProducerDims = producerMap.
getNumDims();
3389 producerOp.getMapOperands().take_front(numProducerDims);
3391 producerOp.getMapOperands().take_back(numProducerSyms);
3392 newDimOperands.append(dimValues.begin(), dimValues.end());
3393 newSymOperands.append(symValues.begin(), symValues.end());
3397 newExprs.push_back(expr.shiftDims(numProducerDims, numUsedDims)
3398 .shiftSymbols(numProducerSyms, numUsedSyms));
3401 numUsedDims += numProducerDims;
3402 numUsedSyms += numProducerSyms;
3408 llvm::to_vector<8>(llvm::concat<Value>(newDimOperands, newSymOperands));
3427 if (!resultExpr.isPureAffine())
3432 if (failed(flattenResult))
3445 if (llvm::is_sorted(flattenedExprs))
3450 llvm::to_vector(llvm::seq<unsigned>(0, map.
getNumResults()));
3451 llvm::sort(resultPermutation, [&](
unsigned lhs,
unsigned rhs) {
3452 return flattenedExprs[lhs] < flattenedExprs[rhs];
3455 for (
unsigned idx : resultPermutation)
3476 template <
typename T>
3482 AffineMap map = affineOp.getAffineMap();
3490 template <
typename T>
3496 if (affineOp.getMap().getNumResults() != 1)
3499 affineOp.getOperands());
3527 return parseAffineMinMaxOp<AffineMinOp>(parser, result);
3555 return parseAffineMinMaxOp<AffineMaxOp>(parser, result);
3574 IntegerAttr hintInfo;
3576 StringRef readOrWrite, cacheType;
3578 AffineMapAttr mapAttr;
3582 AffinePrefetchOp::getMapAttrStrName(),
3588 AffinePrefetchOp::getLocalityHintAttrStrName(),
3598 if (readOrWrite !=
"read" && readOrWrite !=
"write")
3600 "rw specifier has to be 'read' or 'write'");
3601 result.
addAttribute(AffinePrefetchOp::getIsWriteAttrStrName(),
3604 if (cacheType !=
"data" && cacheType !=
"instr")
3606 "cache type has to be 'data' or 'instr'");
3608 result.
addAttribute(AffinePrefetchOp::getIsDataCacheAttrStrName(),
3615 p <<
" " << getMemref() <<
'[';
3616 AffineMapAttr mapAttr =
3617 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3620 p <<
']' <<
", " << (getIsWrite() ?
"write" :
"read") <<
", "
3621 <<
"locality<" << getLocalityHint() <<
">, "
3622 << (getIsDataCache() ?
"data" :
"instr");
3624 (*this)->getAttrs(),
3625 {getMapAttrStrName(), getLocalityHintAttrStrName(),
3626 getIsDataCacheAttrStrName(), getIsWriteAttrStrName()});
3631 auto mapAttr = (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3635 return emitOpError(
"affine.prefetch affine map num results must equal"
3638 return emitOpError(
"too few operands");
3640 if (getNumOperands() != 1)
3641 return emitOpError(
"too few operands");
3645 for (
auto idx : getMapOperands()) {
3648 "index must be a valid dimension or symbol identifier");
3656 results.
add<SimplifyAffineOp<AffinePrefetchOp>>(context);
3659 LogicalResult AffinePrefetchOp::fold(FoldAdaptor adaptor,
3674 auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) {
3678 build(builder, result, resultTypes, reductions, lbs, {}, ubs,
3688 assert(llvm::all_of(lbMaps,
3690 return m.getNumDims() == lbMaps[0].getNumDims() &&
3691 m.getNumSymbols() == lbMaps[0].getNumSymbols();
3693 "expected all lower bounds maps to have the same number of dimensions "
3695 assert(llvm::all_of(ubMaps,
3697 return m.getNumDims() == ubMaps[0].getNumDims() &&
3698 m.getNumSymbols() == ubMaps[0].getNumSymbols();
3700 "expected all upper bounds maps to have the same number of dimensions "
3702 assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) &&
3703 "expected lower bound maps to have as many inputs as lower bound "
3705 assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) &&
3706 "expected upper bound maps to have as many inputs as upper bound "
3714 for (arith::AtomicRMWKind reduction : reductions)
3715 reductionAttrs.push_back(
3727 groups.reserve(groups.size() + maps.size());
3728 exprs.reserve(maps.size());
3730 llvm::append_range(exprs, m.getResults());
3731 groups.push_back(m.getNumResults());
3733 return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs,
3739 AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups);
3740 AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups);
3758 for (
unsigned i = 0, e = steps.size(); i < e; ++i)
3760 if (resultTypes.empty())
3761 ensureTerminator(*bodyRegion, builder, result.
location);
3765 return {&getRegion()};
3768 unsigned AffineParallelOp::getNumDims() {
return getSteps().size(); }
3770 AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
3771 return getOperands().take_front(getLowerBoundsMap().getNumInputs());
3774 AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
3775 return getOperands().drop_front(getLowerBoundsMap().getNumInputs());
3778 AffineMap AffineParallelOp::getLowerBoundMap(
unsigned pos) {
3779 auto values = getLowerBoundsGroups().getValues<int32_t>();
3781 for (
unsigned i = 0; i < pos; ++i)
3783 return getLowerBoundsMap().getSliceMap(start, values[pos]);
3786 AffineMap AffineParallelOp::getUpperBoundMap(
unsigned pos) {
3787 auto values = getUpperBoundsGroups().getValues<int32_t>();
3789 for (
unsigned i = 0; i < pos; ++i)
3791 return getUpperBoundsMap().getSliceMap(start, values[pos]);
3795 return AffineValueMap(getLowerBoundsMap(), getLowerBoundsOperands());
3799 return AffineValueMap(getUpperBoundsMap(), getUpperBoundsOperands());
3802 std::optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
3803 if (hasMinMaxBounds())
3804 return std::nullopt;
3809 AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
3812 for (
unsigned i = 0, e = rangesValueMap.
getNumResults(); i < e; ++i) {
3813 auto expr = rangesValueMap.
getResult(i);
3814 auto cst = dyn_cast<AffineConstantExpr>(expr);
3816 return std::nullopt;
3817 out.push_back(cst.getValue());
3822 Block *AffineParallelOp::getBody() {
return &getRegion().
front(); }
3824 OpBuilder AffineParallelOp::getBodyBuilder() {
3825 return OpBuilder(getBody(), std::prev(getBody()->end()));
3830 "operands to map must match number of inputs");
3832 auto ubOperands = getUpperBoundsOperands();
3835 newOperands.append(ubOperands.begin(), ubOperands.end());
3836 (*this)->setOperands(newOperands);
3843 "operands to map must match number of inputs");
3846 newOperands.append(ubOperands.begin(), ubOperands.end());
3847 (*this)->setOperands(newOperands);
3853 setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
3858 arith::AtomicRMWKind op) {
3860 case arith::AtomicRMWKind::addf:
3861 return isa<FloatType>(resultType);
3862 case arith::AtomicRMWKind::addi:
3863 return isa<IntegerType>(resultType);
3864 case arith::AtomicRMWKind::assign:
3866 case arith::AtomicRMWKind::mulf:
3867 return isa<FloatType>(resultType);
3868 case arith::AtomicRMWKind::muli:
3869 return isa<IntegerType>(resultType);
3870 case arith::AtomicRMWKind::maximumf:
3871 return isa<FloatType>(resultType);
3872 case arith::AtomicRMWKind::minimumf:
3873 return isa<FloatType>(resultType);
3874 case arith::AtomicRMWKind::maxs: {
3875 auto intType = llvm::dyn_cast<IntegerType>(resultType);
3876 return intType && intType.isSigned();
3878 case arith::AtomicRMWKind::mins: {
3879 auto intType = llvm::dyn_cast<IntegerType>(resultType);
3880 return intType && intType.isSigned();
3882 case arith::AtomicRMWKind::maxu: {
3883 auto intType = llvm::dyn_cast<IntegerType>(resultType);
3884 return intType && intType.isUnsigned();
3886 case arith::AtomicRMWKind::minu: {
3887 auto intType = llvm::dyn_cast<IntegerType>(resultType);
3888 return intType && intType.isUnsigned();
3890 case arith::AtomicRMWKind::ori:
3891 return isa<IntegerType>(resultType);
3892 case arith::AtomicRMWKind::andi:
3893 return isa<IntegerType>(resultType);
3900 auto numDims = getNumDims();
3903 getSteps().size() != numDims || getBody()->getNumArguments() != numDims) {
3904 return emitOpError() <<
"the number of region arguments ("
3905 << getBody()->getNumArguments()
3906 <<
") and the number of map groups for lower ("
3907 << getLowerBoundsGroups().getNumElements()
3908 <<
") and upper bound ("
3909 << getUpperBoundsGroups().getNumElements()
3910 <<
"), and the number of steps (" << getSteps().size()
3911 <<
") must all match";
3914 unsigned expectedNumLBResults = 0;
3915 for (APInt v : getLowerBoundsGroups())
3916 expectedNumLBResults += v.getZExtValue();
3917 if (expectedNumLBResults != getLowerBoundsMap().getNumResults())
3918 return emitOpError() <<
"expected lower bounds map to have "
3919 << expectedNumLBResults <<
" results";
3920 unsigned expectedNumUBResults = 0;
3921 for (APInt v : getUpperBoundsGroups())
3922 expectedNumUBResults += v.getZExtValue();
3923 if (expectedNumUBResults != getUpperBoundsMap().getNumResults())
3924 return emitOpError() <<
"expected upper bounds map to have "
3925 << expectedNumUBResults <<
" results";
3927 if (getReductions().size() != getNumResults())
3928 return emitOpError(
"a reduction must be specified for each output");
3934 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
3935 if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
3936 return emitOpError(
"invalid reduction attribute");
3937 auto kind = arith::symbolizeAtomicRMWKind(intAttr.getInt()).value();
3939 return emitOpError(
"result type cannot match reduction attribute");
3945 getLowerBoundsMap().getNumDims())))
3949 getUpperBoundsMap().getNumDims())))
3954 LogicalResult AffineValueMap::canonicalize() {
3956 auto newMap = getAffineMap();
3958 if (newMap == getAffineMap() && newOperands == operands)
3960 reset(newMap, newOperands);
3973 if (!lbCanonicalized && !ubCanonicalized)
3976 if (lbCanonicalized)
3978 if (ubCanonicalized)
3984 LogicalResult AffineParallelOp::fold(FoldAdaptor adaptor,
3996 StringRef keyword) {
3999 ValueRange dimOperands = operands.take_front(numDims);
4000 ValueRange symOperands = operands.drop_front(numDims);
4002 for (llvm::APInt groupSize : group) {
4006 unsigned size = groupSize.getZExtValue();
4011 p << keyword <<
'(';
4021 p <<
" (" << getBody()->getArguments() <<
") = (";
4023 getLowerBoundsOperands(),
"max");
4026 getUpperBoundsOperands(),
"min");
4029 bool elideSteps = llvm::all_of(steps, [](int64_t step) {
return step == 1; });
4032 llvm::interleaveComma(steps, p);
4035 if (getNumResults()) {
4037 llvm::interleaveComma(getReductions(), p, [&](
auto &attr) {
4038 arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(
4039 llvm::cast<IntegerAttr>(attr).getInt());
4040 p <<
"\"" << arith::stringifyAtomicRMWKind(sym) <<
"\"";
4042 p <<
") -> (" << getResultTypes() <<
")";
4049 (*this)->getAttrs(),
4050 {AffineParallelOp::getReductionsAttrStrName(),
4051 AffineParallelOp::getLowerBoundsMapAttrStrName(),
4052 AffineParallelOp::getLowerBoundsGroupsAttrStrName(),
4053 AffineParallelOp::getUpperBoundsMapAttrStrName(),
4054 AffineParallelOp::getUpperBoundsGroupsAttrStrName(),
4055 AffineParallelOp::getStepsAttrStrName()});
4068 "expected operands to be dim or symbol expression");
4071 for (
const auto &list : operands) {
4075 for (
Value operand : valueOperands) {
4076 unsigned pos = std::distance(uniqueOperands.begin(),
4077 llvm::find(uniqueOperands, operand));
4078 if (pos == uniqueOperands.size())
4079 uniqueOperands.push_back(operand);
4080 replacements.push_back(
4090 enum class MinMaxKind { Min, Max };
4114 const llvm::StringLiteral tmpAttrStrName =
"__pseudo_bound_map";
4116 StringRef mapName = kind == MinMaxKind::Min
4117 ? AffineParallelOp::getUpperBoundsMapAttrStrName()
4118 : AffineParallelOp::getLowerBoundsMapAttrStrName();
4119 StringRef groupsName =
4120 kind == MinMaxKind::Min
4121 ? AffineParallelOp::getUpperBoundsGroupsAttrStrName()
4122 : AffineParallelOp::getLowerBoundsGroupsAttrStrName();
4139 auto parseOperands = [&]() {
4141 kind == MinMaxKind::Min ?
"min" :
"max"))) {
4142 mapOperands.clear();
4149 llvm::append_range(flatExprs, map.getValue().getResults());
4151 auto dimsRef = operandsRef.take_front(map.getValue().getNumDims());
4153 auto symsRef = operandsRef.drop_front(map.getValue().getNumDims());
4155 flatDimOperands.append(map.getValue().getNumResults(), dims);
4156 flatSymOperands.append(map.getValue().getNumResults(), syms);
4157 numMapsPerGroup.push_back(map.getValue().getNumResults());
4160 flatSymOperands.emplace_back(),
4161 flatExprs.emplace_back())))
4163 numMapsPerGroup.push_back(1);
4170 unsigned totalNumDims = 0;
4171 unsigned totalNumSyms = 0;
4172 for (
unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
4173 unsigned numDims = flatDimOperands[i].size();
4174 unsigned numSyms = flatSymOperands[i].size();
4175 flatExprs[i] = flatExprs[i]
4176 .shiftDims(numDims, totalNumDims)
4177 .shiftSymbols(numSyms, totalNumSyms);
4178 totalNumDims += numDims;
4179 totalNumSyms += numSyms;
4191 result.
operands.append(dimOperands.begin(), dimOperands.end());
4192 result.
operands.append(symOperands.begin(), symOperands.end());
4195 auto flatMap =
AffineMap::get(totalNumDims, totalNumSyms, flatExprs,
4197 flatMap = flatMap.replaceDimsAndSymbols(
4198 dimRplacements, symRepacements, dimOperands.size(), symOperands.size());
4222 AffineMapAttr stepsMapAttr;
4227 result.
addAttribute(AffineParallelOp::getStepsAttrStrName(),
4231 AffineParallelOp::getStepsAttrStrName(),
4238 auto stepsMap = stepsMapAttr.getValue();
4239 for (
const auto &result : stepsMap.getResults()) {
4240 auto constExpr = dyn_cast<AffineConstantExpr>(result);
4243 "steps must be constant integers");
4244 steps.push_back(constExpr.getValue());
4246 result.
addAttribute(AffineParallelOp::getStepsAttrStrName(),
4256 auto parseAttributes = [&]() -> ParseResult {
4266 std::optional<arith::AtomicRMWKind> reduction =
4267 arith::symbolizeAtomicRMWKind(attrVal.getValue());
4269 return parser.
emitError(loc,
"invalid reduction value: ") << attrVal;
4270 reductions.push_back(
4278 result.
addAttribute(AffineParallelOp::getReductionsAttrStrName(),
4287 for (
auto &iv : ivs)
4288 iv.type = indexType;
4294 AffineParallelOp::ensureTerminator(*body, builder, result.
location);
4303 auto *parentOp = (*this)->getParentOp();
4304 auto results = parentOp->getResults();
4305 auto operands = getOperands();
4307 if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
4308 return emitOpError() <<
"only terminates affine.if/for/parallel regions";
4309 if (parentOp->getNumResults() != getNumOperands())
4310 return emitOpError() <<
"parent of yield must have same number of "
4311 "results as the yield operands";
4312 for (
auto it : llvm::zip(results, operands)) {
4314 return emitOpError() <<
"types mismatch between yield op and its parent";
4327 assert(operands.size() == 1 + map.
getNumInputs() &&
"inconsistent operands");
4331 result.
types.push_back(resultType);
4335 VectorType resultType,
Value memref,
4337 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
4341 result.
types.push_back(resultType);
4345 VectorType resultType,
Value memref,
4347 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
4348 int64_t rank = memrefType.getRank();
4353 build(builder, result, resultType, memref, map, indices);
4356 void AffineVectorLoadOp::getCanonicalizationPatterns(
RewritePatternSet &results,
4358 results.
add<SimplifyAffineOp<AffineVectorLoadOp>>(context);
4366 MemRefType memrefType;
4367 VectorType resultType;
4369 AffineMapAttr mapAttr;
4374 AffineVectorLoadOp::getMapAttrStrName(),
4385 p <<
" " << getMemRef() <<
'[';
4386 if (AffineMapAttr mapAttr =
4387 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4391 {getMapAttrStrName()});
4397 VectorType vectorType) {
4399 if (memrefType.getElementType() != vectorType.getElementType())
4401 "requires memref and vector types of the same elemental type");
4409 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4410 getMapOperands(), memrefType,
4411 getNumOperands() - 1)))
4427 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
4438 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
4439 int64_t rank = memrefType.getRank();
4444 build(builder, result, valueToStore, memref, map, indices);
4446 void AffineVectorStoreOp::getCanonicalizationPatterns(
4448 results.
add<SimplifyAffineOp<AffineVectorStoreOp>>(context);
4455 MemRefType memrefType;
4456 VectorType resultType;
4459 AffineMapAttr mapAttr;
4465 AffineVectorStoreOp::getMapAttrStrName(),
4476 p <<
" " << getValueToStore();
4477 p <<
", " << getMemRef() <<
'[';
4478 if (AffineMapAttr mapAttr =
4479 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4483 {getMapAttrStrName()});
4484 p <<
" : " <<
getMemRefType() <<
", " << getValueToStore().getType();
4490 *
this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4491 getMapOperands(), memrefType,
4492 getNumOperands() - 2)))
4505 LogicalResult AffineDelinearizeIndexOp::inferReturnTypes(
4506 MLIRContext *context, std::optional<::mlir::Location> location,
4509 AffineDelinearizeIndexOpAdaptor adaptor(operands, attributes, properties,
4511 inferredReturnTypes.assign(adaptor.getBasis().size(),
4524 if (staticDim.has_value())
4527 return llvm::dyn_cast_if_present<Value>(ofr);
4533 if (getBasis().empty())
4534 return emitOpError(
"basis should not be empty");
4535 if (getNumResults() != getBasis().size())
4536 return emitOpError(
"should return an index for each basis element");
4543 struct DropUnitExtentBasis
4547 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4550 std::optional<Value> zero = std::nullopt;
4551 Location loc = delinearizeOp->getLoc();
4554 zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
4555 return zero.value();
4561 for (
auto [index, basis] :
llvm::enumerate(delinearizeOp.getBasis())) {
4563 replacements[index] =
getZero();
4565 newOperands.push_back(basis);
4568 if (newOperands.size() == delinearizeOp.getBasis().size())
4571 if (!newOperands.empty()) {
4572 auto newDelinearizeOp = rewriter.
create<affine::AffineDelinearizeIndexOp>(
4573 loc, delinearizeOp.getLinearIndex(), newOperands);
4576 for (
auto &replacement : replacements) {
4579 replacement = newDelinearizeOp->
getResult(newIndex++);
4583 rewriter.
replaceOp(delinearizeOp, replacements);
4604 struct DropDelinearizeOfSingleLoop
4608 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4610 auto basis = delinearizeOp.getBasis();
4611 if (basis.size() != 1)
4615 auto inductionVar = dyn_cast<BlockArgument>(delinearizeOp.getLinearIndex());
4620 auto loopLikeOp = dyn_cast<LoopLikeOpInterface>(
4627 auto inductionVars = loopLikeOp.getLoopInductionVars();
4628 if (!inductionVars || inductionVars->size() != 1 ||
4629 inductionVars->front() != inductionVar) {
4631 delinearizeOp,
"`linear_index` is not loop induction variable");
4635 auto upperBounds = loopLikeOp.getLoopUpperBounds();
4636 if (!upperBounds || upperBounds->size() != 1 ||
4639 "`basis` is not upper bound");
4643 auto lowerBounds = loopLikeOp.getLoopLowerBounds();
4644 if (!lowerBounds || lowerBounds->size() != 1 ||
4647 "loop lower bound is not zero");
4651 auto steps = loopLikeOp.getLoopSteps();
4655 rewriter.
replaceOp(delinearizeOp, inductionVar);
4662 void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
4664 patterns.
insert<DropDelinearizeOfSingleLoop, DropUnitExtentBasis>(context);
4671 #define GET_OP_CLASSES
4672 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
static AffineForOp buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb, int64_t ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds known to be constants.
static bool hasTrivialZeroTripCount(AffineForOp op)
Returns true if the affine.for has zero iterations in trivial cases.
static void composeMultiResultAffineMap(AffineMap &map, SmallVectorImpl< Value > &operands)
Composes the given affine map with the given list of operands, pulling in the maps from any affine....
static void printAffineMinMaxOp(OpAsmPrinter &p, T op)
static bool isResultTypeMatchAtomicRMWKind(Type resultType, arith::AtomicRMWKind op)
static bool remainsLegalAfterInline(Value value, Region *src, Region *dest, const IRMapping &mapping, function_ref< bool(Value, Region *)> legalityCheck)
Checks if value known to be a legal affine dimension or symbol in src region remains legal if the ope...
static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr, DenseIntElementsAttr group, ValueRange operands, StringRef keyword)
Prints a lower(upper) bound of an affine parallel loop with max(min) conditions in it.
static void LLVM_ATTRIBUTE_UNUSED simplifyMapWithOperands(AffineMap &map, ArrayRef< Value > operands)
Simplify the map while exploiting information on the values in operands.
static OpFoldResult foldMinMaxOp(T op, ArrayRef< Attribute > operands)
Fold an affine min or max operation with the given operands.
static LogicalResult canonicalizeLoopBounds(AffineForOp forOp)
Canonicalize the bounds of the given loop.
static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims, unsigned numSymbols, ArrayRef< Value > operands)
Simplify expr while exploiting information from the values in operands.
static bool isValidAffineIndexOperand(Value value, Region *region)
static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
static void composeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Iterate over operands and fold away all those produced by an AffineApplyOp iteratively.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static ParseResult parseBound(bool isLower, OperationState &result, OpAsmParser &p)
Parse a for operation loop bounds.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static void composeSetAndOperands(IntegerSet &set, SmallVectorImpl< Value > &operands)
Compose any affine.apply ops feeding into operands of the integer set set by composing the maps of su...
static LogicalResult replaceDimOrSym(AffineMap *map, unsigned dimOrSymbolPosition, SmallVectorImpl< Value > &dims, SmallVectorImpl< Value > &syms)
Replace all occurrences of AffineExpr at position pos in map by the defining AffineApplyOp expression...
static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType)
Verify common invariants of affine.vector_load and affine.vector_store.
static void simplifyMinOrMaxExprWithOperands(AffineMap &map, ArrayRef< Value > operands, bool isMax)
Simplify the expressions in map while making use of lower or upper bounds of its operands.
static ParseResult parseAffineMinMaxOp(OpAsmParser &parser, OperationState &result)
static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, Region *region)
Returns true if the 'index' dimension of the memref defined by memrefDefOp is a statically shaped one...
static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef< Value > operands, int64_t k)
Check if e is known to be: 0 <= e < k.
static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser, OperationState &result, MinMaxKind kind)
Parses an affine map that can contain a min/max for groups of its results, e.g., max(expr-1,...
static AffineForOp buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds that may or may not be constants.
static void printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, OpAsmPrinter &printer)
Prints dimension and symbol list.
static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef< Value > operands)
Returns the largest known divisor of e.
static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
static void buildAffineLoopNestImpl(OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, LoopCreatorTy &&loopCreatorFn)
Builds an affine loop nest, using "loopCreatorFn" to create individual loop operations.
static LogicalResult foldLoopBounds(AffineForOp forOp)
Fold the constant bounds of a loop.
static LogicalResult verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, unsigned numDims)
Utility function to verify that a set of operands are valid dimension and symbol identifiers.
static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region)
Returns true if the result of the dim op is a valid symbol for region.
static bool isQTimesDPlusR(AffineExpr e, ArrayRef< Value > operands, int64_t &div, AffineExpr "ientTimesDiv, AffineExpr &rem)
Check if expression e is of the form d*e_1 + e_2 where 0 <= e_2 < d.
static ParseResult deduplicateAndResolveOperands(OpAsmParser &parser, ArrayRef< SmallVector< OpAsmParser::UnresolvedOperand >> operands, SmallVectorImpl< Value > &uniqueOperands, SmallVectorImpl< AffineExpr > &replacements, AffineExprKind kind)
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult verifyAffineMinMaxOp(T op)
static void printBound(AffineMapAttr boundMap, Operation::operand_range boundOperands, const char *prefix, OpAsmPrinter &p)
static LogicalResult verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr, Operation::operand_range mapOperands, MemRefType memrefType, unsigned numIndexOperands)
Verify common indexing invariants of affine.load, affine.store, affine.vector_load and affine....
static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map)
Canonicalize the result expression order of an affine map and return success if the order changed.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static Operation::operand_range getLowerBoundOperands(AffineForOp forOp)
static Operation::operand_range getUpperBoundOperands(AffineForOp forOp)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
AffineExprKind getKind() const
Return the classification for this type.
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
MLIRContext * getContext() const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
MLIRContext * getContext() const
bool isFunctionOfDim(unsigned position) const
Return true if any affine expression involves AffineDimExpr position.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
bool isFunctionOfSymbol(unsigned position) const
Return true if any affine expression involves AffineSymbolExpr position.
unsigned getNumResults() const
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
unsigned getNumInputs() const
AffineMap shiftSymbols(unsigned shift, unsigned offset=0) const
Replace symbols[offset ...
AffineExpr getResult(unsigned idx) const
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
static AffineMap getConstantMap(int64_t val, MLIRContext *context)
Returns a single constant result affine map.
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
LogicalResult constantFold(ArrayRef< Attribute > operandConstants, SmallVectorImpl< Attribute > &results, bool *hasPoison=nullptr) const
Folds the results of the application of an affine map on the provided operands to a constant if possi...
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
AffineMap getDimIdentityMap()
AffineMap getMultiDimIdentityMap(unsigned rank)
DenseIntElementsAttr getI32TensorAttr(ArrayRef< int32_t > values)
Tensor-typed DenseIntElementsAttr getters.
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
AffineMap getEmptyAffineMap()
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
MLIRContext * getContext() const
AffineMap getSymbolIdentityMap()
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
An attribute that represents a reference to a dense integer vector or tensor object.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
An integer set representing a conjunction of one or more affine equalities and inequalities.
unsigned getNumDims() const
static IntegerSet get(unsigned dimCount, unsigned symbolCount, ArrayRef< AffineExpr > constraints, ArrayRef< bool > eqFlags)
MLIRContext * getContext() const
unsigned getNumInputs() const
ArrayRef< AffineExpr > getConstraints() const
ArrayRef< bool > getEqFlags() const
Returns the equality bits, which specify whether each of the constraints is an equality or inequality...
unsigned getNumSymbols() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void pop_back()
Pop last element from list.
Attribute erase(StringAttr name)
Erase the attribute with the given name from the list.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseAffineMapOfSSAIds(SmallVectorImpl< UnresolvedOperand > &operands, Attribute &map, StringRef attrName, NamedAttrList &attrs, Delimiter delimiter=Delimiter::Square)=0
Parses an affine map attribute where dims and symbols are SSA operands.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseAffineExprOfSSAIds(SmallVectorImpl< UnresolvedOperand > &dimOperands, SmallVectorImpl< UnresolvedOperand > &symbOperands, AffineExpr &expr)=0
Parses an affine expression where dims and symbols are SSA operands.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, ValueRange symOperands)=0
Prints an affine expression of SSA ids with SSA id names used instead of dims and symbols.
virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, ValueRange operands)=0
Prints an affine map of SSA ids, where SSA id names are used in place of dims/symbols.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
A trait of region holding operations that defines a new scope for polyhedral optimization purposes.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumOperands()
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
operand_range::iterator operand_iterator
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
std::vector< SmallVector< int64_t, 8 > > operandExprStack
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
AffineBound represents a lower or upper bound in the for operation.
AffineDmaStartOp starts a non-blocking DMA operation that transfers data from a source memref to a de...
AffineDmaWaitOp blocks until the completion of a DMA operation associated with the tag element 'tag[i...
An AffineValueMap is an affine map plus its ML value operands and results for analysis purposes.
LogicalResult canonicalize()
Attempts to canonicalize the map and operands.
ArrayRef< Value > getOperands() const
AffineExpr getResult(unsigned i)
AffineMap getAffineMap() const
unsigned getNumResults() const
Operation * getOwner() const
Return the owner of this operand.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
void buildAffineLoopNest(OpBuilder &builder, Location loc, ArrayRef< int64_t > lbs, ArrayRef< int64_t > ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn=nullptr)
Builds a perfect nest of affine.for loops, i.e., each loop except the innermost one contains only ano...
void fullyComposeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Given an affine map map and its input operands, this method composes into map, maps of AffineApplyOps...
void extractForInductionVars(ArrayRef< AffineForOp > forInsts, SmallVectorImpl< Value > *ivs)
Extracts the induction variables from a list of AffineForOps and places them in the output argument i...
bool isValidDim(Value value)
Returns true if the given Value can be used as a dimension id in the region of the closest surroundin...
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
bool isAffineInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp or AffineParallelOp.
AffineForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
bool isAffineForInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp.
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
bool isTopLevelValue(Value value)
A utility function to check if a value is defined at the top level of an op with trait AffineScope or...
void canonicalizeSetAndOperands(IntegerSet *set, SmallVectorImpl< Value > *operands)
Canonicalizes an integer set the same way canonicalizeMapAndOperands does for affine maps.
void extractInductionVars(ArrayRef< Operation * > affineOps, SmallVectorImpl< Value > &ivs)
Extracts the induction variables from a list of either AffineForOp or AffineParallelOp and places the...
bool isValidSymbol(Value value)
Returns true if the given value can be used as a symbol in the region of the closest surrounding op t...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
AffineParallelOp getAffineParallelInductionVarOwner(Value val)
Returns true if the provided value is among the induction variables of an AffineParallelOp.
Region * getAffineScope(Operation *op)
Returns the closest region enclosing op that is held by an operation with trait AffineScope; nullptr ...
ParseResult parseDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl< Value > &operands, unsigned &numDims)
Parses dimension and symbol list.
bool isAffineParallelInductionVar(Value val)
Returns true if val is the induction variable of an AffineParallelOp.
AffineMinOp makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns an AffineMinOp obtained by composing map and operands with AffineApplyOps supplying those ope...
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
bool isZeroIndex(OpFoldResult v)
Return true if v is an IntegerAttr with value 0 of a ConstantIndexOp with attribute with value 0.
AffineMap removeDuplicateExprs(AffineMap map)
Returns a map with the same dimension and symbol count as map, but whose results are the unique affin...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::optional< int64_t > getBoundForAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols, ArrayRef< std::optional< int64_t >> constLowerBounds, ArrayRef< std::optional< int64_t >> constUpperBounds, bool isUpper)
Get a lower or upper (depending on isUpper) bound for expr while using the constant lower and upper b...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ SymbolId
Symbolic identifier.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool isStrided(MemRefType t)
Return "true" if the layout for t is compatible with strided semantics.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
AffineMap foldAttributesIntoMap(Builder &b, AffineMap map, ArrayRef< OpFoldResult > operands, SmallVector< Value > &remainingValues)
Fold all attributes among the given operands into the affine map.
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Canonicalize the affine map result expression order of an affine min/max operation.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Remove duplicated expressions in affine min/max ops.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Merge an affine min/max op to its consumers if its consumer is also an affine min/max op.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.