23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/ScopeExit.h"
25 #include "llvm/ADT/SmallBitVector.h"
26 #include "llvm/ADT/SmallVectorExtras.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/MathExtras.h"
36 using llvm::divideCeilSigned;
37 using llvm::divideFloorSigned;
40 #define DEBUG_TYPE "affine-ops"
42 #include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
49 if (
auto arg = llvm::dyn_cast<BlockArgument>(value))
50 return arg.getParentRegion() == region;
73 if (llvm::isa<BlockArgument>(value))
74 return legalityCheck(mapping.
lookup(value), dest);
81 bool isDimLikeOp = isa<ShapedDimOpInterface>(value.
getDefiningOp());
92 return llvm::all_of(values, [&](
Value v) {
99 template <
typename OpTy>
102 static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
103 AffineWriteOpInterface>::value,
104 "only ops with affine read/write interface are supported");
111 dimOperands, src, dest, mapping,
115 symbolOperands, src, dest, mapping,
132 op.getMapOperands(), src, dest, mapping,
137 op.getMapOperands(), src, dest, mapping,
164 if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
169 if (!llvm::hasSingleElement(*src))
177 if (
auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
178 if (iface.hasNoEffect())
186 .Case<AffineApplyOp, AffineReadOpInterface,
187 AffineWriteOpInterface>([&](
auto op) {
212 isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
216 bool shouldAnalyzeRecursively(
Operation *op)
const final {
return true; }
224 void AffineDialect::initialize() {
227 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
229 addInterfaces<AffineInlinerInterface>();
230 declarePromisedInterfaces<ValueBoundsOpInterface, AffineApplyOp, AffineMaxOp,
239 if (
auto poison = dyn_cast<ub::PoisonAttr>(value))
240 return builder.
create<ub::PoisonOp>(loc, type, poison);
241 return arith::ConstantOp::materialize(builder, value, type, loc);
249 if (
auto arg = llvm::dyn_cast<BlockArgument>(value)) {
265 while (
auto *parentOp = curOp->getParentOp()) {
288 auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
290 isa<AffineForOp, AffineParallelOp>(parentOp));
311 auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->
getParentOp();
312 return isa<AffineForOp, AffineParallelOp>(parentOp);
316 if (
auto applyOp = dyn_cast<AffineApplyOp>(op))
317 return applyOp.isValidDim(region);
320 if (
auto dimOp = dyn_cast<ShapedDimOpInterface>(op))
328 template <
typename AnyMemRefDefOp>
331 MemRefType memRefType = memrefDefOp.getType();
334 if (index >= memRefType.getRank()) {
339 if (!memRefType.isDynamicDim(index))
342 unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
343 return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
355 if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
363 if (!index.has_value())
367 Operation *op = dimOp.getShapedValue().getDefiningOp();
368 while (
auto castOp = dyn_cast<memref::CastOp>(op)) {
370 if (isa<UnrankedMemRefType>(castOp.getSource().getType()))
372 op = castOp.getSource().getDefiningOp();
377 int64_t i = index.value();
379 .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
381 .Default([](
Operation *) {
return false; });
447 if (
auto applyOp = dyn_cast<AffineApplyOp>(defOp))
448 return applyOp.isValidSymbol(region);
451 if (
auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp))
475 printer <<
'(' << operands.take_front(numDims) <<
')';
476 if (operands.size() > numDims)
477 printer <<
'[' << operands.drop_front(numDims) <<
']';
487 numDims = opInfos.size();
501 template <
typename OpTy>
506 for (
auto operand : operands) {
507 if (opIt++ < numDims) {
509 return op.emitOpError(
"operand cannot be used as a dimension id");
511 return op.emitOpError(
"operand cannot be used as a symbol");
522 return AffineValueMap(getAffineMap(), getOperands(), getResult());
529 AffineMapAttr mapAttr;
535 auto map = mapAttr.getValue();
537 if (map.getNumDims() != numDims ||
538 numDims + map.getNumSymbols() != result.
operands.size()) {
540 "dimension or symbol index mismatch");
543 result.
types.append(map.getNumResults(), indexTy);
548 p <<
" " << getMapAttr();
550 getAffineMap().getNumDims(), p);
561 "operand count and affine map dimension and symbol count must match");
565 return emitOpError(
"mapping must produce one value");
573 return llvm::all_of(getOperands(),
581 return llvm::all_of(getOperands(),
588 return llvm::all_of(getOperands(),
595 return llvm::all_of(getOperands(), [&](
Value operand) {
601 auto map = getAffineMap();
604 auto expr = map.getResult(0);
605 if (
auto dim = dyn_cast<AffineDimExpr>(expr))
606 return getOperand(dim.getPosition());
607 if (
auto sym = dyn_cast<AffineSymbolExpr>(expr))
608 return getOperand(map.getNumDims() + sym.getPosition());
612 bool hasPoison =
false;
614 map.constantFold(adaptor.getMapOperands(), result, &hasPoison);
617 if (failed(foldResult))
634 auto dimExpr = dyn_cast<AffineDimExpr>(e);
644 Value operand = operands[dimExpr.getPosition()];
645 int64_t operandDivisor = 1;
649 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() == 0) {
650 operandDivisor = forOp.getStepAsInt();
652 uint64_t lbLargestKnownDivisor =
653 forOp.getLowerBoundMap().getLargestKnownDivisorOfMapExprs();
654 operandDivisor = std::gcd(lbLargestKnownDivisor, forOp.getStepAsInt());
657 return operandDivisor;
664 if (
auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
665 int64_t constVal = constExpr.getValue();
666 return constVal >= 0 && constVal < k;
668 auto dimExpr = dyn_cast<AffineDimExpr>(e);
671 Value operand = operands[dimExpr.getPosition()];
675 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() >= 0 &&
676 forOp.hasConstantUpperBound() && forOp.getConstantUpperBound() <= k) {
692 auto bin = dyn_cast<AffineBinaryOpExpr>(e);
700 quotientTimesDiv = llhs;
706 quotientTimesDiv = rlhs;
716 if (forOp && forOp.hasConstantLowerBound())
717 return forOp.getConstantLowerBound();
724 if (!forOp || !forOp.hasConstantUpperBound())
729 if (forOp.hasConstantLowerBound()) {
730 return forOp.getConstantUpperBound() - 1 -
731 (forOp.getConstantUpperBound() - forOp.getConstantLowerBound() - 1) %
732 forOp.getStepAsInt();
734 return forOp.getConstantUpperBound() - 1;
745 constLowerBounds.reserve(operands.size());
746 constUpperBounds.reserve(operands.size());
747 for (
Value operand : operands) {
752 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr))
753 return constExpr.getValue();
768 constLowerBounds.reserve(operands.size());
769 constUpperBounds.reserve(operands.size());
770 for (
Value operand : operands) {
775 std::optional<int64_t> lowerBound;
776 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
777 lowerBound = constExpr.getValue();
780 constLowerBounds, constUpperBounds,
791 auto binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
802 binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
810 lhs = binExpr.getLHS();
811 rhs = binExpr.getRHS();
812 auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
816 int64_t rhsConstVal = rhsConst.getValue();
818 if (rhsConstVal <= 0)
823 std::optional<int64_t> lhsLbConst =
825 std::optional<int64_t> lhsUbConst =
827 if (lhsLbConst && lhsUbConst) {
828 int64_t lhsLbConstVal = *lhsLbConst;
829 int64_t lhsUbConstVal = *lhsUbConst;
833 divideFloorSigned(lhsLbConstVal, rhsConstVal) ==
834 divideFloorSigned(lhsUbConstVal, rhsConstVal)) {
836 divideFloorSigned(lhsLbConstVal, rhsConstVal), context);
842 divideCeilSigned(lhsLbConstVal, rhsConstVal) ==
843 divideCeilSigned(lhsUbConstVal, rhsConstVal)) {
850 lhsLbConstVal < rhsConstVal && lhsUbConstVal < rhsConstVal) {
862 if (
isQTimesDPlusR(lhs, operands, divisor, quotientTimesDiv, rem)) {
863 if (rhsConstVal % divisor == 0 &&
865 expr = quotientTimesDiv.
floorDiv(rhsConst);
866 }
else if (divisor % rhsConstVal == 0 &&
868 expr = rem % rhsConst;
894 if (operands.empty())
900 constLowerBounds.reserve(operands.size());
901 constUpperBounds.reserve(operands.size());
902 for (
Value operand : operands) {
916 if (
auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
917 lowerBounds.push_back(constExpr.getValue());
918 upperBounds.push_back(constExpr.getValue());
920 lowerBounds.push_back(
922 constLowerBounds, constUpperBounds,
924 upperBounds.push_back(
926 constLowerBounds, constUpperBounds,
935 unsigned i = exprEn.index();
937 if (lowerBounds[i] && upperBounds[i] && *lowerBounds[i] == *upperBounds[i])
942 if (!upperBounds[i]) {
943 irredundantExprs.push_back(e);
949 auto otherLowerBound = en.value();
950 unsigned pos = en.index();
951 if (pos == i || !otherLowerBound)
953 if (*otherLowerBound > *upperBounds[i])
955 if (*otherLowerBound < *upperBounds[i])
960 if (upperBounds[pos] && lowerBounds[i] &&
961 lowerBounds[i] == upperBounds[i] &&
962 otherLowerBound == *upperBounds[pos] && i < pos)
966 irredundantExprs.push_back(e);
968 if (!lowerBounds[i]) {
969 irredundantExprs.push_back(e);
974 auto otherUpperBound = en.value();
975 unsigned pos = en.index();
976 if (pos == i || !otherUpperBound)
978 if (*otherUpperBound < *lowerBounds[i])
980 if (*otherUpperBound > *lowerBounds[i])
982 if (lowerBounds[pos] && upperBounds[i] &&
983 lowerBounds[i] == upperBounds[i] &&
984 otherUpperBound == lowerBounds[pos] && i < pos)
988 irredundantExprs.push_back(e);
1000 static void LLVM_ATTRIBUTE_UNUSED
1002 assert(map.
getNumInputs() == operands.size() &&
"invalid operands for map");
1008 newResults.push_back(expr);
1025 unsigned dimOrSymbolPosition,
1029 bool isDimReplacement = (dimOrSymbolPosition < dims.size());
1030 unsigned pos = isDimReplacement ? dimOrSymbolPosition
1031 : dimOrSymbolPosition - dims.size();
1032 Value &v = isDimReplacement ? dims[pos] : syms[pos];
1045 AffineMap composeMap = affineApply.getAffineMap();
1046 assert(composeMap.
getNumResults() == 1 &&
"affine.apply with >1 results");
1048 affineApply.getMapOperands().end());
1062 dims.append(composeDims.begin(), composeDims.end());
1063 syms.append(composeSyms.begin(), composeSyms.end());
1064 *map = map->
replace(toReplace, replacementExpr, dims.size(), syms.size());
1093 for (
unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
1105 unsigned nDims = 0, nSyms = 0;
1107 dimReplacements.reserve(dims.size());
1108 symReplacements.reserve(syms.size());
1109 for (
auto *container : {&dims, &syms}) {
1110 bool isDim = (container == &dims);
1111 auto &repls = isDim ? dimReplacements : symReplacements;
1113 Value v = en.value();
1117 "map is function of unexpected expr@pos");
1123 operands->push_back(v);
1136 while (llvm::any_of(*operands, [](
Value v) {
1150 return b.
create<AffineApplyOp>(loc, map, valueOperands);
1172 for (
unsigned i : llvm::seq<unsigned>(0, map.
getNumResults())) {
1179 llvm::append_range(dims,
1181 llvm::append_range(symbols,
1188 operands = llvm::to_vector(llvm::concat<Value>(dims, symbols));
1197 assert(map.
getNumResults() == 1 &&
"building affine.apply with !=1 result");
1207 AffineApplyOp applyOp =
1212 for (
unsigned i = 0, e = constOperands.size(); i != e; ++i)
1217 if (failed(applyOp->fold(constOperands, foldResults)) ||
1218 foldResults.empty()) {
1220 listener->notifyOperationInserted(applyOp, {});
1221 return applyOp.getResult();
1225 assert(foldResults.size() == 1 &&
"expected 1 folded result");
1226 return foldResults.front();
1244 return llvm::map_to_vector(llvm::seq<unsigned>(0, map.
getNumResults()),
1246 return makeComposedFoldedAffineApply(
1247 b, loc, map.getSubMap({i}), operands);
1251 template <
typename OpTy>
1263 return makeComposedMinMax<AffineMinOp>(b, loc, map, operands);
1266 template <
typename OpTy>
1278 auto minMaxOp = makeComposedMinMax<OpTy>(newBuilder, loc, map, operands);
1282 for (
unsigned i = 0, e = constOperands.size(); i != e; ++i)
1287 if (failed(minMaxOp->fold(constOperands, foldResults)) ||
1288 foldResults.empty()) {
1290 listener->notifyOperationInserted(minMaxOp, {});
1291 return minMaxOp.getResult();
1295 assert(foldResults.size() == 1 &&
"expected 1 folded result");
1296 return foldResults.front();
1303 return makeComposedFoldedMinMax<AffineMinOp>(b, loc, map, operands);
1310 return makeComposedFoldedMinMax<AffineMaxOp>(b, loc, map, operands);
1315 template <
class MapOrSet>
1318 if (!mapOrSet || operands->empty())
1321 assert(mapOrSet->getNumInputs() == operands->size() &&
1322 "map/set inputs must match number of operands");
1324 auto *context = mapOrSet->getContext();
1326 resultOperands.reserve(operands->size());
1328 remappedSymbols.reserve(operands->size());
1329 unsigned nextDim = 0;
1330 unsigned nextSym = 0;
1331 unsigned oldNumSyms = mapOrSet->getNumSymbols();
1333 for (
unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
1334 if (i < mapOrSet->getNumDims()) {
1338 remappedSymbols.push_back((*operands)[i]);
1341 resultOperands.push_back((*operands)[i]);
1344 resultOperands.push_back((*operands)[i]);
1348 resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
1349 *operands = resultOperands;
1350 *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
1351 oldNumSyms + nextSym);
1353 assert(mapOrSet->getNumInputs() == operands->size() &&
1354 "map/set inputs must match number of operands");
1358 template <
class MapOrSet>
1361 static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
1362 "Argument must be either of AffineMap or IntegerSet type");
1364 if (!mapOrSet || operands->empty())
1367 assert(mapOrSet->getNumInputs() == operands->size() &&
1368 "map/set inputs must match number of operands");
1370 canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
1373 llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
1374 llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
1376 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr))
1377 usedDims[dimExpr.getPosition()] =
true;
1378 else if (
auto symExpr = dyn_cast<AffineSymbolExpr>(expr))
1379 usedSyms[symExpr.getPosition()] =
true;
1382 auto *context = mapOrSet->getContext();
1385 resultOperands.reserve(operands->size());
1387 llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
1389 unsigned nextDim = 0;
1390 for (
unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
1393 auto it = seenDims.find((*operands)[i]);
1394 if (it == seenDims.end()) {
1396 resultOperands.push_back((*operands)[i]);
1397 seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
1399 dimRemapping[i] = it->second;
1403 llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
1405 unsigned nextSym = 0;
1406 for (
unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
1412 IntegerAttr operandCst;
1413 if (
matchPattern((*operands)[i + mapOrSet->getNumDims()],
1420 auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
1421 if (it == seenSymbols.end()) {
1423 resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
1424 seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
1427 symRemapping[i] = it->second;
1430 *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
1432 *operands = resultOperands;
1437 canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
1442 canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
1449 template <
typename AffineOpTy>
1458 LogicalResult matchAndRewrite(AffineOpTy affineOp,
1461 llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
1462 AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp,
1463 AffineVectorStoreOp, AffineVectorLoadOp>::value,
1464 "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op "
1466 auto map = affineOp.getAffineMap();
1468 auto oldOperands = affineOp.getMapOperands();
1473 if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
1474 resultOperands.begin()))
1477 replaceAffineOp(rewriter, affineOp, map, resultOperands);
1485 void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
1492 void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
1496 prefetch, prefetch.getMemref(), map, mapOperands, prefetch.getIsWrite(),
1497 prefetch.getLocalityHint(), prefetch.getIsDataCache());
1500 void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
1504 store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
1507 void SimplifyAffineOp<AffineVectorLoadOp>::replaceAffineOp(
1511 vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map,
1515 void SimplifyAffineOp<AffineVectorStoreOp>::replaceAffineOp(
1519 vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map,
1524 template <
typename AffineOpTy>
1525 void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
1534 results.
add<SimplifyAffineOp<AffineApplyOp>>(context);
1565 p <<
" " << getSrcMemRef() <<
'[';
1567 p <<
"], " << getDstMemRef() <<
'[';
1569 p <<
"], " << getTagMemRef() <<
'[';
1573 p <<
", " << getStride();
1574 p <<
", " << getNumElementsPerStride();
1576 p <<
" : " << getSrcMemRefType() <<
", " << getDstMemRefType() <<
", "
1577 << getTagMemRefType();
1589 AffineMapAttr srcMapAttr;
1592 AffineMapAttr dstMapAttr;
1595 AffineMapAttr tagMapAttr;
1610 getSrcMapAttrStrName(),
1614 getDstMapAttrStrName(),
1618 getTagMapAttrStrName(),
1627 if (!strideInfo.empty() && strideInfo.size() != 2) {
1629 "expected two stride related operands");
1631 bool isStrided = strideInfo.size() == 2;
1636 if (types.size() != 3)
1654 if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
1655 dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
1656 tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1658 "memref operand count not equal to map.numInputs");
1662 LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
1663 if (!llvm::isa<MemRefType>(getOperand(getSrcMemRefOperandIndex()).
getType()))
1664 return emitOpError(
"expected DMA source to be of memref type");
1665 if (!llvm::isa<MemRefType>(getOperand(getDstMemRefOperandIndex()).
getType()))
1666 return emitOpError(
"expected DMA destination to be of memref type");
1667 if (!llvm::isa<MemRefType>(getOperand(getTagMemRefOperandIndex()).
getType()))
1668 return emitOpError(
"expected DMA tag to be of memref type");
1670 unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
1671 getDstMap().getNumInputs() +
1672 getTagMap().getNumInputs();
1673 if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
1674 getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
1675 return emitOpError(
"incorrect number of operands");
1679 for (
auto idx : getSrcIndices()) {
1680 if (!idx.getType().isIndex())
1681 return emitOpError(
"src index to dma_start must have 'index' type");
1684 "src index must be a valid dimension or symbol identifier");
1686 for (
auto idx : getDstIndices()) {
1687 if (!idx.getType().isIndex())
1688 return emitOpError(
"dst index to dma_start must have 'index' type");
1691 "dst index must be a valid dimension or symbol identifier");
1693 for (
auto idx : getTagIndices()) {
1694 if (!idx.getType().isIndex())
1695 return emitOpError(
"tag index to dma_start must have 'index' type");
1698 "tag index must be a valid dimension or symbol identifier");
1709 void AffineDmaStartOp::getEffects(
1735 p <<
" " << getTagMemRef() <<
'[';
1740 p <<
" : " << getTagMemRef().getType();
1751 AffineMapAttr tagMapAttr;
1760 getTagMapAttrStrName(),
1769 if (!llvm::isa<MemRefType>(type))
1771 "expected tag to be of memref type");
1773 if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1775 "tag memref operand count != to map.numInputs");
1779 LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
1780 if (!llvm::isa<MemRefType>(getOperand(0).
getType()))
1781 return emitOpError(
"expected DMA tag to be of memref type");
1783 for (
auto idx : getTagIndices()) {
1784 if (!idx.getType().isIndex())
1785 return emitOpError(
"index to dma_wait must have 'index' type");
1788 "index must be a valid dimension or symbol identifier");
1799 void AffineDmaWaitOp::getEffects(
1815 ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
1816 assert(((!lbMap && lbOperands.empty()) ||
1818 "lower bound operand count does not match the affine map");
1819 assert(((!ubMap && ubOperands.empty()) ||
1821 "upper bound operand count does not match the affine map");
1822 assert(step > 0 &&
"step has to be a positive integer constant");
1828 getOperandSegmentSizeAttr(),
1830 static_cast<int32_t>(ubOperands.size()),
1831 static_cast<int32_t>(iterArgs.size())}));
1833 for (
Value val : iterArgs)
1855 Value inductionVar =
1857 for (
Value val : iterArgs)
1858 bodyBlock->
addArgument(val.getType(), val.getLoc());
1863 if (iterArgs.empty() && !bodyBuilder) {
1864 ensureTerminator(*bodyRegion, builder, result.
location);
1865 }
else if (bodyBuilder) {
1868 bodyBuilder(builder, result.
location, inductionVar,
1874 int64_t ub, int64_t step,
ValueRange iterArgs,
1875 BodyBuilderFn bodyBuilder) {
1878 return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,
1882 LogicalResult AffineForOp::verifyRegions() {
1885 auto *body = getBody();
1886 if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
1887 return emitOpError(
"expected body to have a single index argument for the "
1888 "induction variable");
1892 if (getLowerBoundMap().getNumInputs() > 0)
1894 getLowerBoundMap().getNumDims())))
1897 if (getUpperBoundMap().getNumInputs() > 0)
1899 getUpperBoundMap().getNumDims())))
1902 unsigned opNumResults = getNumResults();
1903 if (opNumResults == 0)
1909 if (getNumIterOperands() != opNumResults)
1911 "mismatch between the number of loop-carried values and results");
1912 if (getNumRegionIterArgs() != opNumResults)
1914 "mismatch between the number of basic block args and results");
1924 bool failedToParsedMinMax =
1928 auto boundAttrStrName =
1929 isLower ? AffineForOp::getLowerBoundMapAttrName(result.
name)
1930 : AffineForOp::getUpperBoundMapAttrName(result.
name);
1937 if (!boundOpInfos.empty()) {
1939 if (boundOpInfos.size() > 1)
1941 "expected only one loop bound operand");
1966 if (
auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(boundAttr)) {
1967 unsigned currentNumOperands = result.
operands.size();
1972 auto map = affineMapAttr.getValue();
1976 "dim operand count and affine map dim count must match");
1978 unsigned numDimAndSymbolOperands =
1979 result.
operands.size() - currentNumOperands;
1980 if (numDims + map.
getNumSymbols() != numDimAndSymbolOperands)
1983 "symbol operand count and affine map symbol count must match");
1989 return p.
emitError(attrLoc,
"lower loop bound affine map with "
1990 "multiple results requires 'max' prefix");
1992 return p.
emitError(attrLoc,
"upper loop bound affine map with multiple "
1993 "results requires 'min' prefix");
1999 if (
auto integerAttr = llvm::dyn_cast<IntegerAttr>(boundAttr)) {
2009 "expected valid affine map representation for loop bounds");
2021 int64_t numOperands = result.
operands.size();
2024 int64_t numLbOperands = result.
operands.size() - numOperands;
2027 numOperands = result.
operands.size();
2030 int64_t numUbOperands = result.
operands.size() - numOperands;
2035 getStepAttrName(result.
name),
2039 IntegerAttr stepAttr;
2041 getStepAttrName(result.
name).data(),
2045 if (stepAttr.getValue().isNegative())
2048 "expected step to be representable as a positive signed integer");
2056 regionArgs.push_back(inductionVariable);
2064 for (
auto argOperandType :
2065 llvm::zip(llvm::drop_begin(regionArgs), operands, result.
types)) {
2066 Type type = std::get<2>(argOperandType);
2067 std::get<0>(argOperandType).type = type;
2075 getOperandSegmentSizeAttr(),
2077 static_cast<int32_t>(numUbOperands),
2078 static_cast<int32_t>(operands.size())}));
2082 if (regionArgs.size() != result.
types.size() + 1)
2085 "mismatch between the number of loop-carried values and results");
2089 AffineForOp::ensureTerminator(*body, builder, result.
location);
2111 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
2112 p << constExpr.getValue();
2120 if (dyn_cast<AffineSymbolExpr>(expr)) {
2136 unsigned AffineForOp::getNumIterOperands() {
2137 AffineMap lbMap = getLowerBoundMapAttr().getValue();
2138 AffineMap ubMap = getUpperBoundMapAttr().getValue();
2143 std::optional<MutableArrayRef<OpOperand>>
2144 AffineForOp::getYieldedValuesMutable() {
2145 return cast<AffineYieldOp>(getBody()->getTerminator()).getOperandsMutable();
2157 if (getStepAsInt() != 1)
2158 p <<
" step " << getStepAsInt();
2160 bool printBlockTerminators =
false;
2161 if (getNumIterOperands() > 0) {
2163 auto regionArgs = getRegionIterArgs();
2164 auto operands = getInits();
2166 llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](
auto it) {
2167 p << std::get<0>(it) <<
" = " << std::get<1>(it);
2169 p <<
") -> (" << getResultTypes() <<
")";
2170 printBlockTerminators =
true;
2175 printBlockTerminators);
2177 (*this)->getAttrs(),
2178 {getLowerBoundMapAttrName(getOperation()->getName()),
2179 getUpperBoundMapAttrName(getOperation()->getName()),
2180 getStepAttrName(getOperation()->getName()),
2181 getOperandSegmentSizeAttr()});
2186 auto foldLowerOrUpperBound = [&forOp](
bool lower) {
2190 auto boundOperands =
2191 lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
2192 for (
auto operand : boundOperands) {
2195 operandConstants.push_back(operandCst);
2199 lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
2201 "bound maps should have at least one result");
2203 if (failed(boundMap.
constantFold(operandConstants, foldedResults)))
2207 assert(!foldedResults.empty() &&
"bounds should have at least one result");
2208 auto maxOrMin = llvm::cast<IntegerAttr>(foldedResults[0]).getValue();
2209 for (
unsigned i = 1, e = foldedResults.size(); i < e; i++) {
2210 auto foldedResult = llvm::cast<IntegerAttr>(foldedResults[i]).getValue();
2211 maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
2212 : llvm::APIntOps::smin(maxOrMin, foldedResult);
2214 lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
2215 : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
2220 bool folded =
false;
2221 if (!forOp.hasConstantLowerBound())
2222 folded |= succeeded(foldLowerOrUpperBound(
true));
2225 if (!forOp.hasConstantUpperBound())
2226 folded |= succeeded(foldLowerOrUpperBound(
false));
2227 return success(folded);
2235 auto lbMap = forOp.getLowerBoundMap();
2236 auto ubMap = forOp.getUpperBoundMap();
2237 auto prevLbMap = lbMap;
2238 auto prevUbMap = ubMap;
2251 if (lbMap == prevLbMap && ubMap == prevUbMap)
2254 if (lbMap != prevLbMap)
2255 forOp.setLowerBound(lbOperands, lbMap);
2256 if (ubMap != prevUbMap)
2257 forOp.setUpperBound(ubOperands, ubMap);
2263 static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
2264 int64_t step = forOp.getStepAsInt();
2265 if (!forOp.hasConstantBounds() || step <= 0)
2266 return std::nullopt;
2267 int64_t lb = forOp.getConstantLowerBound();
2268 int64_t ub = forOp.getConstantUpperBound();
2269 return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
2277 LogicalResult matchAndRewrite(AffineForOp forOp,
2280 if (!llvm::hasSingleElement(*forOp.getBody()))
2282 if (forOp.getNumResults() == 0)
2284 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
2285 if (tripCount && *tripCount == 0) {
2288 rewriter.
replaceOp(forOp, forOp.getInits());
2292 auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
2293 auto iterArgs = forOp.getRegionIterArgs();
2294 bool hasValDefinedOutsideLoop =
false;
2295 bool iterArgsNotInOrder =
false;
2296 for (
unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
2297 Value val = yieldOp.getOperand(i);
2298 auto *iterArgIt = llvm::find(iterArgs, val);
2299 if (iterArgIt == iterArgs.end()) {
2301 assert(forOp.isDefinedOutsideOfLoop(val) &&
2302 "must be defined outside of the loop");
2303 hasValDefinedOutsideLoop =
true;
2304 replacements.push_back(val);
2306 unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
2308 iterArgsNotInOrder =
true;
2309 replacements.push_back(forOp.getInits()[pos]);
2314 if (!tripCount.has_value() &&
2315 (hasValDefinedOutsideLoop || iterArgsNotInOrder))
2319 if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
2321 rewriter.
replaceOp(forOp, replacements);
2329 results.
add<AffineForEmptyLoopFolder>(context);
2333 assert((point.
isParent() || point == getRegion()) &&
"invalid region point");
2340 void AffineForOp::getSuccessorRegions(
2342 assert((point.
isParent() || point == getRegion()) &&
"expected loop region");
2347 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*
this);
2348 if (point.
isParent() && tripCount.has_value()) {
2349 if (tripCount.value() > 0) {
2350 regions.push_back(
RegionSuccessor(&getRegion(), getRegionIterArgs()));
2353 if (tripCount.value() == 0) {
2361 if (!point.
isParent() && tripCount && *tripCount == 1) {
2368 regions.push_back(
RegionSuccessor(&getRegion(), getRegionIterArgs()));
2374 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(op);
2375 return tripCount && *tripCount == 0;
2378 LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
2388 results.assign(getInits().begin(), getInits().end());
2391 return success(folded);
2404 assert(map.
getNumResults() >= 1 &&
"bound map has at least one result");
2405 getLowerBoundOperandsMutable().assign(lbOperands);
2406 setLowerBoundMap(map);
2411 assert(map.
getNumResults() >= 1 &&
"bound map has at least one result");
2412 getUpperBoundOperandsMutable().assign(ubOperands);
2413 setUpperBoundMap(map);
2416 bool AffineForOp::hasConstantLowerBound() {
2417 return getLowerBoundMap().isSingleConstant();
2420 bool AffineForOp::hasConstantUpperBound() {
2421 return getUpperBoundMap().isSingleConstant();
2424 int64_t AffineForOp::getConstantLowerBound() {
2425 return getLowerBoundMap().getSingleConstantResult();
2428 int64_t AffineForOp::getConstantUpperBound() {
2429 return getUpperBoundMap().getSingleConstantResult();
2432 void AffineForOp::setConstantLowerBound(int64_t value) {
2436 void AffineForOp::setConstantUpperBound(int64_t value) {
2440 AffineForOp::operand_range AffineForOp::getControlOperands() {
2445 bool AffineForOp::matchingBoundOperandList() {
2446 auto lbMap = getLowerBoundMap();
2447 auto ubMap = getUpperBoundMap();
2453 for (
unsigned i = 0, e = lbMap.
getNumInputs(); i < e; i++) {
2455 if (getOperand(i) != getOperand(numOperands + i))
2463 std::optional<SmallVector<Value>> AffineForOp::getLoopInductionVars() {
2467 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopLowerBounds() {
2468 if (!hasConstantLowerBound())
2469 return std::nullopt;
2472 OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))};
2475 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopSteps() {
2481 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopUpperBounds() {
2482 if (!hasConstantUpperBound())
2486 OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()))};
2489 FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
2491 bool replaceInitOperandUsesInLoop,
2496 auto inits = llvm::to_vector(getInits());
2497 inits.append(newInitOperands.begin(), newInitOperands.end());
2498 AffineForOp newLoop = rewriter.
create<AffineForOp>(
2503 auto yieldOp = cast<AffineYieldOp>(getBody()->getTerminator());
2505 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
2510 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
2511 assert(newInitOperands.size() == newYieldedValues.size() &&
2512 "expected as many new yield values as new iter operands");
2514 yieldOp.getOperandsMutable().append(newYieldedValues);
2519 rewriter.
mergeBlocks(getBody(), newLoop.getBody(),
2520 newLoop.getBody()->getArguments().take_front(
2521 getBody()->getNumArguments()));
2523 if (replaceInitOperandUsesInLoop) {
2526 for (
auto it : llvm::zip(newInitOperands, newIterArgs)) {
2537 newLoop->getResults().take_front(getNumResults()));
2538 return cast<LoopLikeOpInterface>(newLoop.getOperation());
2566 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2567 if (!ivArg || !ivArg.getOwner() || !ivArg.getOwner()->getParent())
2568 return AffineForOp();
2570 ivArg.getOwner()->getParent()->getParentOfType<AffineForOp>())
2572 return forOp.getInductionVar() == val ? forOp : AffineForOp();
2573 return AffineForOp();
2577 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2578 if (!ivArg || !ivArg.getOwner())
2581 auto parallelOp = dyn_cast<AffineParallelOp>(containingOp);
2582 if (parallelOp && llvm::is_contained(parallelOp.getIVs(), val))
2591 ivs->reserve(forInsts.size());
2592 for (
auto forInst : forInsts)
2593 ivs->push_back(forInst.getInductionVar());
2598 ivs.reserve(affineOps.size());
2601 if (
auto forOp = dyn_cast<AffineForOp>(op))
2602 ivs.push_back(forOp.getInductionVar());
2603 else if (
auto parallelOp = dyn_cast<AffineParallelOp>(op))
2604 for (
size_t i = 0; i < parallelOp.getBody()->getNumArguments(); i++)
2605 ivs.push_back(parallelOp.getBody()->getArgument(i));
2611 template <
typename BoundListTy,
typename LoopCreatorTy>
2616 LoopCreatorTy &&loopCreatorFn) {
2617 assert(lbs.size() == ubs.size() &&
"Mismatch in number of arguments");
2618 assert(lbs.size() == steps.size() &&
"Mismatch in number of arguments");
2630 ivs.reserve(lbs.size());
2631 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
2637 if (i == e - 1 && bodyBuilderFn) {
2639 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2641 nestedBuilder.
create<AffineYieldOp>(nestedLoc);
2646 auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
2654 int64_t ub, int64_t step,
2655 AffineForOp::BodyBuilderFn bodyBuilderFn) {
2656 return builder.
create<AffineForOp>(loc, lb, ub, step,
2657 std::nullopt, bodyBuilderFn);
2664 AffineForOp::BodyBuilderFn bodyBuilderFn) {
2667 if (lbConst && ubConst)
2669 ubConst.value(), step, bodyBuilderFn);
2672 std::nullopt, bodyBuilderFn);
2700 LogicalResult matchAndRewrite(AffineIfOp ifOp,
2702 if (ifOp.getElseRegion().empty() ||
2703 !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
2718 LogicalResult matchAndRewrite(AffineIfOp op,
2721 auto isTriviallyFalse = [](
IntegerSet iSet) {
2722 return iSet.isEmptyIntegerSet();
2726 return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 &&
2727 iSet.getConstraint(0) == 0);
2730 IntegerSet affineIfConditions = op.getIntegerSet();
2732 if (isTriviallyFalse(affineIfConditions)) {
2736 if (op.getNumResults() == 0 && !op.hasElse()) {
2742 blockToMove = op.getElseBlock();
2743 }
else if (isTriviallyTrue(affineIfConditions)) {
2744 blockToMove = op.getThenBlock();
2762 rewriter.
eraseOp(blockToMoveTerminator);
2770 void AffineIfOp::getSuccessorRegions(
2779 if (getElseRegion().empty()) {
2780 regions.push_back(getResults());
2796 auto conditionAttr =
2797 (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2799 return emitOpError(
"requires an integer set attribute named 'condition'");
2802 IntegerSet condition = conditionAttr.getValue();
2804 return emitOpError(
"operand count and condition integer set dimension and "
2805 "symbol count must match");
2817 IntegerSetAttr conditionAttr;
2820 AffineIfOp::getConditionAttrStrName(),
2826 auto set = conditionAttr.getValue();
2827 if (set.getNumDims() != numDims)
2830 "dim operand count and integer set dim count must match");
2831 if (numDims + set.getNumSymbols() != result.
operands.size())
2834 "symbol operand count and integer set symbol count must match");
2848 AffineIfOp::ensureTerminator(*thenRegion, parser.
getBuilder(),
2855 AffineIfOp::ensureTerminator(*elseRegion, parser.
getBuilder(),
2867 auto conditionAttr =
2868 (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2869 p <<
" " << conditionAttr;
2871 conditionAttr.getValue().getNumDims(), p);
2878 auto &elseRegion = this->getElseRegion();
2879 if (!elseRegion.
empty()) {
2888 getConditionAttrStrName());
2893 ->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName())
2897 void AffineIfOp::setIntegerSet(
IntegerSet newSet) {
2903 (*this)->setOperands(operands);
2908 bool withElseRegion) {
2909 assert(resultTypes.empty() || withElseRegion);
2918 if (resultTypes.empty())
2919 AffineIfOp::ensureTerminator(*thenRegion, builder, result.
location);
2922 if (withElseRegion) {
2924 if (resultTypes.empty())
2925 AffineIfOp::ensureTerminator(*elseRegion, builder, result.
location);
2931 AffineIfOp::build(builder, result, {}, set, args,
2946 if (llvm::none_of(operands,
2957 auto set = getIntegerSet();
2963 if (getIntegerSet() == set && llvm::equal(operands, getOperands()))
2966 setConditional(set, operands);
2972 results.
add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context);
2981 assert(operands.size() == 1 + map.
getNumInputs() &&
"inconsistent operands");
2985 auto memrefType = llvm::cast<MemRefType>(operands[0].
getType());
2986 result.
types.push_back(memrefType.getElementType());
2991 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
2994 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
2996 result.
types.push_back(memrefType.getElementType());
3001 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
3002 int64_t rank = memrefType.getRank();
3007 build(builder, result, memref, map, indices);
3016 AffineMapAttr mapAttr;
3021 AffineLoadOp::getMapAttrStrName(),
3031 p <<
" " << getMemRef() <<
'[';
3032 if (AffineMapAttr mapAttr =
3033 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3037 {getMapAttrStrName()});
3043 static LogicalResult
3046 MemRefType memrefType,
unsigned numIndexOperands) {
3049 return op->
emitOpError(
"affine map num results must equal memref rank");
3051 return op->
emitOpError(
"expects as many subscripts as affine map inputs");
3054 for (
auto idx : mapOperands) {
3055 if (!idx.getType().isIndex())
3056 return op->
emitOpError(
"index to load must have 'index' type");
3059 "index must be a valid dimension or symbol identifier");
3067 if (
getType() != memrefType.getElementType())
3068 return emitOpError(
"result type must match element type of memref");
3072 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3073 getMapOperands(), memrefType,
3074 getNumOperands() - 1)))
3082 results.
add<SimplifyAffineOp<AffineLoadOp>>(context);
3091 auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
3098 auto global = dyn_cast_or_null<memref::GlobalOp>(
3105 llvm::dyn_cast_or_null<DenseElementsAttr>(global.getConstantInitValue());
3109 if (
auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(cstAttr))
3110 return splatAttr.getSplatValue<
Attribute>();
3112 if (!getAffineMap().isConstant())
3114 auto indices = llvm::to_vector<4>(
3115 llvm::map_range(getAffineMap().getConstantResults(),
3116 [](int64_t v) -> uint64_t {
return v; }));
3117 return cstAttr.getValues<
Attribute>()[indices];
3127 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
3138 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
3139 int64_t rank = memrefType.getRank();
3144 build(builder, result, valueToStore, memref, map, indices);
3153 AffineMapAttr mapAttr;
3158 mapOperands, mapAttr, AffineStoreOp::getMapAttrStrName(),
3169 p <<
" " << getValueToStore();
3170 p <<
", " << getMemRef() <<
'[';
3171 if (AffineMapAttr mapAttr =
3172 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3176 {getMapAttrStrName()});
3183 if (getValueToStore().
getType() != memrefType.getElementType())
3185 "value to store must have the same type as memref element type");
3189 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3190 getMapOperands(), memrefType,
3191 getNumOperands() - 2)))
3199 results.
add<SimplifyAffineOp<AffineStoreOp>>(context);
3202 LogicalResult AffineStoreOp::fold(FoldAdaptor adaptor,
3212 template <
typename T>
3215 if (op.getNumOperands() !=
3216 op.getMap().getNumDims() + op.getMap().getNumSymbols())
3217 return op.emitOpError(
3218 "operand count and affine map dimension and symbol count must match");
3220 if (op.getMap().getNumResults() == 0)
3221 return op.emitOpError(
"affine map expect at least one result");
3225 template <
typename T>
3227 p <<
' ' << op->getAttr(T::getMapAttrStrName());
3228 auto operands = op.getOperands();
3229 unsigned numDims = op.getMap().getNumDims();
3230 p <<
'(' << operands.take_front(numDims) <<
')';
3232 if (operands.size() != numDims)
3233 p <<
'[' << operands.drop_front(numDims) <<
']';
3235 {T::getMapAttrStrName()});
3238 template <
typename T>
3245 AffineMapAttr mapAttr;
3261 template <
typename T>
3263 static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
3264 "expected affine min or max op");
3270 auto foldedMap = op.getMap().partialConstantFold(operands, &results);
3272 if (foldedMap.getNumSymbols() == 1 && foldedMap.isSymbolIdentity())
3273 return op.getOperand(0);
3276 if (results.empty()) {
3278 if (foldedMap == op.getMap())
3281 return op.getResult();
3285 auto resultIt = std::is_same<T, AffineMinOp>::value
3286 ? llvm::min_element(results)
3287 : llvm::max_element(results);
3288 if (resultIt == results.end())
3294 template <
typename T>
3300 AffineMap oldMap = affineOp.getAffineMap();
3306 if (!llvm::is_contained(newExprs, expr))
3307 newExprs.push_back(expr);
3337 template <
typename T>
3343 AffineMap oldMap = affineOp.getAffineMap();
3345 affineOp.getMapOperands().take_front(oldMap.
getNumDims());
3347 affineOp.getMapOperands().take_back(oldMap.
getNumSymbols());
3349 auto newDimOperands = llvm::to_vector<8>(dimOperands);
3350 auto newSymOperands = llvm::to_vector<8>(symOperands);
3358 if (
auto symExpr = dyn_cast<AffineSymbolExpr>(expr)) {
3359 Value symValue = symOperands[symExpr.getPosition()];
3361 producerOps.push_back(producerOp);
3364 }
else if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
3365 Value dimValue = dimOperands[dimExpr.getPosition()];
3367 producerOps.push_back(producerOp);
3374 newExprs.push_back(expr);
3377 if (producerOps.empty())
3384 for (T producerOp : producerOps) {
3385 AffineMap producerMap = producerOp.getAffineMap();
3386 unsigned numProducerDims = producerMap.
getNumDims();
3391 producerOp.getMapOperands().take_front(numProducerDims);
3393 producerOp.getMapOperands().take_back(numProducerSyms);
3394 newDimOperands.append(dimValues.begin(), dimValues.end());
3395 newSymOperands.append(symValues.begin(), symValues.end());
3399 newExprs.push_back(expr.shiftDims(numProducerDims, numUsedDims)
3400 .shiftSymbols(numProducerSyms, numUsedSyms));
3403 numUsedDims += numProducerDims;
3404 numUsedSyms += numProducerSyms;
3410 llvm::to_vector<8>(llvm::concat<Value>(newDimOperands, newSymOperands));
3429 if (!resultExpr.isPureAffine())
3434 if (failed(flattenResult))
3447 if (llvm::is_sorted(flattenedExprs))
3452 llvm::to_vector(llvm::seq<unsigned>(0, map.
getNumResults()));
3453 llvm::sort(resultPermutation, [&](
unsigned lhs,
unsigned rhs) {
3454 return flattenedExprs[lhs] < flattenedExprs[rhs];
3457 for (
unsigned idx : resultPermutation)
3478 template <
typename T>
3484 AffineMap map = affineOp.getAffineMap();
3492 template <
typename T>
3498 if (affineOp.getMap().getNumResults() != 1)
3501 affineOp.getOperands());
3529 return parseAffineMinMaxOp<AffineMinOp>(parser, result);
3557 return parseAffineMinMaxOp<AffineMaxOp>(parser, result);
3576 IntegerAttr hintInfo;
3578 StringRef readOrWrite, cacheType;
3580 AffineMapAttr mapAttr;
3584 AffinePrefetchOp::getMapAttrStrName(),
3590 AffinePrefetchOp::getLocalityHintAttrStrName(),
3600 if (readOrWrite !=
"read" && readOrWrite !=
"write")
3602 "rw specifier has to be 'read' or 'write'");
3603 result.
addAttribute(AffinePrefetchOp::getIsWriteAttrStrName(),
3606 if (cacheType !=
"data" && cacheType !=
"instr")
3608 "cache type has to be 'data' or 'instr'");
3610 result.
addAttribute(AffinePrefetchOp::getIsDataCacheAttrStrName(),
3617 p <<
" " << getMemref() <<
'[';
3618 AffineMapAttr mapAttr =
3619 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3622 p <<
']' <<
", " << (getIsWrite() ?
"write" :
"read") <<
", "
3623 <<
"locality<" << getLocalityHint() <<
">, "
3624 << (getIsDataCache() ?
"data" :
"instr");
3626 (*this)->getAttrs(),
3627 {getMapAttrStrName(), getLocalityHintAttrStrName(),
3628 getIsDataCacheAttrStrName(), getIsWriteAttrStrName()});
3633 auto mapAttr = (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3637 return emitOpError(
"affine.prefetch affine map num results must equal"
3640 return emitOpError(
"too few operands");
3642 if (getNumOperands() != 1)
3643 return emitOpError(
"too few operands");
3647 for (
auto idx : getMapOperands()) {
3650 "index must be a valid dimension or symbol identifier");
3658 results.
add<SimplifyAffineOp<AffinePrefetchOp>>(context);
3661 LogicalResult AffinePrefetchOp::fold(FoldAdaptor adaptor,
3676 auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) {
3680 build(builder, result, resultTypes, reductions, lbs, {}, ubs,
3690 assert(llvm::all_of(lbMaps,
3692 return m.getNumDims() == lbMaps[0].getNumDims() &&
3693 m.getNumSymbols() == lbMaps[0].getNumSymbols();
3695 "expected all lower bounds maps to have the same number of dimensions "
3697 assert(llvm::all_of(ubMaps,
3699 return m.getNumDims() == ubMaps[0].getNumDims() &&
3700 m.getNumSymbols() == ubMaps[0].getNumSymbols();
3702 "expected all upper bounds maps to have the same number of dimensions "
3704 assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) &&
3705 "expected lower bound maps to have as many inputs as lower bound "
3707 assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) &&
3708 "expected upper bound maps to have as many inputs as upper bound "
3716 for (arith::AtomicRMWKind reduction : reductions)
3717 reductionAttrs.push_back(
3729 groups.reserve(groups.size() + maps.size());
3730 exprs.reserve(maps.size());
3732 llvm::append_range(exprs, m.getResults());
3733 groups.push_back(m.getNumResults());
3735 return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs,
3741 AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups);
3742 AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups);
3760 for (
unsigned i = 0, e = steps.size(); i < e; ++i)
3762 if (resultTypes.empty())
3763 ensureTerminator(*bodyRegion, builder, result.
location);
3767 return {&getRegion()};
3770 unsigned AffineParallelOp::getNumDims() {
return getSteps().size(); }
3772 AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
3773 return getOperands().take_front(getLowerBoundsMap().getNumInputs());
3776 AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
3777 return getOperands().drop_front(getLowerBoundsMap().getNumInputs());
3780 AffineMap AffineParallelOp::getLowerBoundMap(
unsigned pos) {
3781 auto values = getLowerBoundsGroups().getValues<int32_t>();
3783 for (
unsigned i = 0; i < pos; ++i)
3785 return getLowerBoundsMap().getSliceMap(start, values[pos]);
3788 AffineMap AffineParallelOp::getUpperBoundMap(
unsigned pos) {
3789 auto values = getUpperBoundsGroups().getValues<int32_t>();
3791 for (
unsigned i = 0; i < pos; ++i)
3793 return getUpperBoundsMap().getSliceMap(start, values[pos]);
3797 return AffineValueMap(getLowerBoundsMap(), getLowerBoundsOperands());
3801 return AffineValueMap(getUpperBoundsMap(), getUpperBoundsOperands());
3804 std::optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
3805 if (hasMinMaxBounds())
3806 return std::nullopt;
3811 AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
3814 for (
unsigned i = 0, e = rangesValueMap.
getNumResults(); i < e; ++i) {
3815 auto expr = rangesValueMap.
getResult(i);
3816 auto cst = dyn_cast<AffineConstantExpr>(expr);
3818 return std::nullopt;
3819 out.push_back(cst.getValue());
3824 Block *AffineParallelOp::getBody() {
return &getRegion().
front(); }
3826 OpBuilder AffineParallelOp::getBodyBuilder() {
3827 return OpBuilder(getBody(), std::prev(getBody()->end()));
3832 "operands to map must match number of inputs");
3834 auto ubOperands = getUpperBoundsOperands();
3837 newOperands.append(ubOperands.begin(), ubOperands.end());
3838 (*this)->setOperands(newOperands);
3845 "operands to map must match number of inputs");
3848 newOperands.append(ubOperands.begin(), ubOperands.end());
3849 (*this)->setOperands(newOperands);
3855 setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
3860 arith::AtomicRMWKind op) {
3862 case arith::AtomicRMWKind::addf:
3863 return isa<FloatType>(resultType);
3864 case arith::AtomicRMWKind::addi:
3865 return isa<IntegerType>(resultType);
3866 case arith::AtomicRMWKind::assign:
3868 case arith::AtomicRMWKind::mulf:
3869 return isa<FloatType>(resultType);
3870 case arith::AtomicRMWKind::muli:
3871 return isa<IntegerType>(resultType);
3872 case arith::AtomicRMWKind::maximumf:
3873 return isa<FloatType>(resultType);
3874 case arith::AtomicRMWKind::minimumf:
3875 return isa<FloatType>(resultType);
3876 case arith::AtomicRMWKind::maxs: {
3877 auto intType = llvm::dyn_cast<IntegerType>(resultType);
3878 return intType && intType.isSigned();
3880 case arith::AtomicRMWKind::mins: {
3881 auto intType = llvm::dyn_cast<IntegerType>(resultType);
3882 return intType && intType.isSigned();
3884 case arith::AtomicRMWKind::maxu: {
3885 auto intType = llvm::dyn_cast<IntegerType>(resultType);
3886 return intType && intType.isUnsigned();
3888 case arith::AtomicRMWKind::minu: {
3889 auto intType = llvm::dyn_cast<IntegerType>(resultType);
3890 return intType && intType.isUnsigned();
3892 case arith::AtomicRMWKind::ori:
3893 return isa<IntegerType>(resultType);
3894 case arith::AtomicRMWKind::andi:
3895 return isa<IntegerType>(resultType);
3902 auto numDims = getNumDims();
3905 getSteps().size() != numDims || getBody()->getNumArguments() != numDims) {
3906 return emitOpError() <<
"the number of region arguments ("
3907 << getBody()->getNumArguments()
3908 <<
") and the number of map groups for lower ("
3909 << getLowerBoundsGroups().getNumElements()
3910 <<
") and upper bound ("
3911 << getUpperBoundsGroups().getNumElements()
3912 <<
"), and the number of steps (" << getSteps().size()
3913 <<
") must all match";
3916 unsigned expectedNumLBResults = 0;
3917 for (APInt v : getLowerBoundsGroups())
3918 expectedNumLBResults += v.getZExtValue();
3919 if (expectedNumLBResults != getLowerBoundsMap().getNumResults())
3920 return emitOpError() <<
"expected lower bounds map to have "
3921 << expectedNumLBResults <<
" results";
3922 unsigned expectedNumUBResults = 0;
3923 for (APInt v : getUpperBoundsGroups())
3924 expectedNumUBResults += v.getZExtValue();
3925 if (expectedNumUBResults != getUpperBoundsMap().getNumResults())
3926 return emitOpError() <<
"expected upper bounds map to have "
3927 << expectedNumUBResults <<
" results";
3929 if (getReductions().size() != getNumResults())
3930 return emitOpError(
"a reduction must be specified for each output");
3936 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
3937 if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
3938 return emitOpError(
"invalid reduction attribute");
3939 auto kind = arith::symbolizeAtomicRMWKind(intAttr.getInt()).value();
3941 return emitOpError(
"result type cannot match reduction attribute");
3947 getLowerBoundsMap().getNumDims())))
3951 getUpperBoundsMap().getNumDims())))
3956 LogicalResult AffineValueMap::canonicalize() {
3958 auto newMap = getAffineMap();
3960 if (newMap == getAffineMap() && newOperands == operands)
3962 reset(newMap, newOperands);
3975 if (!lbCanonicalized && !ubCanonicalized)
3978 if (lbCanonicalized)
3980 if (ubCanonicalized)
3986 LogicalResult AffineParallelOp::fold(FoldAdaptor adaptor,
3998 StringRef keyword) {
4001 ValueRange dimOperands = operands.take_front(numDims);
4002 ValueRange symOperands = operands.drop_front(numDims);
4004 for (llvm::APInt groupSize : group) {
4008 unsigned size = groupSize.getZExtValue();
4013 p << keyword <<
'(';
4023 p <<
" (" << getBody()->getArguments() <<
") = (";
4025 getLowerBoundsOperands(),
"max");
4028 getUpperBoundsOperands(),
"min");
4031 bool elideSteps = llvm::all_of(steps, [](int64_t step) {
return step == 1; });
4034 llvm::interleaveComma(steps, p);
4037 if (getNumResults()) {
4039 llvm::interleaveComma(getReductions(), p, [&](
auto &attr) {
4040 arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(
4041 llvm::cast<IntegerAttr>(attr).getInt());
4042 p <<
"\"" << arith::stringifyAtomicRMWKind(sym) <<
"\"";
4044 p <<
") -> (" << getResultTypes() <<
")";
4051 (*this)->getAttrs(),
4052 {AffineParallelOp::getReductionsAttrStrName(),
4053 AffineParallelOp::getLowerBoundsMapAttrStrName(),
4054 AffineParallelOp::getLowerBoundsGroupsAttrStrName(),
4055 AffineParallelOp::getUpperBoundsMapAttrStrName(),
4056 AffineParallelOp::getUpperBoundsGroupsAttrStrName(),
4057 AffineParallelOp::getStepsAttrStrName()});
4070 "expected operands to be dim or symbol expression");
4073 for (
const auto &list : operands) {
4077 for (
Value operand : valueOperands) {
4078 unsigned pos = std::distance(uniqueOperands.begin(),
4079 llvm::find(uniqueOperands, operand));
4080 if (pos == uniqueOperands.size())
4081 uniqueOperands.push_back(operand);
4082 replacements.push_back(
4092 enum class MinMaxKind { Min, Max };
4116 const llvm::StringLiteral tmpAttrStrName =
"__pseudo_bound_map";
4118 StringRef mapName = kind == MinMaxKind::Min
4119 ? AffineParallelOp::getUpperBoundsMapAttrStrName()
4120 : AffineParallelOp::getLowerBoundsMapAttrStrName();
4121 StringRef groupsName =
4122 kind == MinMaxKind::Min
4123 ? AffineParallelOp::getUpperBoundsGroupsAttrStrName()
4124 : AffineParallelOp::getLowerBoundsGroupsAttrStrName();
4141 auto parseOperands = [&]() {
4143 kind == MinMaxKind::Min ?
"min" :
"max"))) {
4144 mapOperands.clear();
4151 llvm::append_range(flatExprs, map.getValue().getResults());
4153 auto dimsRef = operandsRef.take_front(map.getValue().getNumDims());
4155 auto symsRef = operandsRef.drop_front(map.getValue().getNumDims());
4157 flatDimOperands.append(map.getValue().getNumResults(), dims);
4158 flatSymOperands.append(map.getValue().getNumResults(), syms);
4159 numMapsPerGroup.push_back(map.getValue().getNumResults());
4162 flatSymOperands.emplace_back(),
4163 flatExprs.emplace_back())))
4165 numMapsPerGroup.push_back(1);
4172 unsigned totalNumDims = 0;
4173 unsigned totalNumSyms = 0;
4174 for (
unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
4175 unsigned numDims = flatDimOperands[i].size();
4176 unsigned numSyms = flatSymOperands[i].size();
4177 flatExprs[i] = flatExprs[i]
4178 .shiftDims(numDims, totalNumDims)
4179 .shiftSymbols(numSyms, totalNumSyms);
4180 totalNumDims += numDims;
4181 totalNumSyms += numSyms;
4193 result.
operands.append(dimOperands.begin(), dimOperands.end());
4194 result.
operands.append(symOperands.begin(), symOperands.end());
4197 auto flatMap =
AffineMap::get(totalNumDims, totalNumSyms, flatExprs,
4199 flatMap = flatMap.replaceDimsAndSymbols(
4200 dimRplacements, symRepacements, dimOperands.size(), symOperands.size());
4224 AffineMapAttr stepsMapAttr;
4229 result.
addAttribute(AffineParallelOp::getStepsAttrStrName(),
4233 AffineParallelOp::getStepsAttrStrName(),
4240 auto stepsMap = stepsMapAttr.getValue();
4241 for (
const auto &result : stepsMap.getResults()) {
4242 auto constExpr = dyn_cast<AffineConstantExpr>(result);
4245 "steps must be constant integers");
4246 steps.push_back(constExpr.getValue());
4248 result.
addAttribute(AffineParallelOp::getStepsAttrStrName(),
4258 auto parseAttributes = [&]() -> ParseResult {
4268 std::optional<arith::AtomicRMWKind> reduction =
4269 arith::symbolizeAtomicRMWKind(attrVal.getValue());
4271 return parser.
emitError(loc,
"invalid reduction value: ") << attrVal;
4272 reductions.push_back(
4280 result.
addAttribute(AffineParallelOp::getReductionsAttrStrName(),
4289 for (
auto &iv : ivs)
4290 iv.type = indexType;
4296 AffineParallelOp::ensureTerminator(*body, builder, result.
location);
4305 auto *parentOp = (*this)->getParentOp();
4306 auto results = parentOp->getResults();
4307 auto operands = getOperands();
4309 if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
4310 return emitOpError() <<
"only terminates affine.if/for/parallel regions";
4311 if (parentOp->getNumResults() != getNumOperands())
4312 return emitOpError() <<
"parent of yield must have same number of "
4313 "results as the yield operands";
4314 for (
auto it : llvm::zip(results, operands)) {
4316 return emitOpError() <<
"types mismatch between yield op and its parent";
4329 assert(operands.size() == 1 + map.
getNumInputs() &&
"inconsistent operands");
4333 result.
types.push_back(resultType);
4337 VectorType resultType,
Value memref,
4339 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
4343 result.
types.push_back(resultType);
4347 VectorType resultType,
Value memref,
4349 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
4350 int64_t rank = memrefType.getRank();
4355 build(builder, result, resultType, memref, map, indices);
4358 void AffineVectorLoadOp::getCanonicalizationPatterns(
RewritePatternSet &results,
4360 results.
add<SimplifyAffineOp<AffineVectorLoadOp>>(context);
4368 MemRefType memrefType;
4369 VectorType resultType;
4371 AffineMapAttr mapAttr;
4376 AffineVectorLoadOp::getMapAttrStrName(),
4387 p <<
" " << getMemRef() <<
'[';
4388 if (AffineMapAttr mapAttr =
4389 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4393 {getMapAttrStrName()});
4399 VectorType vectorType) {
4401 if (memrefType.getElementType() != vectorType.getElementType())
4403 "requires memref and vector types of the same elemental type");
4411 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4412 getMapOperands(), memrefType,
4413 getNumOperands() - 1)))
4429 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
4440 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
4441 int64_t rank = memrefType.getRank();
4446 build(builder, result, valueToStore, memref, map, indices);
4448 void AffineVectorStoreOp::getCanonicalizationPatterns(
4450 results.
add<SimplifyAffineOp<AffineVectorStoreOp>>(context);
4457 MemRefType memrefType;
4458 VectorType resultType;
4461 AffineMapAttr mapAttr;
4467 AffineVectorStoreOp::getMapAttrStrName(),
4478 p <<
" " << getValueToStore();
4479 p <<
", " << getMemRef() <<
'[';
4480 if (AffineMapAttr mapAttr =
4481 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4485 {getMapAttrStrName()});
4486 p <<
" : " <<
getMemRefType() <<
", " << getValueToStore().getType();
4492 *
this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4493 getMapOperands(), memrefType,
4494 getNumOperands() - 2)))
4507 void AffineDelinearizeIndexOp::build(
OpBuilder &odsBuilder,
4511 bool hasOuterBound) {
4513 : staticBasis.size() + 1,
4515 build(odsBuilder, odsState, returnTypes, linearIndex, dynamicBasis,
4519 void AffineDelinearizeIndexOp::build(
OpBuilder &odsBuilder,
4522 bool hasOuterBound) {
4523 if (hasOuterBound && !basis.empty() && basis.front() ==
nullptr) {
4524 hasOuterBound =
false;
4525 basis = basis.drop_front();
4531 build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4535 void AffineDelinearizeIndexOp::build(
OpBuilder &odsBuilder,
4539 bool hasOuterBound) {
4540 if (hasOuterBound && !basis.empty() && basis.front() ==
OpFoldResult()) {
4541 hasOuterBound =
false;
4542 basis = basis.drop_front();
4547 build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4551 void AffineDelinearizeIndexOp::build(
OpBuilder &odsBuilder,
4554 bool hasOuterBound) {
4555 build(odsBuilder, odsState, linearIndex,
ValueRange{}, basis, hasOuterBound);
4560 if (getNumResults() != staticBasis.size() &&
4561 getNumResults() != staticBasis.size() + 1)
4562 return emitOpError(
"should return an index for each basis element and up "
4563 "to one extra index");
4565 auto dynamicMarkersCount = llvm::count_if(staticBasis, ShapedType::isDynamic);
4566 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
4568 "mismatch between dynamic and static basis (kDynamic marker but no "
4569 "corresponding dynamic basis entry) -- this can only happen due to an "
4570 "incorrect fold/rewrite");
4572 if (!llvm::all_of(staticBasis, [](int64_t v) {
4573 return v > 0 || ShapedType::isDynamic(v);
4575 return emitOpError(
"no basis element may be statically non-positive");
4584 static std::optional<SmallVector<int64_t>>
4588 uint64_t dynamicBasisIndex = 0;
4591 mutableDynamicBasis.
erase(dynamicBasisIndex);
4593 ++dynamicBasisIndex;
4598 if (dynamicBasisIndex == dynamicBasis.size())
4599 return std::nullopt;
4605 staticBasis.push_back(ShapedType::kDynamic);
4607 staticBasis.push_back(*basisVal);
4614 AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
4616 std::optional<SmallVector<int64_t>> maybeStaticBasis =
4618 adaptor.getDynamicBasis());
4619 if (maybeStaticBasis) {
4620 setStaticBasis(*maybeStaticBasis);
4625 if (getNumResults() == 1) {
4626 result.push_back(getLinearIndex());
4630 if (adaptor.getLinearIndex() ==
nullptr)
4633 if (!adaptor.getDynamicBasis().empty())
4636 int64_t highPart = cast<IntegerAttr>(adaptor.getLinearIndex()).getInt();
4637 Type attrType = getLinearIndex().getType();
4640 if (hasOuterBound())
4641 staticBasis = staticBasis.drop_front();
4642 for (int64_t modulus : llvm::reverse(staticBasis)) {
4643 result.push_back(
IntegerAttr::get(attrType, llvm::mod(highPart, modulus)));
4644 highPart = llvm::divideFloorSigned(highPart, modulus);
4647 std::reverse(result.begin(), result.end());
4653 if (hasOuterBound()) {
4654 if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
4656 getDynamicBasis().drop_front(), builder);
4658 return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
4662 return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
4667 if (!hasOuterBound())
4675 struct DropUnitExtentBasis
4679 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4682 std::optional<Value> zero = std::nullopt;
4683 Location loc = delinearizeOp->getLoc();
4686 zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
4687 return zero.value();
4693 for (
auto [index, basis] :
4695 std::optional<int64_t> basisVal =
4697 if (basisVal && *basisVal == 1)
4698 replacements[index] =
getZero();
4700 newBasis.push_back(basis);
4703 if (newBasis.size() == delinearizeOp.getNumResults())
4705 "no unit basis elements");
4707 if (!newBasis.empty()) {
4709 auto newDelinearizeOp = rewriter.
create<affine::AffineDelinearizeIndexOp>(
4710 loc, delinearizeOp.getLinearIndex(), newBasis);
4713 for (
auto &replacement : replacements) {
4716 replacement = newDelinearizeOp->
getResult(newIndex++);
4720 rewriter.
replaceOp(delinearizeOp, replacements);
4735 struct CancelDelinearizeOfLinearizeDisjointExactTail
4739 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4741 auto linearizeOp = delinearizeOp.getLinearIndex()
4742 .getDefiningOp<affine::AffineLinearizeIndexOp>();
4745 "index doesn't come from linearize");
4747 if (!linearizeOp.getDisjoint())
4750 ValueRange linearizeIns = linearizeOp.getMultiIndex();
4754 size_t numMatches = 0;
4755 for (
auto [linSize, delinSize] : llvm::zip(
4756 llvm::reverse(linearizeBasis), llvm::reverse(delinearizeBasis))) {
4757 if (linSize != delinSize)
4762 if (numMatches == 0)
4764 delinearizeOp,
"final basis element doesn't match linearize");
4767 if (numMatches == linearizeBasis.size() &&
4768 numMatches == delinearizeBasis.size() &&
4769 linearizeIns.size() == delinearizeOp.getNumResults()) {
4770 rewriter.
replaceOp(delinearizeOp, linearizeOp.getMultiIndex());
4774 Value newLinearize = rewriter.
create<affine::AffineLinearizeIndexOp>(
4775 linearizeOp.getLoc(), linearizeIns.drop_back(numMatches),
4777 linearizeOp.getDisjoint());
4778 auto newDelinearize = rewriter.
create<affine::AffineDelinearizeIndexOp>(
4779 delinearizeOp.getLoc(), newLinearize,
4781 delinearizeOp.hasOuterBound());
4783 mergedResults.append(linearizeIns.take_back(numMatches).begin(),
4784 linearizeIns.take_back(numMatches).end());
4785 rewriter.
replaceOp(delinearizeOp, mergedResults);
4803 struct SplitDelinearizeSpanningLastLinearizeArg final
4807 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4809 auto linearizeOp = delinearizeOp.getLinearIndex()
4810 .getDefiningOp<affine::AffineLinearizeIndexOp>();
4813 "index doesn't come from linearize");
4815 if (!linearizeOp.getDisjoint())
4817 "linearize isn't disjoint");
4819 int64_t target = linearizeOp.getStaticBasis().back();
4820 if (ShapedType::isDynamic(target))
4822 linearizeOp,
"linearize ends with dynamic basis value");
4824 int64_t sizeToSplit = 1;
4825 size_t elemsToSplit = 0;
4827 for (int64_t basisElem : llvm::reverse(basis)) {
4828 if (ShapedType::isDynamic(basisElem))
4830 delinearizeOp,
"dynamic basis element while scanning for split");
4831 sizeToSplit *= basisElem;
4834 if (sizeToSplit > target)
4836 "overshot last argument size");
4837 if (sizeToSplit == target)
4841 if (sizeToSplit < target)
4843 delinearizeOp,
"product of known basis elements doesn't exceed last "
4844 "linearize argument");
4846 if (elemsToSplit < 2)
4849 "need at least two elements to form the basis product");
4851 Value linearizeWithoutBack =
4852 rewriter.
create<affine::AffineLinearizeIndexOp>(
4853 linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(),
4854 linearizeOp.getDynamicBasis(),
4855 linearizeOp.getStaticBasis().drop_back(),
4856 linearizeOp.getDisjoint());
4857 auto delinearizeWithoutSplitPart =
4858 rewriter.
create<affine::AffineDelinearizeIndexOp>(
4859 delinearizeOp.getLoc(), linearizeWithoutBack,
4860 delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),
4861 delinearizeOp.hasOuterBound());
4862 auto delinearizeBack = rewriter.
create<affine::AffineDelinearizeIndexOp>(
4863 delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(),
4864 basis.take_back(elemsToSplit),
true);
4866 llvm::concat<Value>(delinearizeWithoutSplitPart.getResults(),
4867 delinearizeBack.getResults()));
4868 rewriter.
replaceOp(delinearizeOp, results);
4875 void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
4878 .insert<CancelDelinearizeOfLinearizeDisjointExactTail,
4879 DropUnitExtentBasis, SplitDelinearizeSpanningLastLinearizeArg>(
4887 void AffineLinearizeIndexOp::build(
OpBuilder &odsBuilder,
4891 if (!basis.empty() && basis.front() ==
Value())
4892 basis = basis.drop_front();
4897 build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
4900 void AffineLinearizeIndexOp::build(
OpBuilder &odsBuilder,
4906 basis = basis.drop_front();
4910 build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
4913 void AffineLinearizeIndexOp::build(
OpBuilder &odsBuilder,
4917 build(odsBuilder, odsState, multiIndex,
ValueRange{}, basis, disjoint);
4921 size_t numIndexes = getMultiIndex().size();
4922 size_t numBasisElems = getStaticBasis().size();
4923 if (numIndexes != numBasisElems && numIndexes != numBasisElems + 1)
4924 return emitOpError(
"should be passed a basis element for each index except "
4925 "possibly the first");
4927 auto dynamicMarkersCount =
4928 llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
4929 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
4931 "mismatch between dynamic and static basis (kDynamic marker but no "
4932 "corresponding dynamic basis entry) -- this can only happen due to an "
4933 "incorrect fold/rewrite");
4938 OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
4939 std::optional<SmallVector<int64_t>> maybeStaticBasis =
4941 adaptor.getDynamicBasis());
4942 if (maybeStaticBasis) {
4943 setStaticBasis(*maybeStaticBasis);
4947 if (getMultiIndex().empty())
4951 if (getMultiIndex().size() == 1)
4952 return getMultiIndex().front();
4954 if (llvm::any_of(adaptor.getMultiIndex(),
4955 [](
Attribute a) { return a == nullptr; }))
4958 if (!adaptor.getDynamicBasis().empty())
4963 for (
auto [length, indexAttr] :
4964 llvm::zip_first(llvm::reverse(getStaticBasis()),
4965 llvm::reverse(adaptor.getMultiIndex()))) {
4966 result = result + cast<IntegerAttr>(indexAttr).getInt() * stride;
4967 stride = stride * length;
4970 if (!hasOuterBound())
4973 cast<IntegerAttr>(adaptor.getMultiIndex().front()).getInt() * stride;
4980 if (hasOuterBound()) {
4981 if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
4983 getDynamicBasis().drop_front(), builder);
4985 return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
4989 return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
4994 if (!hasOuterBound())
5010 struct DropLinearizeUnitComponentsIfDisjointOrZero final
5014 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5017 size_t numIndices = multiIndex.size();
5019 newIndices.reserve(numIndices);
5021 newBasis.reserve(numIndices);
5023 if (!op.hasOuterBound()) {
5024 newIndices.push_back(multiIndex.front());
5025 multiIndex = multiIndex.drop_front();
5029 for (
auto [index, basisElem] : llvm::zip_equal(multiIndex, basis)) {
5031 if (!basisEntry || *basisEntry != 1) {
5032 newIndices.push_back(index);
5033 newBasis.push_back(basisElem);
5038 if (!op.getDisjoint() && (!indexValue || *indexValue != 0)) {
5039 newIndices.push_back(index);
5040 newBasis.push_back(basisElem);
5044 if (newIndices.size() == numIndices)
5046 "no unit basis entries to replace");
5048 if (newIndices.size() == 0) {
5053 op, newIndices, newBasis, op.getDisjoint());
5062 int64_t nDynamic = 0;
5072 dynamicPart.push_back(cast<Value>(term));
5076 if (
auto constant = dyn_cast<AffineConstantExpr>(result))
5078 return builder.
create<AffineApplyOp>(loc, result, dynamicPart).getResult();
5108 struct CancelLinearizeOfDelinearizePortion final
5119 unsigned linStart = 0;
5120 unsigned delinStart = 0;
5121 unsigned length = 0;
5125 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp,
5132 ValueRange multiIndex = linearizeOp.getMultiIndex();
5133 unsigned numLinArgs = multiIndex.size();
5134 unsigned linArgIdx = 0;
5138 while (linArgIdx < numLinArgs) {
5139 auto asResult = dyn_cast<OpResult>(multiIndex[linArgIdx]);
5145 auto delinearizeOp =
5146 dyn_cast<AffineDelinearizeIndexOp>(asResult.getOwner());
5147 if (!delinearizeOp) {
5164 unsigned delinArgIdx = asResult.getResultNumber();
5166 OpFoldResult firstDelinBound = delinBasis[delinArgIdx];
5168 bool boundsMatch = firstDelinBound == firstLinBound;
5169 bool bothAtFront = linArgIdx == 0 && delinArgIdx == 0;
5170 bool knownByDisjoint =
5171 linearizeOp.getDisjoint() && delinArgIdx == 0 && !firstDelinBound;
5172 if (!boundsMatch && !bothAtFront && !knownByDisjoint) {
5178 unsigned numDelinOuts = delinearizeOp.getNumResults();
5179 for (;
j + linArgIdx < numLinArgs &&
j + delinArgIdx < numDelinOuts;
5181 if (multiIndex[linArgIdx +
j] !=
5182 delinearizeOp.getResult(delinArgIdx +
j))
5184 if (linBasis[linArgIdx +
j] != delinBasis[delinArgIdx +
j])
5190 if (
j <= 1 || !alreadyMatchedDelinearize.insert(delinearizeOp).second) {
5194 matches.push_back(Match{delinearizeOp, linArgIdx, delinArgIdx,
j});
5198 if (matches.empty())
5200 linearizeOp,
"no run of delinearize outputs to deal with");
5208 newIndex.reserve(numLinArgs);
5210 newBasis.reserve(numLinArgs);
5211 unsigned prevMatchEnd = 0;
5212 for (Match m : matches) {
5213 unsigned gap = m.linStart - prevMatchEnd;
5214 llvm::append_range(newIndex, multiIndex.slice(prevMatchEnd, gap));
5215 llvm::append_range(newBasis, linBasisRef.slice(prevMatchEnd, gap));
5217 prevMatchEnd = m.linStart + m.length;
5219 PatternRewriter::InsertionGuard g(rewriter);
5223 linBasisRef.slice(m.linStart, m.length);
5230 if (m.length == m.delinearize.getNumResults()) {
5231 newIndex.push_back(m.delinearize.getLinearIndex());
5232 newBasis.push_back(newSize);
5240 newDelinBasis.erase(newDelinBasis.begin() + m.delinStart,
5241 newDelinBasis.begin() + m.delinStart + m.length);
5242 newDelinBasis.insert(newDelinBasis.begin() + m.delinStart, newSize);
5243 auto newDelinearize = rewriter.
create<AffineDelinearizeIndexOp>(
5244 m.delinearize.getLoc(), m.delinearize.getLinearIndex(),
5250 Value combinedElem = newDelinearize.getResult(m.delinStart);
5251 auto residualDelinearize = rewriter.
create<AffineDelinearizeIndexOp>(
5252 m.delinearize.getLoc(), combinedElem, basisToMerge);
5257 llvm::append_range(newDelinResults,
5258 newDelinearize.getResults().take_front(m.delinStart));
5259 llvm::append_range(newDelinResults, residualDelinearize.getResults());
5262 newDelinearize.getResults().drop_front(m.delinStart + 1));
5264 delinearizeReplacements.push_back(newDelinResults);
5265 newIndex.push_back(combinedElem);
5266 newBasis.push_back(newSize);
5268 llvm::append_range(newIndex, multiIndex.drop_front(prevMatchEnd));
5269 llvm::append_range(newBasis, linBasisRef.drop_front(prevMatchEnd));
5271 linearizeOp, newIndex, newBasis, linearizeOp.getDisjoint());
5273 for (
auto [m, newResults] :
5274 llvm::zip_equal(matches, delinearizeReplacements)) {
5275 if (newResults.empty())
5277 rewriter.
replaceOp(m.delinearize, newResults);
5288 struct DropLinearizeLeadingZero final
5292 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5294 Value leadingIdx = op.getMultiIndex().front();
5298 if (op.getMultiIndex().size() == 1) {
5305 if (op.hasOuterBound())
5306 newMixedBasis = newMixedBasis.drop_front();
5309 op, op.getMultiIndex().drop_front(), newMixedBasis, op.getDisjoint());
5315 void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
5317 patterns.add<CancelLinearizeOfDelinearizePortion, DropLinearizeLeadingZero,
5318 DropLinearizeUnitComponentsIfDisjointOrZero>(context);
5325 #define GET_OP_CLASSES
5326 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
static AffineForOp buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb, int64_t ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds known to be constants.
static bool hasTrivialZeroTripCount(AffineForOp op)
Returns true if the affine.for has zero iterations in trivial cases.
static void composeMultiResultAffineMap(AffineMap &map, SmallVectorImpl< Value > &operands)
Composes the given affine map with the given list of operands, pulling in the maps from any affine....
static void printAffineMinMaxOp(OpAsmPrinter &p, T op)
static bool isResultTypeMatchAtomicRMWKind(Type resultType, arith::AtomicRMWKind op)
static bool remainsLegalAfterInline(Value value, Region *src, Region *dest, const IRMapping &mapping, function_ref< bool(Value, Region *)> legalityCheck)
Checks if value known to be a legal affine dimension or symbol in src region remains legal if the ope...
static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr, DenseIntElementsAttr group, ValueRange operands, StringRef keyword)
Prints a lower(upper) bound of an affine parallel loop with max(min) conditions in it.
static void LLVM_ATTRIBUTE_UNUSED simplifyMapWithOperands(AffineMap &map, ArrayRef< Value > operands)
Simplify the map while exploiting information on the values in operands.
static OpFoldResult foldMinMaxOp(T op, ArrayRef< Attribute > operands)
Fold an affine min or max operation with the given operands.
static LogicalResult canonicalizeLoopBounds(AffineForOp forOp)
Canonicalize the bounds of the given loop.
static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims, unsigned numSymbols, ArrayRef< Value > operands)
Simplify expr while exploiting information from the values in operands.
static bool isValidAffineIndexOperand(Value value, Region *region)
static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
static void composeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Iterate over operands and fold away all those produced by an AffineApplyOp iteratively.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static ParseResult parseBound(bool isLower, OperationState &result, OpAsmParser &p)
Parse a for operation loop bounds.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static void composeSetAndOperands(IntegerSet &set, SmallVectorImpl< Value > &operands)
Compose any affine.apply ops feeding into operands of the integer set set by composing the maps of su...
static LogicalResult replaceDimOrSym(AffineMap *map, unsigned dimOrSymbolPosition, SmallVectorImpl< Value > &dims, SmallVectorImpl< Value > &syms)
Replace all occurrences of AffineExpr at position pos in map by the defining AffineApplyOp expression...
static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType)
Verify common invariants of affine.vector_load and affine.vector_store.
static void simplifyMinOrMaxExprWithOperands(AffineMap &map, ArrayRef< Value > operands, bool isMax)
Simplify the expressions in map while making use of lower or upper bounds of its operands.
static ParseResult parseAffineMinMaxOp(OpAsmParser &parser, OperationState &result)
static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, Region *region)
Returns true if the 'index' dimension of the memref defined by memrefDefOp is a statically shaped one...
static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef< Value > operands, int64_t k)
Check if e is known to be: 0 <= e < k.
static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser, OperationState &result, MinMaxKind kind)
Parses an affine map that can contain a min/max for groups of its results, e.g., max(expr-1,...
static AffineForOp buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds that may or may not be constants.
static void printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, OpAsmPrinter &printer)
Prints dimension and symbol list.
static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef< Value > operands)
Returns the largest known divisor of e.
static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
static void buildAffineLoopNestImpl(OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, LoopCreatorTy &&loopCreatorFn)
Builds an affine loop nest, using "loopCreatorFn" to create individual loop operations.
static LogicalResult foldLoopBounds(AffineForOp forOp)
Fold the constant bounds of a loop.
static LogicalResult verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, unsigned numDims)
Utility function to verify that a set of operands are valid dimension and symbol identifiers.
static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region)
Returns true if the result of the dim op is a valid symbol for region.
static bool isQTimesDPlusR(AffineExpr e, ArrayRef< Value > operands, int64_t &div, AffineExpr "ientTimesDiv, AffineExpr &rem)
Check if expression e is of the form d*e_1 + e_2 where 0 <= e_2 < d.
static ParseResult deduplicateAndResolveOperands(OpAsmParser &parser, ArrayRef< SmallVector< OpAsmParser::UnresolvedOperand >> operands, SmallVectorImpl< Value > &uniqueOperands, SmallVectorImpl< AffineExpr > &replacements, AffineExprKind kind)
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult verifyAffineMinMaxOp(T op)
static void printBound(AffineMapAttr boundMap, Operation::operand_range boundOperands, const char *prefix, OpAsmPrinter &p)
static LogicalResult verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr, Operation::operand_range mapOperands, MemRefType memrefType, unsigned numIndexOperands)
Verify common indexing invariants of affine.load, affine.store, affine.vector_load and affine....
static std::optional< SmallVector< int64_t > > foldCstValueToCstAttrBasis(ArrayRef< OpFoldResult > mixedBasis, MutableOperandRange mutableDynamicBasis, ArrayRef< Attribute > dynamicBasis)
Given mixed basis of affine.delinearize_index/linearize_index replace constant SSA values with the co...
static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map)
Canonicalize the result expression order of an affine map and return success if the order changed.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static Operation::operand_range getLowerBoundOperands(AffineForOp forOp)
static Operation::operand_range getUpperBoundOperands(AffineForOp forOp)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
AffineExprKind getKind() const
Return the classification for this type.
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
MLIRContext * getContext() const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
MLIRContext * getContext() const
bool isFunctionOfDim(unsigned position) const
Return true if any affine expression involves AffineDimExpr position.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
bool isFunctionOfSymbol(unsigned position) const
Return true if any affine expression involves AffineSymbolExpr position.
unsigned getNumResults() const
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
unsigned getNumInputs() const
AffineMap shiftSymbols(unsigned shift, unsigned offset=0) const
Replace symbols[offset ...
AffineExpr getResult(unsigned idx) const
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
static AffineMap getConstantMap(int64_t val, MLIRContext *context)
Returns a single constant result affine map.
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
LogicalResult constantFold(ArrayRef< Attribute > operandConstants, SmallVectorImpl< Attribute > &results, bool *hasPoison=nullptr) const
Folds the results of the application of an affine map on the provided operands to a constant if possi...
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
AffineMap getDimIdentityMap()
AffineMap getMultiDimIdentityMap(unsigned rank)
AffineExpr getAffineSymbolExpr(unsigned position)
AffineExpr getAffineConstantExpr(int64_t constant)
DenseIntElementsAttr getI32TensorAttr(ArrayRef< int32_t > values)
Tensor-typed DenseIntElementsAttr getters.
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
AffineMap getEmptyAffineMap()
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
MLIRContext * getContext() const
AffineMap getSymbolIdentityMap()
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
An attribute that represents a reference to a dense integer vector or tensor object.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
An integer set representing a conjunction of one or more affine equalities and inequalities.
unsigned getNumDims() const
static IntegerSet get(unsigned dimCount, unsigned symbolCount, ArrayRef< AffineExpr > constraints, ArrayRef< bool > eqFlags)
MLIRContext * getContext() const
unsigned getNumInputs() const
ArrayRef< AffineExpr > getConstraints() const
ArrayRef< bool > getEqFlags() const
Returns the equality bits, which specify whether each of the constraints is an equality or inequality...
unsigned getNumSymbols() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
void erase(unsigned subStart, unsigned subLen=1)
Erase the operands within the given sub-range.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void pop_back()
Pop last element from list.
Attribute erase(StringAttr name)
Erase the attribute with the given name from the list.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseAffineMapOfSSAIds(SmallVectorImpl< UnresolvedOperand > &operands, Attribute &map, StringRef attrName, NamedAttrList &attrs, Delimiter delimiter=Delimiter::Square)=0
Parses an affine map attribute where dims and symbols are SSA operands.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseAffineExprOfSSAIds(SmallVectorImpl< UnresolvedOperand > &dimOperands, SmallVectorImpl< UnresolvedOperand > &symbOperands, AffineExpr &expr)=0
Parses an affine expression where dims and symbols are SSA operands.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, ValueRange symOperands)=0
Prints an affine expression of SSA ids with SSA id names used instead of dims and symbols.
virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, ValueRange operands)=0
Prints an affine map of SSA ids, where SSA id names are used in place of dims/symbols.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
A trait of region holding operations that defines a new scope for polyhedral optimization purposes.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
operand_range::iterator operand_iterator
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
std::vector< SmallVector< int64_t, 8 > > operandExprStack
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
AffineBound represents a lower or upper bound in the for operation.
AffineDmaStartOp starts a non-blocking DMA operation that transfers data from a source memref to a de...
AffineDmaWaitOp blocks until the completion of a DMA operation associated with the tag element 'tag[i...
An AffineValueMap is an affine map plus its ML value operands and results for analysis purposes.
LogicalResult canonicalize()
Attempts to canonicalize the map and operands.
ArrayRef< Value > getOperands() const
AffineExpr getResult(unsigned i)
AffineMap getAffineMap() const
unsigned getNumResults() const
Operation * getOwner() const
Return the owner of this operand.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
void buildAffineLoopNest(OpBuilder &builder, Location loc, ArrayRef< int64_t > lbs, ArrayRef< int64_t > ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn=nullptr)
Builds a perfect nest of affine.for loops, i.e., each loop except the innermost one contains only ano...
void fullyComposeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Given an affine map map and its input operands, this method composes into map, maps of AffineApplyOps...
void extractForInductionVars(ArrayRef< AffineForOp > forInsts, SmallVectorImpl< Value > *ivs)
Extracts the induction variables from a list of AffineForOps and places them in the output argument i...
bool isValidDim(Value value)
Returns true if the given Value can be used as a dimension id in the region of the closest surroundin...
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
bool isAffineInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp or AffineParallelOp.
AffineForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
bool isAffineForInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp.
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
bool isTopLevelValue(Value value)
A utility function to check if a value is defined at the top level of an op with trait AffineScope or...
void canonicalizeSetAndOperands(IntegerSet *set, SmallVectorImpl< Value > *operands)
Canonicalizes an integer set the same way canonicalizeMapAndOperands does for affine maps.
void extractInductionVars(ArrayRef< Operation * > affineOps, SmallVectorImpl< Value > &ivs)
Extracts the induction variables from a list of either AffineForOp or AffineParallelOp and places the...
bool isValidSymbol(Value value)
Returns true if the given value can be used as a symbol in the region of the closest surrounding op t...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
AffineParallelOp getAffineParallelInductionVarOwner(Value val)
Returns true if the provided value is among the induction variables of an AffineParallelOp.
Region * getAffineScope(Operation *op)
Returns the closest region enclosing op that is held by an operation with trait AffineScope; nullptr ...
ParseResult parseDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl< Value > &operands, unsigned &numDims)
Parses dimension and symbol list.
bool isAffineParallelInductionVar(Value val)
Returns true if val is the induction variable of an AffineParallelOp.
AffineMinOp makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns an AffineMinOp obtained by composing map and operands with AffineApplyOps supplying those ope...
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
AffineMap removeDuplicateExprs(AffineMap map)
Returns a map with the same dimension and symbol count as map, but whose results are the unique affin...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::optional< int64_t > getBoundForAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols, ArrayRef< std::optional< int64_t >> constLowerBounds, ArrayRef< std::optional< int64_t >> constUpperBounds, bool isUpper)
Get a lower or upper (depending on isUpper) bound for expr while using the constant lower and upper b...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ SymbolId
Symbolic identifier.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool isStrided(MemRefType t)
Return "true" if the layout for t is compatible with strided semantics.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
AffineMap foldAttributesIntoMap(Builder &b, AffineMap map, ArrayRef< OpFoldResult > operands, SmallVector< Value > &remainingValues)
Fold all attributes among the given operands into the affine map.
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Canonicalize the affine map result expression order of an affine min/max operation.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Remove duplicated expressions in affine min/max ops.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Merge an affine min/max op to its consumers if its consumer is also an affine min/max op.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.