23 #include "llvm/ADT/ScopeExit.h"
24 #include "llvm/ADT/SmallBitVector.h"
25 #include "llvm/ADT/SmallVectorExtras.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/MathExtras.h"
35 using llvm::divideCeilSigned;
36 using llvm::divideFloorSigned;
39 #define DEBUG_TYPE "affine-ops"
41 #include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
48 if (
auto arg = llvm::dyn_cast<BlockArgument>(value))
49 return arg.getParentRegion() == region;
72 if (llvm::isa<BlockArgument>(value))
73 return legalityCheck(mapping.
lookup(value), dest);
80 bool isDimLikeOp = isa<ShapedDimOpInterface>(value.
getDefiningOp());
91 return llvm::all_of(values, [&](
Value v) {
98 template <
typename OpTy>
101 static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
102 AffineWriteOpInterface>::value,
103 "only ops with affine read/write interface are supported");
110 dimOperands, src, dest, mapping,
114 symbolOperands, src, dest, mapping,
131 op.getMapOperands(), src, dest, mapping,
136 op.getMapOperands(), src, dest, mapping,
163 if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
168 if (!llvm::hasSingleElement(*src))
176 if (
auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
177 if (iface.hasNoEffect())
185 .Case<AffineApplyOp, AffineReadOpInterface,
186 AffineWriteOpInterface>([&](
auto op) {
211 isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
215 bool shouldAnalyzeRecursively(
Operation *op)
const final {
return true; }
223 void AffineDialect::initialize() {
226 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
228 addInterfaces<AffineInlinerInterface>();
229 declarePromisedInterfaces<ValueBoundsOpInterface, AffineApplyOp, AffineMaxOp,
238 if (
auto poison = dyn_cast<ub::PoisonAttr>(value))
239 return builder.
create<ub::PoisonOp>(loc, type, poison);
240 return arith::ConstantOp::materialize(builder, value, type, loc);
248 if (
auto arg = llvm::dyn_cast<BlockArgument>(value)) {
264 while (
auto *parentOp = curOp->getParentOp()) {
287 auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
289 isa<AffineForOp, AffineParallelOp>(parentOp));
310 auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->
getParentOp();
311 return isa<AffineForOp, AffineParallelOp>(parentOp);
315 if (
auto applyOp = dyn_cast<AffineApplyOp>(op))
316 return applyOp.isValidDim(region);
319 if (
auto dimOp = dyn_cast<ShapedDimOpInterface>(op))
327 template <
typename AnyMemRefDefOp>
330 MemRefType memRefType = memrefDefOp.getType();
333 if (index >= memRefType.getRank()) {
338 if (!memRefType.isDynamicDim(index))
341 unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
342 return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
354 if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
362 if (!index.has_value())
366 Operation *op = dimOp.getShapedValue().getDefiningOp();
367 while (
auto castOp = dyn_cast<memref::CastOp>(op)) {
369 if (isa<UnrankedMemRefType>(castOp.getSource().getType()))
371 op = castOp.getSource().getDefiningOp();
376 int64_t i = index.value();
378 .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
380 .Default([](
Operation *) {
return false; });
446 if (
auto applyOp = dyn_cast<AffineApplyOp>(defOp))
447 return applyOp.isValidSymbol(region);
450 if (
auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp))
474 printer <<
'(' << operands.take_front(numDims) <<
')';
475 if (operands.size() > numDims)
476 printer <<
'[' << operands.drop_front(numDims) <<
']';
486 numDims = opInfos.size();
500 template <
typename OpTy>
505 for (
auto operand : operands) {
506 if (opIt++ < numDims) {
508 return op.emitOpError(
"operand cannot be used as a dimension id");
510 return op.emitOpError(
"operand cannot be used as a symbol");
521 return AffineValueMap(getAffineMap(), getOperands(), getResult());
528 AffineMapAttr mapAttr;
534 auto map = mapAttr.getValue();
536 if (map.getNumDims() != numDims ||
537 numDims + map.getNumSymbols() != result.
operands.size()) {
539 "dimension or symbol index mismatch");
542 result.
types.append(map.getNumResults(), indexTy);
547 p <<
" " << getMapAttr();
549 getAffineMap().getNumDims(), p);
560 "operand count and affine map dimension and symbol count must match");
564 return emitOpError(
"mapping must produce one value");
572 return llvm::all_of(getOperands(),
580 return llvm::all_of(getOperands(),
587 return llvm::all_of(getOperands(),
594 return llvm::all_of(getOperands(), [&](
Value operand) {
600 auto map = getAffineMap();
603 auto expr = map.getResult(0);
604 if (
auto dim = dyn_cast<AffineDimExpr>(expr))
605 return getOperand(dim.getPosition());
606 if (
auto sym = dyn_cast<AffineSymbolExpr>(expr))
607 return getOperand(map.getNumDims() + sym.getPosition());
611 bool hasPoison =
false;
613 map.constantFold(adaptor.getMapOperands(), result, &hasPoison);
616 if (failed(foldResult))
633 auto dimExpr = dyn_cast<AffineDimExpr>(e);
643 Value operand = operands[dimExpr.getPosition()];
644 int64_t operandDivisor = 1;
648 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() == 0) {
649 operandDivisor = forOp.getStepAsInt();
651 uint64_t lbLargestKnownDivisor =
652 forOp.getLowerBoundMap().getLargestKnownDivisorOfMapExprs();
653 operandDivisor = std::gcd(lbLargestKnownDivisor, forOp.getStepAsInt());
656 return operandDivisor;
663 if (
auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
664 int64_t constVal = constExpr.getValue();
665 return constVal >= 0 && constVal < k;
667 auto dimExpr = dyn_cast<AffineDimExpr>(e);
670 Value operand = operands[dimExpr.getPosition()];
674 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() >= 0 &&
675 forOp.hasConstantUpperBound() && forOp.getConstantUpperBound() <= k) {
691 auto bin = dyn_cast<AffineBinaryOpExpr>(e);
699 quotientTimesDiv = llhs;
705 quotientTimesDiv = rlhs;
715 if (forOp && forOp.hasConstantLowerBound())
716 return forOp.getConstantLowerBound();
723 if (!forOp || !forOp.hasConstantUpperBound())
728 if (forOp.hasConstantLowerBound()) {
729 return forOp.getConstantUpperBound() - 1 -
730 (forOp.getConstantUpperBound() - forOp.getConstantLowerBound() - 1) %
731 forOp.getStepAsInt();
733 return forOp.getConstantUpperBound() - 1;
744 constLowerBounds.reserve(operands.size());
745 constUpperBounds.reserve(operands.size());
746 for (
Value operand : operands) {
751 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr))
752 return constExpr.getValue();
767 constLowerBounds.reserve(operands.size());
768 constUpperBounds.reserve(operands.size());
769 for (
Value operand : operands) {
774 std::optional<int64_t> lowerBound;
775 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
776 lowerBound = constExpr.getValue();
779 constLowerBounds, constUpperBounds,
790 auto binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
801 binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
809 lhs = binExpr.getLHS();
810 rhs = binExpr.getRHS();
811 auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
815 int64_t rhsConstVal = rhsConst.getValue();
817 if (rhsConstVal <= 0)
822 std::optional<int64_t> lhsLbConst =
824 std::optional<int64_t> lhsUbConst =
826 if (lhsLbConst && lhsUbConst) {
827 int64_t lhsLbConstVal = *lhsLbConst;
828 int64_t lhsUbConstVal = *lhsUbConst;
832 divideFloorSigned(lhsLbConstVal, rhsConstVal) ==
833 divideFloorSigned(lhsUbConstVal, rhsConstVal)) {
835 divideFloorSigned(lhsLbConstVal, rhsConstVal), context);
841 divideCeilSigned(lhsLbConstVal, rhsConstVal) ==
842 divideCeilSigned(lhsUbConstVal, rhsConstVal)) {
849 lhsLbConstVal < rhsConstVal && lhsUbConstVal < rhsConstVal) {
861 if (
isQTimesDPlusR(lhs, operands, divisor, quotientTimesDiv, rem)) {
862 if (rhsConstVal % divisor == 0 &&
864 expr = quotientTimesDiv.
floorDiv(rhsConst);
865 }
else if (divisor % rhsConstVal == 0 &&
867 expr = rem % rhsConst;
893 if (operands.empty())
899 constLowerBounds.reserve(operands.size());
900 constUpperBounds.reserve(operands.size());
901 for (
Value operand : operands) {
915 if (
auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
916 lowerBounds.push_back(constExpr.getValue());
917 upperBounds.push_back(constExpr.getValue());
919 lowerBounds.push_back(
921 constLowerBounds, constUpperBounds,
923 upperBounds.push_back(
925 constLowerBounds, constUpperBounds,
934 unsigned i = exprEn.index();
936 if (lowerBounds[i] && upperBounds[i] && *lowerBounds[i] == *upperBounds[i])
941 if (!upperBounds[i]) {
942 irredundantExprs.push_back(e);
948 auto otherLowerBound = en.value();
949 unsigned pos = en.index();
950 if (pos == i || !otherLowerBound)
952 if (*otherLowerBound > *upperBounds[i])
954 if (*otherLowerBound < *upperBounds[i])
959 if (upperBounds[pos] && lowerBounds[i] &&
960 lowerBounds[i] == upperBounds[i] &&
961 otherLowerBound == *upperBounds[pos] && i < pos)
965 irredundantExprs.push_back(e);
967 if (!lowerBounds[i]) {
968 irredundantExprs.push_back(e);
973 auto otherUpperBound = en.value();
974 unsigned pos = en.index();
975 if (pos == i || !otherUpperBound)
977 if (*otherUpperBound < *lowerBounds[i])
979 if (*otherUpperBound > *lowerBounds[i])
981 if (lowerBounds[pos] && upperBounds[i] &&
982 lowerBounds[i] == upperBounds[i] &&
983 otherUpperBound == lowerBounds[pos] && i < pos)
987 irredundantExprs.push_back(e);
999 static void LLVM_ATTRIBUTE_UNUSED
1001 assert(map.
getNumInputs() == operands.size() &&
"invalid operands for map");
1007 newResults.push_back(expr);
1024 unsigned dimOrSymbolPosition,
1028 bool isDimReplacement = (dimOrSymbolPosition < dims.size());
1029 unsigned pos = isDimReplacement ? dimOrSymbolPosition
1030 : dimOrSymbolPosition - dims.size();
1031 Value &v = isDimReplacement ? dims[pos] : syms[pos];
1044 AffineMap composeMap = affineApply.getAffineMap();
1045 assert(composeMap.
getNumResults() == 1 &&
"affine.apply with >1 results");
1047 affineApply.getMapOperands().end());
1061 dims.append(composeDims.begin(), composeDims.end());
1062 syms.append(composeSyms.begin(), composeSyms.end());
1063 *map = map->
replace(toReplace, replacementExpr, dims.size(), syms.size());
1091 bool changed =
false;
1092 for (
unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
1104 unsigned nDims = 0, nSyms = 0;
1106 dimReplacements.reserve(dims.size());
1107 symReplacements.reserve(syms.size());
1108 for (
auto *container : {&dims, &syms}) {
1109 bool isDim = (container == &dims);
1110 auto &repls = isDim ? dimReplacements : symReplacements;
1112 Value v = en.value();
1116 "map is function of unexpected expr@pos");
1122 operands->push_back(v);
1135 while (llvm::any_of(*operands, [](
Value v) {
1149 return b.
create<AffineApplyOp>(loc, map, valueOperands);
1171 for (
unsigned i : llvm::seq<unsigned>(0, map.
getNumResults())) {
1178 llvm::append_range(dims,
1180 llvm::append_range(symbols,
1187 operands = llvm::to_vector(llvm::concat<Value>(dims, symbols));
1196 assert(map.
getNumResults() == 1 &&
"building affine.apply with !=1 result");
1206 AffineApplyOp applyOp =
1211 for (
unsigned i = 0, e = constOperands.size(); i != e; ++i)
1216 if (failed(applyOp->fold(constOperands, foldResults)) ||
1217 foldResults.empty()) {
1219 listener->notifyOperationInserted(applyOp, {});
1220 return applyOp.getResult();
1224 assert(foldResults.size() == 1 &&
"expected 1 folded result");
1225 return foldResults.front();
1243 return llvm::map_to_vector(llvm::seq<unsigned>(0, map.
getNumResults()),
1245 return makeComposedFoldedAffineApply(
1246 b, loc, map.getSubMap({i}), operands);
1250 template <
typename OpTy>
1262 return makeComposedMinMax<AffineMinOp>(b, loc, map, operands);
1265 template <
typename OpTy>
1277 auto minMaxOp = makeComposedMinMax<OpTy>(newBuilder, loc, map, operands);
1281 for (
unsigned i = 0, e = constOperands.size(); i != e; ++i)
1286 if (failed(minMaxOp->fold(constOperands, foldResults)) ||
1287 foldResults.empty()) {
1289 listener->notifyOperationInserted(minMaxOp, {});
1290 return minMaxOp.getResult();
1294 assert(foldResults.size() == 1 &&
"expected 1 folded result");
1295 return foldResults.front();
1302 return makeComposedFoldedMinMax<AffineMinOp>(b, loc, map, operands);
1309 return makeComposedFoldedMinMax<AffineMaxOp>(b, loc, map, operands);
1314 template <
class MapOrSet>
1317 if (!mapOrSet || operands->empty())
1320 assert(mapOrSet->getNumInputs() == operands->size() &&
1321 "map/set inputs must match number of operands");
1323 auto *context = mapOrSet->getContext();
1325 resultOperands.reserve(operands->size());
1327 remappedSymbols.reserve(operands->size());
1328 unsigned nextDim = 0;
1329 unsigned nextSym = 0;
1330 unsigned oldNumSyms = mapOrSet->getNumSymbols();
1332 for (
unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
1333 if (i < mapOrSet->getNumDims()) {
1337 remappedSymbols.push_back((*operands)[i]);
1340 resultOperands.push_back((*operands)[i]);
1343 resultOperands.push_back((*operands)[i]);
1347 resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
1348 *operands = resultOperands;
1349 *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
1350 oldNumSyms + nextSym);
1352 assert(mapOrSet->getNumInputs() == operands->size() &&
1353 "map/set inputs must match number of operands");
1357 template <
class MapOrSet>
1360 static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
1361 "Argument must be either of AffineMap or IntegerSet type");
1363 if (!mapOrSet || operands->empty())
1366 assert(mapOrSet->getNumInputs() == operands->size() &&
1367 "map/set inputs must match number of operands");
1369 canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
1372 llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
1373 llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
1375 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr))
1376 usedDims[dimExpr.getPosition()] =
true;
1377 else if (
auto symExpr = dyn_cast<AffineSymbolExpr>(expr))
1378 usedSyms[symExpr.getPosition()] =
true;
1381 auto *context = mapOrSet->getContext();
1384 resultOperands.reserve(operands->size());
1386 llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
1388 unsigned nextDim = 0;
1389 for (
unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
1392 auto it = seenDims.find((*operands)[i]);
1393 if (it == seenDims.end()) {
1395 resultOperands.push_back((*operands)[i]);
1396 seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
1398 dimRemapping[i] = it->second;
1402 llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
1404 unsigned nextSym = 0;
1405 for (
unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
1411 IntegerAttr operandCst;
1412 if (
matchPattern((*operands)[i + mapOrSet->getNumDims()],
1419 auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
1420 if (it == seenSymbols.end()) {
1422 resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
1423 seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
1426 symRemapping[i] = it->second;
1429 *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
1431 *operands = resultOperands;
1436 canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
1441 canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
1448 template <
typename AffineOpTy>
1457 LogicalResult matchAndRewrite(AffineOpTy affineOp,
1460 llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
1461 AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp,
1462 AffineVectorStoreOp, AffineVectorLoadOp>::value,
1463 "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op "
1465 auto map = affineOp.getAffineMap();
1467 auto oldOperands = affineOp.getMapOperands();
1472 if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
1473 resultOperands.begin()))
1476 replaceAffineOp(rewriter, affineOp, map, resultOperands);
1484 void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
1491 void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
1495 prefetch, prefetch.getMemref(), map, mapOperands, prefetch.getIsWrite(),
1496 prefetch.getLocalityHint(), prefetch.getIsDataCache());
1499 void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
1503 store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
1506 void SimplifyAffineOp<AffineVectorLoadOp>::replaceAffineOp(
1510 vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map,
1514 void SimplifyAffineOp<AffineVectorStoreOp>::replaceAffineOp(
1518 vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map,
1523 template <
typename AffineOpTy>
1524 void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
1533 results.
add<SimplifyAffineOp<AffineApplyOp>>(context);
1564 p <<
" " << getSrcMemRef() <<
'[';
1566 p <<
"], " << getDstMemRef() <<
'[';
1568 p <<
"], " << getTagMemRef() <<
'[';
1572 p <<
", " << getStride();
1573 p <<
", " << getNumElementsPerStride();
1575 p <<
" : " << getSrcMemRefType() <<
", " << getDstMemRefType() <<
", "
1576 << getTagMemRefType();
1588 AffineMapAttr srcMapAttr;
1591 AffineMapAttr dstMapAttr;
1594 AffineMapAttr tagMapAttr;
1609 getSrcMapAttrStrName(),
1613 getDstMapAttrStrName(),
1617 getTagMapAttrStrName(),
1626 if (!strideInfo.empty() && strideInfo.size() != 2) {
1628 "expected two stride related operands");
1630 bool isStrided = strideInfo.size() == 2;
1635 if (types.size() != 3)
1653 if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
1654 dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
1655 tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1657 "memref operand count not equal to map.numInputs");
1661 LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
1662 if (!llvm::isa<MemRefType>(getOperand(getSrcMemRefOperandIndex()).
getType()))
1663 return emitOpError(
"expected DMA source to be of memref type");
1664 if (!llvm::isa<MemRefType>(getOperand(getDstMemRefOperandIndex()).
getType()))
1665 return emitOpError(
"expected DMA destination to be of memref type");
1666 if (!llvm::isa<MemRefType>(getOperand(getTagMemRefOperandIndex()).
getType()))
1667 return emitOpError(
"expected DMA tag to be of memref type");
1669 unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
1670 getDstMap().getNumInputs() +
1671 getTagMap().getNumInputs();
1672 if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
1673 getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
1674 return emitOpError(
"incorrect number of operands");
1678 for (
auto idx : getSrcIndices()) {
1679 if (!idx.getType().isIndex())
1680 return emitOpError(
"src index to dma_start must have 'index' type");
1683 "src index must be a valid dimension or symbol identifier");
1685 for (
auto idx : getDstIndices()) {
1686 if (!idx.getType().isIndex())
1687 return emitOpError(
"dst index to dma_start must have 'index' type");
1690 "dst index must be a valid dimension or symbol identifier");
1692 for (
auto idx : getTagIndices()) {
1693 if (!idx.getType().isIndex())
1694 return emitOpError(
"tag index to dma_start must have 'index' type");
1697 "tag index must be a valid dimension or symbol identifier");
1708 void AffineDmaStartOp::getEffects(
1734 p <<
" " << getTagMemRef() <<
'[';
1739 p <<
" : " << getTagMemRef().getType();
1750 AffineMapAttr tagMapAttr;
1759 getTagMapAttrStrName(),
1768 if (!llvm::isa<MemRefType>(type))
1770 "expected tag to be of memref type");
1772 if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1774 "tag memref operand count != to map.numInputs");
1778 LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
1779 if (!llvm::isa<MemRefType>(getOperand(0).
getType()))
1780 return emitOpError(
"expected DMA tag to be of memref type");
1782 for (
auto idx : getTagIndices()) {
1783 if (!idx.getType().isIndex())
1784 return emitOpError(
"index to dma_wait must have 'index' type");
1787 "index must be a valid dimension or symbol identifier");
1798 void AffineDmaWaitOp::getEffects(
1814 ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
1815 assert(((!lbMap && lbOperands.empty()) ||
1817 "lower bound operand count does not match the affine map");
1818 assert(((!ubMap && ubOperands.empty()) ||
1820 "upper bound operand count does not match the affine map");
1821 assert(step > 0 &&
"step has to be a positive integer constant");
1827 getOperandSegmentSizeAttr(),
1829 static_cast<int32_t>(ubOperands.size()),
1830 static_cast<int32_t>(iterArgs.size())}));
1832 for (
Value val : iterArgs)
1854 Value inductionVar =
1856 for (
Value val : iterArgs)
1857 bodyBlock->
addArgument(val.getType(), val.getLoc());
1862 if (iterArgs.empty() && !bodyBuilder) {
1863 ensureTerminator(*bodyRegion, builder, result.
location);
1864 }
else if (bodyBuilder) {
1867 bodyBuilder(builder, result.
location, inductionVar,
1873 int64_t ub, int64_t step,
ValueRange iterArgs,
1874 BodyBuilderFn bodyBuilder) {
1877 return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,
1881 LogicalResult AffineForOp::verifyRegions() {
1884 auto *body = getBody();
1885 if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
1886 return emitOpError(
"expected body to have a single index argument for the "
1887 "induction variable");
1891 if (getLowerBoundMap().getNumInputs() > 0)
1893 getLowerBoundMap().getNumDims())))
1896 if (getUpperBoundMap().getNumInputs() > 0)
1898 getUpperBoundMap().getNumDims())))
1901 unsigned opNumResults = getNumResults();
1902 if (opNumResults == 0)
1908 if (getNumIterOperands() != opNumResults)
1910 "mismatch between the number of loop-carried values and results");
1911 if (getNumRegionIterArgs() != opNumResults)
1913 "mismatch between the number of basic block args and results");
1923 bool failedToParsedMinMax =
1927 auto boundAttrStrName =
1928 isLower ? AffineForOp::getLowerBoundMapAttrName(result.
name)
1929 : AffineForOp::getUpperBoundMapAttrName(result.
name);
1936 if (!boundOpInfos.empty()) {
1938 if (boundOpInfos.size() > 1)
1940 "expected only one loop bound operand");
1965 if (
auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(boundAttr)) {
1966 unsigned currentNumOperands = result.
operands.size();
1971 auto map = affineMapAttr.getValue();
1975 "dim operand count and affine map dim count must match");
1977 unsigned numDimAndSymbolOperands =
1978 result.
operands.size() - currentNumOperands;
1979 if (numDims + map.
getNumSymbols() != numDimAndSymbolOperands)
1982 "symbol operand count and affine map symbol count must match");
1988 return p.
emitError(attrLoc,
"lower loop bound affine map with "
1989 "multiple results requires 'max' prefix");
1991 return p.
emitError(attrLoc,
"upper loop bound affine map with multiple "
1992 "results requires 'min' prefix");
1998 if (
auto integerAttr = llvm::dyn_cast<IntegerAttr>(boundAttr)) {
2008 "expected valid affine map representation for loop bounds");
2020 int64_t numOperands = result.
operands.size();
2023 int64_t numLbOperands = result.
operands.size() - numOperands;
2026 numOperands = result.
operands.size();
2029 int64_t numUbOperands = result.
operands.size() - numOperands;
2034 getStepAttrName(result.
name),
2038 IntegerAttr stepAttr;
2040 getStepAttrName(result.
name).data(),
2044 if (stepAttr.getValue().isNegative())
2047 "expected step to be representable as a positive signed integer");
2055 regionArgs.push_back(inductionVariable);
2063 for (
auto argOperandType :
2064 llvm::zip(llvm::drop_begin(regionArgs), operands, result.
types)) {
2065 Type type = std::get<2>(argOperandType);
2066 std::get<0>(argOperandType).type = type;
2074 getOperandSegmentSizeAttr(),
2076 static_cast<int32_t>(numUbOperands),
2077 static_cast<int32_t>(operands.size())}));
2081 if (regionArgs.size() != result.
types.size() + 1)
2084 "mismatch between the number of loop-carried values and results");
2088 AffineForOp::ensureTerminator(*body, builder, result.
location);
2110 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
2111 p << constExpr.getValue();
2119 if (dyn_cast<AffineSymbolExpr>(expr)) {
2135 unsigned AffineForOp::getNumIterOperands() {
2136 AffineMap lbMap = getLowerBoundMapAttr().getValue();
2137 AffineMap ubMap = getUpperBoundMapAttr().getValue();
2142 std::optional<MutableArrayRef<OpOperand>>
2143 AffineForOp::getYieldedValuesMutable() {
2144 return cast<AffineYieldOp>(getBody()->getTerminator()).getOperandsMutable();
2156 if (getStepAsInt() != 1)
2157 p <<
" step " << getStepAsInt();
2159 bool printBlockTerminators =
false;
2160 if (getNumIterOperands() > 0) {
2162 auto regionArgs = getRegionIterArgs();
2163 auto operands = getInits();
2165 llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](
auto it) {
2166 p << std::get<0>(it) <<
" = " << std::get<1>(it);
2168 p <<
") -> (" << getResultTypes() <<
")";
2169 printBlockTerminators =
true;
2174 printBlockTerminators);
2176 (*this)->getAttrs(),
2177 {getLowerBoundMapAttrName(getOperation()->getName()),
2178 getUpperBoundMapAttrName(getOperation()->getName()),
2179 getStepAttrName(getOperation()->getName()),
2180 getOperandSegmentSizeAttr()});
2185 auto foldLowerOrUpperBound = [&forOp](
bool lower) {
2189 auto boundOperands =
2190 lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
2191 for (
auto operand : boundOperands) {
2194 operandConstants.push_back(operandCst);
2198 lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
2200 "bound maps should have at least one result");
2202 if (failed(boundMap.
constantFold(operandConstants, foldedResults)))
2206 assert(!foldedResults.empty() &&
"bounds should have at least one result");
2207 auto maxOrMin = llvm::cast<IntegerAttr>(foldedResults[0]).getValue();
2208 for (
unsigned i = 1, e = foldedResults.size(); i < e; i++) {
2209 auto foldedResult = llvm::cast<IntegerAttr>(foldedResults[i]).getValue();
2210 maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
2211 : llvm::APIntOps::smin(maxOrMin, foldedResult);
2213 lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
2214 : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
2219 bool folded =
false;
2220 if (!forOp.hasConstantLowerBound())
2221 folded |= succeeded(foldLowerOrUpperBound(
true));
2224 if (!forOp.hasConstantUpperBound())
2225 folded |= succeeded(foldLowerOrUpperBound(
false));
2226 return success(folded);
2234 auto lbMap = forOp.getLowerBoundMap();
2235 auto ubMap = forOp.getUpperBoundMap();
2236 auto prevLbMap = lbMap;
2237 auto prevUbMap = ubMap;
2250 if (lbMap == prevLbMap && ubMap == prevUbMap)
2253 if (lbMap != prevLbMap)
2254 forOp.setLowerBound(lbOperands, lbMap);
2255 if (ubMap != prevUbMap)
2256 forOp.setUpperBound(ubOperands, ubMap);
2262 static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
2263 int64_t step = forOp.getStepAsInt();
2264 if (!forOp.hasConstantBounds() || step <= 0)
2265 return std::nullopt;
2266 int64_t lb = forOp.getConstantLowerBound();
2267 int64_t ub = forOp.getConstantUpperBound();
2268 return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
2276 LogicalResult matchAndRewrite(AffineForOp forOp,
2279 if (!llvm::hasSingleElement(*forOp.getBody()))
2281 if (forOp.getNumResults() == 0)
2283 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
2284 if (tripCount && *tripCount == 0) {
2287 rewriter.
replaceOp(forOp, forOp.getInits());
2291 auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
2292 auto iterArgs = forOp.getRegionIterArgs();
2293 bool hasValDefinedOutsideLoop =
false;
2294 bool iterArgsNotInOrder =
false;
2295 for (
unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
2296 Value val = yieldOp.getOperand(i);
2297 auto *iterArgIt = llvm::find(iterArgs, val);
2298 if (iterArgIt == iterArgs.end()) {
2300 assert(forOp.isDefinedOutsideOfLoop(val) &&
2301 "must be defined outside of the loop");
2302 hasValDefinedOutsideLoop =
true;
2303 replacements.push_back(val);
2305 unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
2307 iterArgsNotInOrder =
true;
2308 replacements.push_back(forOp.getInits()[pos]);
2313 if (!tripCount.has_value() &&
2314 (hasValDefinedOutsideLoop || iterArgsNotInOrder))
2318 if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
2320 rewriter.
replaceOp(forOp, replacements);
2328 results.
add<AffineForEmptyLoopFolder>(context);
2332 assert((point.
isParent() || point == getRegion()) &&
"invalid region point");
2339 void AffineForOp::getSuccessorRegions(
2341 assert((point.
isParent() || point == getRegion()) &&
"expected loop region");
2346 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*
this);
2347 if (point.
isParent() && tripCount.has_value()) {
2348 if (tripCount.value() > 0) {
2349 regions.push_back(
RegionSuccessor(&getRegion(), getRegionIterArgs()));
2352 if (tripCount.value() == 0) {
2360 if (!point.
isParent() && tripCount && *tripCount == 1) {
2367 regions.push_back(
RegionSuccessor(&getRegion(), getRegionIterArgs()));
2373 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(op);
2374 return tripCount && *tripCount == 0;
2377 LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
2387 results.assign(getInits().begin(), getInits().end());
2390 return success(folded);
2403 assert(map.
getNumResults() >= 1 &&
"bound map has at least one result");
2404 getLowerBoundOperandsMutable().assign(lbOperands);
2405 setLowerBoundMap(map);
2410 assert(map.
getNumResults() >= 1 &&
"bound map has at least one result");
2411 getUpperBoundOperandsMutable().assign(ubOperands);
2412 setUpperBoundMap(map);
2415 bool AffineForOp::hasConstantLowerBound() {
2416 return getLowerBoundMap().isSingleConstant();
2419 bool AffineForOp::hasConstantUpperBound() {
2420 return getUpperBoundMap().isSingleConstant();
2423 int64_t AffineForOp::getConstantLowerBound() {
2424 return getLowerBoundMap().getSingleConstantResult();
2427 int64_t AffineForOp::getConstantUpperBound() {
2428 return getUpperBoundMap().getSingleConstantResult();
2431 void AffineForOp::setConstantLowerBound(int64_t value) {
2435 void AffineForOp::setConstantUpperBound(int64_t value) {
2439 AffineForOp::operand_range AffineForOp::getControlOperands() {
2444 bool AffineForOp::matchingBoundOperandList() {
2445 auto lbMap = getLowerBoundMap();
2446 auto ubMap = getUpperBoundMap();
2452 for (
unsigned i = 0, e = lbMap.
getNumInputs(); i < e; i++) {
2454 if (getOperand(i) != getOperand(numOperands + i))
2462 std::optional<SmallVector<Value>> AffineForOp::getLoopInductionVars() {
2466 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopLowerBounds() {
2467 if (!hasConstantLowerBound())
2468 return std::nullopt;
2471 OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))};
2474 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopSteps() {
2480 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopUpperBounds() {
2481 if (!hasConstantUpperBound())
2485 OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()))};
2488 FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
2490 bool replaceInitOperandUsesInLoop,
2495 auto inits = llvm::to_vector(getInits());
2496 inits.append(newInitOperands.begin(), newInitOperands.end());
2497 AffineForOp newLoop = rewriter.
create<AffineForOp>(
2502 auto yieldOp = cast<AffineYieldOp>(getBody()->getTerminator());
2504 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
2509 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
2510 assert(newInitOperands.size() == newYieldedValues.size() &&
2511 "expected as many new yield values as new iter operands");
2513 yieldOp.getOperandsMutable().append(newYieldedValues);
2518 rewriter.
mergeBlocks(getBody(), newLoop.getBody(),
2519 newLoop.getBody()->getArguments().take_front(
2520 getBody()->getNumArguments()));
2522 if (replaceInitOperandUsesInLoop) {
2525 for (
auto it : llvm::zip(newInitOperands, newIterArgs)) {
2536 newLoop->getResults().take_front(getNumResults()));
2537 return cast<LoopLikeOpInterface>(newLoop.getOperation());
2565 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2566 if (!ivArg || !ivArg.getOwner() || !ivArg.getOwner()->getParent())
2567 return AffineForOp();
2569 ivArg.getOwner()->getParent()->getParentOfType<AffineForOp>())
2571 return forOp.getInductionVar() == val ? forOp : AffineForOp();
2572 return AffineForOp();
2576 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2577 if (!ivArg || !ivArg.getOwner())
2580 auto parallelOp = dyn_cast<AffineParallelOp>(containingOp);
2581 if (parallelOp && llvm::is_contained(parallelOp.getIVs(), val))
2590 ivs->reserve(forInsts.size());
2591 for (
auto forInst : forInsts)
2592 ivs->push_back(forInst.getInductionVar());
2597 ivs.reserve(affineOps.size());
2600 if (
auto forOp = dyn_cast<AffineForOp>(op))
2601 ivs.push_back(forOp.getInductionVar());
2602 else if (
auto parallelOp = dyn_cast<AffineParallelOp>(op))
2603 for (
size_t i = 0; i < parallelOp.getBody()->getNumArguments(); i++)
2604 ivs.push_back(parallelOp.getBody()->getArgument(i));
2610 template <
typename BoundListTy,
typename LoopCreatorTy>
2615 LoopCreatorTy &&loopCreatorFn) {
2616 assert(lbs.size() == ubs.size() &&
"Mismatch in number of arguments");
2617 assert(lbs.size() == steps.size() &&
"Mismatch in number of arguments");
2629 ivs.reserve(lbs.size());
2630 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
2636 if (i == e - 1 && bodyBuilderFn) {
2638 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2640 nestedBuilder.
create<AffineYieldOp>(nestedLoc);
2645 auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
2653 int64_t ub, int64_t step,
2654 AffineForOp::BodyBuilderFn bodyBuilderFn) {
2655 return builder.
create<AffineForOp>(loc, lb, ub, step,
2656 std::nullopt, bodyBuilderFn);
2663 AffineForOp::BodyBuilderFn bodyBuilderFn) {
2666 if (lbConst && ubConst)
2668 ubConst.value(), step, bodyBuilderFn);
2671 std::nullopt, bodyBuilderFn);
2699 LogicalResult matchAndRewrite(AffineIfOp ifOp,
2701 if (ifOp.getElseRegion().empty() ||
2702 !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
2717 LogicalResult matchAndRewrite(AffineIfOp op,
2720 auto isTriviallyFalse = [](
IntegerSet iSet) {
2721 return iSet.isEmptyIntegerSet();
2725 return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 &&
2726 iSet.getConstraint(0) == 0);
2729 IntegerSet affineIfConditions = op.getIntegerSet();
2731 if (isTriviallyFalse(affineIfConditions)) {
2735 if (op.getNumResults() == 0 && !op.hasElse()) {
2741 blockToMove = op.getElseBlock();
2742 }
else if (isTriviallyTrue(affineIfConditions)) {
2743 blockToMove = op.getThenBlock();
2761 rewriter.
eraseOp(blockToMoveTerminator);
2769 void AffineIfOp::getSuccessorRegions(
2778 if (getElseRegion().empty()) {
2779 regions.push_back(getResults());
2795 auto conditionAttr =
2796 (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2798 return emitOpError(
"requires an integer set attribute named 'condition'");
2801 IntegerSet condition = conditionAttr.getValue();
2803 return emitOpError(
"operand count and condition integer set dimension and "
2804 "symbol count must match");
2816 IntegerSetAttr conditionAttr;
2819 AffineIfOp::getConditionAttrStrName(),
2825 auto set = conditionAttr.getValue();
2826 if (set.getNumDims() != numDims)
2829 "dim operand count and integer set dim count must match");
2830 if (numDims + set.getNumSymbols() != result.
operands.size())
2833 "symbol operand count and integer set symbol count must match");
2847 AffineIfOp::ensureTerminator(*thenRegion, parser.
getBuilder(),
2854 AffineIfOp::ensureTerminator(*elseRegion, parser.
getBuilder(),
2866 auto conditionAttr =
2867 (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2868 p <<
" " << conditionAttr;
2870 conditionAttr.getValue().getNumDims(), p);
2877 auto &elseRegion = this->getElseRegion();
2878 if (!elseRegion.
empty()) {
2887 getConditionAttrStrName());
2892 ->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName())
2896 void AffineIfOp::setIntegerSet(
IntegerSet newSet) {
2902 (*this)->setOperands(operands);
2907 bool withElseRegion) {
2908 assert(resultTypes.empty() || withElseRegion);
2917 if (resultTypes.empty())
2918 AffineIfOp::ensureTerminator(*thenRegion, builder, result.
location);
2921 if (withElseRegion) {
2923 if (resultTypes.empty())
2924 AffineIfOp::ensureTerminator(*elseRegion, builder, result.
location);
2930 AffineIfOp::build(builder, result, {}, set, args,
2945 if (llvm::none_of(operands,
2956 auto set = getIntegerSet();
2962 if (getIntegerSet() == set && llvm::equal(operands, getOperands()))
2965 setConditional(set, operands);
2971 results.
add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context);
2980 assert(operands.size() == 1 + map.
getNumInputs() &&
"inconsistent operands");
2984 auto memrefType = llvm::cast<MemRefType>(operands[0].
getType());
2985 result.
types.push_back(memrefType.getElementType());
2990 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
2993 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
2995 result.
types.push_back(memrefType.getElementType());
3000 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
3001 int64_t rank = memrefType.getRank();
3006 build(builder, result, memref, map, indices);
3015 AffineMapAttr mapAttr;
3020 AffineLoadOp::getMapAttrStrName(),
3030 p <<
" " << getMemRef() <<
'[';
3031 if (AffineMapAttr mapAttr =
3032 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3036 {getMapAttrStrName()});
3042 static LogicalResult
3045 MemRefType memrefType,
unsigned numIndexOperands) {
3048 return op->
emitOpError(
"affine map num results must equal memref rank");
3050 return op->
emitOpError(
"expects as many subscripts as affine map inputs");
3053 for (
auto idx : mapOperands) {
3054 if (!idx.getType().isIndex())
3055 return op->
emitOpError(
"index to load must have 'index' type");
3058 "index must be a valid dimension or symbol identifier");
3066 if (
getType() != memrefType.getElementType())
3067 return emitOpError(
"result type must match element type of memref");
3071 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3072 getMapOperands(), memrefType,
3073 getNumOperands() - 1)))
3081 results.
add<SimplifyAffineOp<AffineLoadOp>>(context);
3090 auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
3097 auto global = dyn_cast_or_null<memref::GlobalOp>(
3104 llvm::dyn_cast_or_null<DenseElementsAttr>(global.getConstantInitValue());
3108 if (
auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(cstAttr))
3109 return splatAttr.getSplatValue<
Attribute>();
3111 if (!getAffineMap().isConstant())
3113 auto indices = llvm::to_vector<4>(
3114 llvm::map_range(getAffineMap().getConstantResults(),
3115 [](int64_t v) -> uint64_t {
return v; }));
3116 return cstAttr.getValues<
Attribute>()[indices];
3126 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
3137 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
3138 int64_t rank = memrefType.getRank();
3143 build(builder, result, valueToStore, memref, map, indices);
3152 AffineMapAttr mapAttr;
3157 mapOperands, mapAttr, AffineStoreOp::getMapAttrStrName(),
3168 p <<
" " << getValueToStore();
3169 p <<
", " << getMemRef() <<
'[';
3170 if (AffineMapAttr mapAttr =
3171 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3175 {getMapAttrStrName()});
3182 if (getValueToStore().
getType() != memrefType.getElementType())
3184 "value to store must have the same type as memref element type");
3188 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3189 getMapOperands(), memrefType,
3190 getNumOperands() - 2)))
3198 results.
add<SimplifyAffineOp<AffineStoreOp>>(context);
3201 LogicalResult AffineStoreOp::fold(FoldAdaptor adaptor,
3211 template <
typename T>
3214 if (op.getNumOperands() !=
3215 op.getMap().getNumDims() + op.getMap().getNumSymbols())
3216 return op.emitOpError(
3217 "operand count and affine map dimension and symbol count must match");
3219 if (op.getMap().getNumResults() == 0)
3220 return op.emitOpError(
"affine map expect at least one result");
3224 template <
typename T>
3226 p <<
' ' << op->getAttr(T::getMapAttrStrName());
3227 auto operands = op.getOperands();
3228 unsigned numDims = op.getMap().getNumDims();
3229 p <<
'(' << operands.take_front(numDims) <<
')';
3231 if (operands.size() != numDims)
3232 p <<
'[' << operands.drop_front(numDims) <<
']';
3234 {T::getMapAttrStrName()});
3237 template <
typename T>
3244 AffineMapAttr mapAttr;
3260 template <
typename T>
3262 static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
3263 "expected affine min or max op");
3269 auto foldedMap = op.getMap().partialConstantFold(operands, &results);
3271 if (foldedMap.getNumSymbols() == 1 && foldedMap.isSymbolIdentity())
3272 return op.getOperand(0);
3275 if (results.empty()) {
3277 if (foldedMap == op.getMap())
3280 return op.getResult();
3284 auto resultIt = std::is_same<T, AffineMinOp>::value
3285 ? llvm::min_element(results)
3286 : llvm::max_element(results);
3287 if (resultIt == results.end())
3293 template <
typename T>
3299 AffineMap oldMap = affineOp.getAffineMap();
3305 if (!llvm::is_contained(newExprs, expr))
3306 newExprs.push_back(expr);
3336 template <
typename T>
3342 AffineMap oldMap = affineOp.getAffineMap();
3344 affineOp.getMapOperands().take_front(oldMap.
getNumDims());
3346 affineOp.getMapOperands().take_back(oldMap.
getNumSymbols());
3348 auto newDimOperands = llvm::to_vector<8>(dimOperands);
3349 auto newSymOperands = llvm::to_vector<8>(symOperands);
3357 if (
auto symExpr = dyn_cast<AffineSymbolExpr>(expr)) {
3358 Value symValue = symOperands[symExpr.getPosition()];
3360 producerOps.push_back(producerOp);
3363 }
else if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
3364 Value dimValue = dimOperands[dimExpr.getPosition()];
3366 producerOps.push_back(producerOp);
3373 newExprs.push_back(expr);
3376 if (producerOps.empty())
3383 for (T producerOp : producerOps) {
3384 AffineMap producerMap = producerOp.getAffineMap();
3385 unsigned numProducerDims = producerMap.
getNumDims();
3390 producerOp.getMapOperands().take_front(numProducerDims);
3392 producerOp.getMapOperands().take_back(numProducerSyms);
3393 newDimOperands.append(dimValues.begin(), dimValues.end());
3394 newSymOperands.append(symValues.begin(), symValues.end());
3398 newExprs.push_back(expr.shiftDims(numProducerDims, numUsedDims)
3399 .shiftSymbols(numProducerSyms, numUsedSyms));
3402 numUsedDims += numProducerDims;
3403 numUsedSyms += numProducerSyms;
3409 llvm::to_vector<8>(llvm::concat<Value>(newDimOperands, newSymOperands));
3428 if (!resultExpr.isPureAffine())
3433 if (failed(flattenResult))
3446 if (llvm::is_sorted(flattenedExprs))
3451 llvm::to_vector(llvm::seq<unsigned>(0, map.
getNumResults()));
3452 llvm::sort(resultPermutation, [&](
unsigned lhs,
unsigned rhs) {
3453 return flattenedExprs[lhs] < flattenedExprs[rhs];
3456 for (
unsigned idx : resultPermutation)
3477 template <
typename T>
3483 AffineMap map = affineOp.getAffineMap();
3491 template <
typename T>
3497 if (affineOp.getMap().getNumResults() != 1)
3500 affineOp.getOperands());
3528 return parseAffineMinMaxOp<AffineMinOp>(parser, result);
3556 return parseAffineMinMaxOp<AffineMaxOp>(parser, result);
3575 IntegerAttr hintInfo;
3577 StringRef readOrWrite, cacheType;
3579 AffineMapAttr mapAttr;
3583 AffinePrefetchOp::getMapAttrStrName(),
3589 AffinePrefetchOp::getLocalityHintAttrStrName(),
3599 if (readOrWrite !=
"read" && readOrWrite !=
"write")
3601 "rw specifier has to be 'read' or 'write'");
3602 result.
addAttribute(AffinePrefetchOp::getIsWriteAttrStrName(),
3605 if (cacheType !=
"data" && cacheType !=
"instr")
3607 "cache type has to be 'data' or 'instr'");
3609 result.
addAttribute(AffinePrefetchOp::getIsDataCacheAttrStrName(),
3616 p <<
" " << getMemref() <<
'[';
3617 AffineMapAttr mapAttr =
3618 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3621 p <<
']' <<
", " << (getIsWrite() ?
"write" :
"read") <<
", "
3622 <<
"locality<" << getLocalityHint() <<
">, "
3623 << (getIsDataCache() ?
"data" :
"instr");
3625 (*this)->getAttrs(),
3626 {getMapAttrStrName(), getLocalityHintAttrStrName(),
3627 getIsDataCacheAttrStrName(), getIsWriteAttrStrName()});
3632 auto mapAttr = (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3636 return emitOpError(
"affine.prefetch affine map num results must equal"
3639 return emitOpError(
"too few operands");
3641 if (getNumOperands() != 1)
3642 return emitOpError(
"too few operands");
3646 for (
auto idx : getMapOperands()) {
3649 "index must be a valid dimension or symbol identifier");
3657 results.
add<SimplifyAffineOp<AffinePrefetchOp>>(context);
3660 LogicalResult AffinePrefetchOp::fold(FoldAdaptor adaptor,
3675 auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) {
3679 build(builder, result, resultTypes, reductions, lbs, {}, ubs,
3689 assert(llvm::all_of(lbMaps,
3691 return m.getNumDims() == lbMaps[0].getNumDims() &&
3692 m.getNumSymbols() == lbMaps[0].getNumSymbols();
3694 "expected all lower bounds maps to have the same number of dimensions "
3696 assert(llvm::all_of(ubMaps,
3698 return m.getNumDims() == ubMaps[0].getNumDims() &&
3699 m.getNumSymbols() == ubMaps[0].getNumSymbols();
3701 "expected all upper bounds maps to have the same number of dimensions "
3703 assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) &&
3704 "expected lower bound maps to have as many inputs as lower bound "
3706 assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) &&
3707 "expected upper bound maps to have as many inputs as upper bound "
3715 for (arith::AtomicRMWKind reduction : reductions)
3716 reductionAttrs.push_back(
3728 groups.reserve(groups.size() + maps.size());
3729 exprs.reserve(maps.size());
3731 llvm::append_range(exprs, m.getResults());
3732 groups.push_back(m.getNumResults());
3734 return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs,
3740 AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups);
3741 AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups);
3759 for (
unsigned i = 0, e = steps.size(); i < e; ++i)
3761 if (resultTypes.empty())
3762 ensureTerminator(*bodyRegion, builder, result.
location);
3766 return {&getRegion()};
3769 unsigned AffineParallelOp::getNumDims() {
return getSteps().size(); }
3771 AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
3772 return getOperands().take_front(getLowerBoundsMap().getNumInputs());
3775 AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
3776 return getOperands().drop_front(getLowerBoundsMap().getNumInputs());
3779 AffineMap AffineParallelOp::getLowerBoundMap(
unsigned pos) {
3780 auto values = getLowerBoundsGroups().getValues<int32_t>();
3782 for (
unsigned i = 0; i < pos; ++i)
3784 return getLowerBoundsMap().getSliceMap(start, values[pos]);
3787 AffineMap AffineParallelOp::getUpperBoundMap(
unsigned pos) {
3788 auto values = getUpperBoundsGroups().getValues<int32_t>();
3790 for (
unsigned i = 0; i < pos; ++i)
3792 return getUpperBoundsMap().getSliceMap(start, values[pos]);
3796 return AffineValueMap(getLowerBoundsMap(), getLowerBoundsOperands());
3800 return AffineValueMap(getUpperBoundsMap(), getUpperBoundsOperands());
3803 std::optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
3804 if (hasMinMaxBounds())
3805 return std::nullopt;
3810 AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
3813 for (
unsigned i = 0, e = rangesValueMap.
getNumResults(); i < e; ++i) {
3814 auto expr = rangesValueMap.
getResult(i);
3815 auto cst = dyn_cast<AffineConstantExpr>(expr);
3817 return std::nullopt;
3818 out.push_back(cst.getValue());
3823 Block *AffineParallelOp::getBody() {
return &getRegion().
front(); }
3825 OpBuilder AffineParallelOp::getBodyBuilder() {
3826 return OpBuilder(getBody(), std::prev(getBody()->end()));
3831 "operands to map must match number of inputs");
3833 auto ubOperands = getUpperBoundsOperands();
3836 newOperands.append(ubOperands.begin(), ubOperands.end());
3837 (*this)->setOperands(newOperands);
3844 "operands to map must match number of inputs");
3847 newOperands.append(ubOperands.begin(), ubOperands.end());
3848 (*this)->setOperands(newOperands);
3854 setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
3859 arith::AtomicRMWKind op) {
3861 case arith::AtomicRMWKind::addf:
3862 return isa<FloatType>(resultType);
3863 case arith::AtomicRMWKind::addi:
3864 return isa<IntegerType>(resultType);
3865 case arith::AtomicRMWKind::assign:
3867 case arith::AtomicRMWKind::mulf:
3868 return isa<FloatType>(resultType);
3869 case arith::AtomicRMWKind::muli:
3870 return isa<IntegerType>(resultType);
3871 case arith::AtomicRMWKind::maximumf:
3872 return isa<FloatType>(resultType);
3873 case arith::AtomicRMWKind::minimumf:
3874 return isa<FloatType>(resultType);
3875 case arith::AtomicRMWKind::maxs: {
3876 auto intType = llvm::dyn_cast<IntegerType>(resultType);
3877 return intType && intType.isSigned();
3879 case arith::AtomicRMWKind::mins: {
3880 auto intType = llvm::dyn_cast<IntegerType>(resultType);
3881 return intType && intType.isSigned();
3883 case arith::AtomicRMWKind::maxu: {
3884 auto intType = llvm::dyn_cast<IntegerType>(resultType);
3885 return intType && intType.isUnsigned();
3887 case arith::AtomicRMWKind::minu: {
3888 auto intType = llvm::dyn_cast<IntegerType>(resultType);
3889 return intType && intType.isUnsigned();
3891 case arith::AtomicRMWKind::ori:
3892 return isa<IntegerType>(resultType);
3893 case arith::AtomicRMWKind::andi:
3894 return isa<IntegerType>(resultType);
3901 auto numDims = getNumDims();
3904 getSteps().size() != numDims || getBody()->getNumArguments() != numDims) {
3905 return emitOpError() <<
"the number of region arguments ("
3906 << getBody()->getNumArguments()
3907 <<
") and the number of map groups for lower ("
3908 << getLowerBoundsGroups().getNumElements()
3909 <<
") and upper bound ("
3910 << getUpperBoundsGroups().getNumElements()
3911 <<
"), and the number of steps (" << getSteps().size()
3912 <<
") must all match";
3915 unsigned expectedNumLBResults = 0;
3916 for (APInt v : getLowerBoundsGroups())
3917 expectedNumLBResults += v.getZExtValue();
3918 if (expectedNumLBResults != getLowerBoundsMap().getNumResults())
3919 return emitOpError() <<
"expected lower bounds map to have "
3920 << expectedNumLBResults <<
" results";
3921 unsigned expectedNumUBResults = 0;
3922 for (APInt v : getUpperBoundsGroups())
3923 expectedNumUBResults += v.getZExtValue();
3924 if (expectedNumUBResults != getUpperBoundsMap().getNumResults())
3925 return emitOpError() <<
"expected upper bounds map to have "
3926 << expectedNumUBResults <<
" results";
3928 if (getReductions().size() != getNumResults())
3929 return emitOpError(
"a reduction must be specified for each output");
3935 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
3936 if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
3937 return emitOpError(
"invalid reduction attribute");
3938 auto kind = arith::symbolizeAtomicRMWKind(intAttr.getInt()).value();
3940 return emitOpError(
"result type cannot match reduction attribute");
3946 getLowerBoundsMap().getNumDims())))
3950 getUpperBoundsMap().getNumDims())))
3955 LogicalResult AffineValueMap::canonicalize() {
3957 auto newMap = getAffineMap();
3959 if (newMap == getAffineMap() && newOperands == operands)
3961 reset(newMap, newOperands);
3974 if (!lbCanonicalized && !ubCanonicalized)
3977 if (lbCanonicalized)
3979 if (ubCanonicalized)
3985 LogicalResult AffineParallelOp::fold(FoldAdaptor adaptor,
3997 StringRef keyword) {
4000 ValueRange dimOperands = operands.take_front(numDims);
4001 ValueRange symOperands = operands.drop_front(numDims);
4003 for (llvm::APInt groupSize : group) {
4007 unsigned size = groupSize.getZExtValue();
4012 p << keyword <<
'(';
4022 p <<
" (" << getBody()->getArguments() <<
") = (";
4024 getLowerBoundsOperands(),
"max");
4027 getUpperBoundsOperands(),
"min");
4030 bool elideSteps = llvm::all_of(steps, [](int64_t step) {
return step == 1; });
4033 llvm::interleaveComma(steps, p);
4036 if (getNumResults()) {
4038 llvm::interleaveComma(getReductions(), p, [&](
auto &attr) {
4039 arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(
4040 llvm::cast<IntegerAttr>(attr).getInt());
4041 p <<
"\"" << arith::stringifyAtomicRMWKind(sym) <<
"\"";
4043 p <<
") -> (" << getResultTypes() <<
")";
4050 (*this)->getAttrs(),
4051 {AffineParallelOp::getReductionsAttrStrName(),
4052 AffineParallelOp::getLowerBoundsMapAttrStrName(),
4053 AffineParallelOp::getLowerBoundsGroupsAttrStrName(),
4054 AffineParallelOp::getUpperBoundsMapAttrStrName(),
4055 AffineParallelOp::getUpperBoundsGroupsAttrStrName(),
4056 AffineParallelOp::getStepsAttrStrName()});
4069 "expected operands to be dim or symbol expression");
4072 for (
const auto &list : operands) {
4076 for (
Value operand : valueOperands) {
4077 unsigned pos = std::distance(uniqueOperands.begin(),
4078 llvm::find(uniqueOperands, operand));
4079 if (pos == uniqueOperands.size())
4080 uniqueOperands.push_back(operand);
4081 replacements.push_back(
4091 enum class MinMaxKind { Min, Max };
4115 const llvm::StringLiteral tmpAttrStrName =
"__pseudo_bound_map";
4117 StringRef mapName = kind == MinMaxKind::Min
4118 ? AffineParallelOp::getUpperBoundsMapAttrStrName()
4119 : AffineParallelOp::getLowerBoundsMapAttrStrName();
4120 StringRef groupsName =
4121 kind == MinMaxKind::Min
4122 ? AffineParallelOp::getUpperBoundsGroupsAttrStrName()
4123 : AffineParallelOp::getLowerBoundsGroupsAttrStrName();
4140 auto parseOperands = [&]() {
4142 kind == MinMaxKind::Min ?
"min" :
"max"))) {
4143 mapOperands.clear();
4150 llvm::append_range(flatExprs, map.getValue().getResults());
4152 auto dimsRef = operandsRef.take_front(map.getValue().getNumDims());
4154 auto symsRef = operandsRef.drop_front(map.getValue().getNumDims());
4156 flatDimOperands.append(map.getValue().getNumResults(), dims);
4157 flatSymOperands.append(map.getValue().getNumResults(), syms);
4158 numMapsPerGroup.push_back(map.getValue().getNumResults());
4161 flatSymOperands.emplace_back(),
4162 flatExprs.emplace_back())))
4164 numMapsPerGroup.push_back(1);
4171 unsigned totalNumDims = 0;
4172 unsigned totalNumSyms = 0;
4173 for (
unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
4174 unsigned numDims = flatDimOperands[i].size();
4175 unsigned numSyms = flatSymOperands[i].size();
4176 flatExprs[i] = flatExprs[i]
4177 .shiftDims(numDims, totalNumDims)
4178 .shiftSymbols(numSyms, totalNumSyms);
4179 totalNumDims += numDims;
4180 totalNumSyms += numSyms;
4192 result.
operands.append(dimOperands.begin(), dimOperands.end());
4193 result.
operands.append(symOperands.begin(), symOperands.end());
4196 auto flatMap =
AffineMap::get(totalNumDims, totalNumSyms, flatExprs,
4198 flatMap = flatMap.replaceDimsAndSymbols(
4199 dimRplacements, symRepacements, dimOperands.size(), symOperands.size());
4223 AffineMapAttr stepsMapAttr;
4228 result.
addAttribute(AffineParallelOp::getStepsAttrStrName(),
4232 AffineParallelOp::getStepsAttrStrName(),
4239 auto stepsMap = stepsMapAttr.getValue();
4240 for (
const auto &result : stepsMap.getResults()) {
4241 auto constExpr = dyn_cast<AffineConstantExpr>(result);
4244 "steps must be constant integers");
4245 steps.push_back(constExpr.getValue());
4247 result.
addAttribute(AffineParallelOp::getStepsAttrStrName(),
4257 auto parseAttributes = [&]() -> ParseResult {
4267 std::optional<arith::AtomicRMWKind> reduction =
4268 arith::symbolizeAtomicRMWKind(attrVal.getValue());
4270 return parser.
emitError(loc,
"invalid reduction value: ") << attrVal;
4271 reductions.push_back(
4279 result.
addAttribute(AffineParallelOp::getReductionsAttrStrName(),
4288 for (
auto &iv : ivs)
4289 iv.type = indexType;
4295 AffineParallelOp::ensureTerminator(*body, builder, result.
location);
4304 auto *parentOp = (*this)->getParentOp();
4305 auto results = parentOp->getResults();
4306 auto operands = getOperands();
4308 if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
4309 return emitOpError() <<
"only terminates affine.if/for/parallel regions";
4310 if (parentOp->getNumResults() != getNumOperands())
4311 return emitOpError() <<
"parent of yield must have same number of "
4312 "results as the yield operands";
4313 for (
auto it : llvm::zip(results, operands)) {
4315 return emitOpError() <<
"types mismatch between yield op and its parent";
4328 assert(operands.size() == 1 + map.
getNumInputs() &&
"inconsistent operands");
4332 result.
types.push_back(resultType);
4336 VectorType resultType,
Value memref,
4338 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
4342 result.
types.push_back(resultType);
4346 VectorType resultType,
Value memref,
4348 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
4349 int64_t rank = memrefType.getRank();
4354 build(builder, result, resultType, memref, map, indices);
4357 void AffineVectorLoadOp::getCanonicalizationPatterns(
RewritePatternSet &results,
4359 results.
add<SimplifyAffineOp<AffineVectorLoadOp>>(context);
4367 MemRefType memrefType;
4368 VectorType resultType;
4370 AffineMapAttr mapAttr;
4375 AffineVectorLoadOp::getMapAttrStrName(),
4386 p <<
" " << getMemRef() <<
'[';
4387 if (AffineMapAttr mapAttr =
4388 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4392 {getMapAttrStrName()});
4398 VectorType vectorType) {
4400 if (memrefType.getElementType() != vectorType.getElementType())
4402 "requires memref and vector types of the same elemental type");
4410 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4411 getMapOperands(), memrefType,
4412 getNumOperands() - 1)))
4428 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
4439 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
4440 int64_t rank = memrefType.getRank();
4445 build(builder, result, valueToStore, memref, map, indices);
4447 void AffineVectorStoreOp::getCanonicalizationPatterns(
4449 results.
add<SimplifyAffineOp<AffineVectorStoreOp>>(context);
4456 MemRefType memrefType;
4457 VectorType resultType;
4460 AffineMapAttr mapAttr;
4466 AffineVectorStoreOp::getMapAttrStrName(),
4477 p <<
" " << getValueToStore();
4478 p <<
", " << getMemRef() <<
'[';
4479 if (AffineMapAttr mapAttr =
4480 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4484 {getMapAttrStrName()});
4485 p <<
" : " <<
getMemRefType() <<
", " << getValueToStore().getType();
4491 *
this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4492 getMapOperands(), memrefType,
4493 getNumOperands() - 2)))
4506 LogicalResult AffineDelinearizeIndexOp::inferReturnTypes(
4507 MLIRContext *context, std::optional<::mlir::Location> location,
4510 AffineDelinearizeIndexOpAdaptor adaptor(operands, attributes, properties,
4512 inferredReturnTypes.assign(adaptor.getStaticBasis().size(),
4517 void AffineDelinearizeIndexOp::build(
OpBuilder &odsBuilder,
4524 build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis);
4527 void AffineDelinearizeIndexOp::build(
OpBuilder &odsBuilder,
4534 build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis);
4537 void AffineDelinearizeIndexOp::build(
OpBuilder &odsBuilder,
4541 build(odsBuilder, odsState, linearIndex,
ValueRange{}, basis);
4545 if (getStaticBasis().empty())
4546 return emitOpError(
"basis should not be empty");
4547 if (getNumResults() != getStaticBasis().size())
4548 return emitOpError(
"should return an index for each basis element");
4549 auto dynamicMarkersCount =
4550 llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
4551 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
4553 "mismatch between dynamic and static basis (kDynamic marker but no "
4554 "corresponding dynamic basis entry) -- this can only happen due to an "
4555 "incorrect fold/rewrite");
4562 struct DropUnitExtentBasis
4566 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4569 std::optional<Value> zero = std::nullopt;
4570 Location loc = delinearizeOp->getLoc();
4573 zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
4574 return zero.value();
4580 for (
auto [index, basis] :
llvm::enumerate(delinearizeOp.getMixedBasis())) {
4582 if (basisVal && *basisVal == 1)
4583 replacements[index] =
getZero();
4585 newOperands.push_back(basis);
4588 if (newOperands.size() == delinearizeOp.getStaticBasis().size())
4591 if (!newOperands.empty()) {
4592 auto newDelinearizeOp = rewriter.
create<affine::AffineDelinearizeIndexOp>(
4593 loc, delinearizeOp.getLinearIndex(), newOperands);
4596 for (
auto &replacement : replacements) {
4599 replacement = newDelinearizeOp->
getResult(newIndex++);
4603 rewriter.
replaceOp(delinearizeOp, replacements);
4615 struct DropDelinearizeOneBasisElement
4619 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4621 if (delinearizeOp.getStaticBasis().size() != 1)
4623 rewriter.
replaceOp(delinearizeOp, delinearizeOp.getLinearIndex());
4630 void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
4632 patterns.
insert<DropDelinearizeOneBasisElement, DropUnitExtentBasis>(context);
4639 void AffineLinearizeIndexOp::build(
OpBuilder &odsBuilder,
4647 build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
4650 void AffineLinearizeIndexOp::build(
OpBuilder &odsBuilder,
4658 build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
4661 void AffineLinearizeIndexOp::build(
OpBuilder &odsBuilder,
4665 build(odsBuilder, odsState, multiIndex,
ValueRange{}, basis, disjoint);
4669 if (getStaticBasis().empty())
4670 return emitOpError(
"basis should not be empty");
4672 if (getMultiIndex().size() != getStaticBasis().size())
4673 return emitOpError(
"should be passed an index for each basis element");
4675 auto dynamicMarkersCount =
4676 llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
4677 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
4679 "mismatch between dynamic and static basis (kDynamic marker but no "
4680 "corresponding dynamic basis entry) -- this can only happen due to an "
4681 "incorrect fold/rewrite");
4697 struct DropLinearizeUnitComponentsIfDisjointOrZero final
4701 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
4703 size_t numIndices = op.getMultiIndex().size();
4705 newIndices.reserve(numIndices);
4707 newBasis.reserve(numIndices);
4710 for (
auto [index, basisElem] : llvm::zip_equal(op.getMultiIndex(), basis)) {
4712 if (!basisEntry || *basisEntry != 1) {
4713 newIndices.push_back(index);
4714 newBasis.push_back(basisElem);
4719 if (!op.getDisjoint() && (!indexValue || *indexValue != 0)) {
4720 newIndices.push_back(index);
4721 newBasis.push_back(basisElem);
4725 if (newIndices.size() == numIndices)
4728 if (newIndices.size() == 0) {
4733 op, newIndices, newBasis, op.getDisjoint());
4739 void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
4741 patterns.
add<DropLinearizeUnitComponentsIfDisjointOrZero>(context);
4748 #define GET_OP_CLASSES
4749 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
static AffineForOp buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb, int64_t ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds known to be constants.
static bool hasTrivialZeroTripCount(AffineForOp op)
Returns true if the affine.for has zero iterations in trivial cases.
static void composeMultiResultAffineMap(AffineMap &map, SmallVectorImpl< Value > &operands)
Composes the given affine map with the given list of operands, pulling in the maps from any affine....
static void printAffineMinMaxOp(OpAsmPrinter &p, T op)
static bool isResultTypeMatchAtomicRMWKind(Type resultType, arith::AtomicRMWKind op)
static bool remainsLegalAfterInline(Value value, Region *src, Region *dest, const IRMapping &mapping, function_ref< bool(Value, Region *)> legalityCheck)
Checks if value known to be a legal affine dimension or symbol in src region remains legal if the ope...
static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr, DenseIntElementsAttr group, ValueRange operands, StringRef keyword)
Prints a lower(upper) bound of an affine parallel loop with max(min) conditions in it.
static void LLVM_ATTRIBUTE_UNUSED simplifyMapWithOperands(AffineMap &map, ArrayRef< Value > operands)
Simplify the map while exploiting information on the values in operands.
static OpFoldResult foldMinMaxOp(T op, ArrayRef< Attribute > operands)
Fold an affine min or max operation with the given operands.
static LogicalResult canonicalizeLoopBounds(AffineForOp forOp)
Canonicalize the bounds of the given loop.
static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims, unsigned numSymbols, ArrayRef< Value > operands)
Simplify expr while exploiting information from the values in operands.
static bool isValidAffineIndexOperand(Value value, Region *region)
static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
static void composeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Iterate over operands and fold away all those produced by an AffineApplyOp iteratively.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static ParseResult parseBound(bool isLower, OperationState &result, OpAsmParser &p)
Parse a for operation loop bounds.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static void composeSetAndOperands(IntegerSet &set, SmallVectorImpl< Value > &operands)
Compose any affine.apply ops feeding into operands of the integer set set by composing the maps of su...
static LogicalResult replaceDimOrSym(AffineMap *map, unsigned dimOrSymbolPosition, SmallVectorImpl< Value > &dims, SmallVectorImpl< Value > &syms)
Replace all occurrences of AffineExpr at position pos in map by the defining AffineApplyOp expression...
static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType)
Verify common invariants of affine.vector_load and affine.vector_store.
static void simplifyMinOrMaxExprWithOperands(AffineMap &map, ArrayRef< Value > operands, bool isMax)
Simplify the expressions in map while making use of lower or upper bounds of its operands.
static ParseResult parseAffineMinMaxOp(OpAsmParser &parser, OperationState &result)
static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, Region *region)
Returns true if the 'index' dimension of the memref defined by memrefDefOp is a statically shaped one...
static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef< Value > operands, int64_t k)
Check if e is known to be: 0 <= e < k.
static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser, OperationState &result, MinMaxKind kind)
Parses an affine map that can contain a min/max for groups of its results, e.g., max(expr-1,...
static AffineForOp buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds that may or may not be constants.
static void printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, OpAsmPrinter &printer)
Prints dimension and symbol list.
static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef< Value > operands)
Returns the largest known divisor of e.
static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
static void buildAffineLoopNestImpl(OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, LoopCreatorTy &&loopCreatorFn)
Builds an affine loop nest, using "loopCreatorFn" to create individual loop operations.
static LogicalResult foldLoopBounds(AffineForOp forOp)
Fold the constant bounds of a loop.
static LogicalResult verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, unsigned numDims)
Utility function to verify that a set of operands are valid dimension and symbol identifiers.
static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region)
Returns true if the result of the dim op is a valid symbol for region.
static bool isQTimesDPlusR(AffineExpr e, ArrayRef< Value > operands, int64_t &div, AffineExpr "ientTimesDiv, AffineExpr &rem)
Check if expression e is of the form d*e_1 + e_2 where 0 <= e_2 < d.
static ParseResult deduplicateAndResolveOperands(OpAsmParser &parser, ArrayRef< SmallVector< OpAsmParser::UnresolvedOperand >> operands, SmallVectorImpl< Value > &uniqueOperands, SmallVectorImpl< AffineExpr > &replacements, AffineExprKind kind)
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult verifyAffineMinMaxOp(T op)
static void printBound(AffineMapAttr boundMap, Operation::operand_range boundOperands, const char *prefix, OpAsmPrinter &p)
static LogicalResult verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr, Operation::operand_range mapOperands, MemRefType memrefType, unsigned numIndexOperands)
Verify common indexing invariants of affine.load, affine.store, affine.vector_load and affine....
static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map)
Canonicalize the result expression order of an affine map and return success if the order changed.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static Operation::operand_range getLowerBoundOperands(AffineForOp forOp)
static Operation::operand_range getUpperBoundOperands(AffineForOp forOp)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
AffineExprKind getKind() const
Return the classification for this type.
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
MLIRContext * getContext() const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
MLIRContext * getContext() const
bool isFunctionOfDim(unsigned position) const
Return true if any affine expression involves AffineDimExpr position.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
bool isFunctionOfSymbol(unsigned position) const
Return true if any affine expression involves AffineSymbolExpr position.
unsigned getNumResults() const
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
unsigned getNumInputs() const
AffineMap shiftSymbols(unsigned shift, unsigned offset=0) const
Replace symbols[offset ...
AffineExpr getResult(unsigned idx) const
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
static AffineMap getConstantMap(int64_t val, MLIRContext *context)
Returns a single constant result affine map.
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
LogicalResult constantFold(ArrayRef< Attribute > operandConstants, SmallVectorImpl< Attribute > &results, bool *hasPoison=nullptr) const
Folds the results of the application of an affine map on the provided operands to a constant if possi...
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
AffineMap getDimIdentityMap()
AffineMap getMultiDimIdentityMap(unsigned rank)
DenseIntElementsAttr getI32TensorAttr(ArrayRef< int32_t > values)
Tensor-typed DenseIntElementsAttr getters.
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
AffineMap getEmptyAffineMap()
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
MLIRContext * getContext() const
AffineMap getSymbolIdentityMap()
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
An attribute that represents a reference to a dense integer vector or tensor object.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
An integer set representing a conjunction of one or more affine equalities and inequalities.
unsigned getNumDims() const
static IntegerSet get(unsigned dimCount, unsigned symbolCount, ArrayRef< AffineExpr > constraints, ArrayRef< bool > eqFlags)
MLIRContext * getContext() const
unsigned getNumInputs() const
ArrayRef< AffineExpr > getConstraints() const
ArrayRef< bool > getEqFlags() const
Returns the equality bits, which specify whether each of the constraints is an equality or inequality...
unsigned getNumSymbols() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void pop_back()
Pop last element from list.
Attribute erase(StringAttr name)
Erase the attribute with the given name from the list.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseAffineMapOfSSAIds(SmallVectorImpl< UnresolvedOperand > &operands, Attribute &map, StringRef attrName, NamedAttrList &attrs, Delimiter delimiter=Delimiter::Square)=0
Parses an affine map attribute where dims and symbols are SSA operands.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseAffineExprOfSSAIds(SmallVectorImpl< UnresolvedOperand > &dimOperands, SmallVectorImpl< UnresolvedOperand > &symbOperands, AffineExpr &expr)=0
Parses an affine expression where dims and symbols are SSA operands.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, ValueRange symOperands)=0
Prints an affine expression of SSA ids with SSA id names used instead of dims and symbols.
virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, ValueRange operands)=0
Prints an affine map of SSA ids, where SSA id names are used in place of dims/symbols.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
A trait of region holding operations that defines a new scope for polyhedral optimization purposes.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
operand_range::iterator operand_iterator
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
std::vector< SmallVector< int64_t, 8 > > operandExprStack
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
AffineBound represents a lower or upper bound in the for operation.
AffineDmaStartOp starts a non-blocking DMA operation that transfers data from a source memref to a de...
AffineDmaWaitOp blocks until the completion of a DMA operation associated with the tag element 'tag[i...
An AffineValueMap is an affine map plus its ML value operands and results for analysis purposes.
LogicalResult canonicalize()
Attempts to canonicalize the map and operands.
ArrayRef< Value > getOperands() const
AffineExpr getResult(unsigned i)
AffineMap getAffineMap() const
unsigned getNumResults() const
Operation * getOwner() const
Return the owner of this operand.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
void buildAffineLoopNest(OpBuilder &builder, Location loc, ArrayRef< int64_t > lbs, ArrayRef< int64_t > ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn=nullptr)
Builds a perfect nest of affine.for loops, i.e., each loop except the innermost one contains only ano...
void fullyComposeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Given an affine map map and its input operands, this method composes into map, maps of AffineApplyOps...
void extractForInductionVars(ArrayRef< AffineForOp > forInsts, SmallVectorImpl< Value > *ivs)
Extracts the induction variables from a list of AffineForOps and places them in the output argument i...
bool isValidDim(Value value)
Returns true if the given Value can be used as a dimension id in the region of the closest surroundin...
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
bool isAffineInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp or AffineParallelOp.
AffineForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
bool isAffineForInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp.
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
bool isTopLevelValue(Value value)
A utility function to check if a value is defined at the top level of an op with trait AffineScope or...
void canonicalizeSetAndOperands(IntegerSet *set, SmallVectorImpl< Value > *operands)
Canonicalizes an integer set the same way canonicalizeMapAndOperands does for affine maps.
void extractInductionVars(ArrayRef< Operation * > affineOps, SmallVectorImpl< Value > &ivs)
Extracts the induction variables from a list of either AffineForOp or AffineParallelOp and places the...
bool isValidSymbol(Value value)
Returns true if the given value can be used as a symbol in the region of the closest surrounding op t...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
AffineParallelOp getAffineParallelInductionVarOwner(Value val)
Returns true if the provided value is among the induction variables of an AffineParallelOp.
Region * getAffineScope(Operation *op)
Returns the closest region enclosing op that is held by an operation with trait AffineScope; nullptr ...
ParseResult parseDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl< Value > &operands, unsigned &numDims)
Parses dimension and symbol list.
bool isAffineParallelInductionVar(Value val)
Returns true if val is the induction variable of an AffineParallelOp.
AffineMinOp makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns an AffineMinOp obtained by composing map and operands with AffineApplyOps supplying those ope...
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
AffineMap removeDuplicateExprs(AffineMap map)
Returns a map with the same dimension and symbol count as map, but whose results are the unique affin...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::optional< int64_t > getBoundForAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols, ArrayRef< std::optional< int64_t >> constLowerBounds, ArrayRef< std::optional< int64_t >> constUpperBounds, bool isUpper)
Get a lower or upper (depending on isUpper) bound for expr while using the constant lower and upper b...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ SymbolId
Symbolic identifier.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool isStrided(MemRefType t)
Return "true" if the layout for t is compatible with strided semantics.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
AffineMap foldAttributesIntoMap(Builder &b, AffineMap map, ArrayRef< OpFoldResult > operands, SmallVector< Value > &remainingValues)
Fold all attributes among the given operands into the affine map.
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Canonicalize the affine map result expression order of an affine min/max operation.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Remove duplicated expressions in affine min/max ops.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Merge an affine min/max op to its consumers if its consumer is also an affine min/max op.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.