23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/ScopeExit.h"
25 #include "llvm/ADT/SmallBitVector.h"
26 #include "llvm/ADT/SmallVectorExtras.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/MathExtras.h"
36 using llvm::divideCeilSigned;
37 using llvm::divideFloorSigned;
40 #define DEBUG_TYPE "affine-ops"
42 #include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
49 if (
auto arg = llvm::dyn_cast<BlockArgument>(value))
50 return arg.getParentRegion() == region;
73 if (llvm::isa<BlockArgument>(value))
74 return legalityCheck(mapping.
lookup(value), dest);
81 bool isDimLikeOp = isa<ShapedDimOpInterface>(value.
getDefiningOp());
92 return llvm::all_of(values, [&](
Value v) {
99 template <
typename OpTy>
102 static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
103 AffineWriteOpInterface>::value,
104 "only ops with affine read/write interface are supported");
111 dimOperands, src, dest, mapping,
115 symbolOperands, src, dest, mapping,
132 op.getMapOperands(), src, dest, mapping,
137 op.getMapOperands(), src, dest, mapping,
164 if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
169 if (!llvm::hasSingleElement(*src))
177 if (
auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
178 if (iface.hasNoEffect())
186 .Case<AffineApplyOp, AffineReadOpInterface,
187 AffineWriteOpInterface>([&](
auto op) {
212 isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
216 bool shouldAnalyzeRecursively(
Operation *op)
const final {
return true; }
224 void AffineDialect::initialize() {
227 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
229 addInterfaces<AffineInlinerInterface>();
230 declarePromisedInterfaces<ValueBoundsOpInterface, AffineApplyOp, AffineMaxOp,
239 if (
auto poison = dyn_cast<ub::PoisonAttr>(value))
240 return builder.
create<ub::PoisonOp>(loc, type, poison);
241 return arith::ConstantOp::materialize(builder, value, type, loc);
249 if (
auto arg = llvm::dyn_cast<BlockArgument>(value)) {
265 while (
auto *parentOp = curOp->getParentOp()) {
276 if (!isa<AffineForOp, AffineIfOp, AffineParallelOp>(parentOp))
301 auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
327 if (
auto applyOp = dyn_cast<AffineApplyOp>(op))
328 return applyOp.isValidDim(region);
331 if (
auto dimOp = dyn_cast<ShapedDimOpInterface>(op))
339 template <
typename AnyMemRefDefOp>
342 MemRefType memRefType = memrefDefOp.getType();
345 if (index >= memRefType.getRank()) {
350 if (!memRefType.isDynamicDim(index))
353 unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
354 return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
366 if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
374 if (!index.has_value())
378 Operation *op = dimOp.getShapedValue().getDefiningOp();
379 while (
auto castOp = dyn_cast<memref::CastOp>(op)) {
381 if (isa<UnrankedMemRefType>(castOp.getSource().getType()))
383 op = castOp.getSource().getDefiningOp();
388 int64_t i = index.value();
390 .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
392 .Default([](
Operation *) {
return false; });
459 if (
isPure(defOp) && llvm::all_of(defOp->getOperands(), [&](
Value operand) {
460 return affine::isValidSymbol(operand, region);
466 if (
auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp))
490 printer <<
'(' << operands.take_front(numDims) <<
')';
491 if (operands.size() > numDims)
492 printer <<
'[' << operands.drop_front(numDims) <<
']';
502 numDims = opInfos.size();
516 template <
typename OpTy>
521 for (
auto operand : operands) {
522 if (opIt++ < numDims) {
524 return op.emitOpError(
"operand cannot be used as a dimension id");
526 return op.emitOpError(
"operand cannot be used as a symbol");
537 return AffineValueMap(getAffineMap(), getOperands(), getResult());
544 AffineMapAttr mapAttr;
550 auto map = mapAttr.getValue();
552 if (map.getNumDims() != numDims ||
553 numDims + map.getNumSymbols() != result.
operands.size()) {
555 "dimension or symbol index mismatch");
558 result.
types.append(map.getNumResults(), indexTy);
563 p <<
" " << getMapAttr();
565 getAffineMap().getNumDims(), p);
576 "operand count and affine map dimension and symbol count must match");
580 return emitOpError(
"mapping must produce one value");
586 for (
Value operand : getMapOperands().drop_front(affineMap.
getNumDims())) {
588 return emitError(
"dimensional operand cannot be used as a symbol");
597 return llvm::all_of(getOperands(),
605 return llvm::all_of(getOperands(),
612 return llvm::all_of(getOperands(),
619 return llvm::all_of(getOperands(), [&](
Value operand) {
625 auto map = getAffineMap();
628 auto expr = map.getResult(0);
629 if (
auto dim = dyn_cast<AffineDimExpr>(expr))
630 return getOperand(dim.getPosition());
631 if (
auto sym = dyn_cast<AffineSymbolExpr>(expr))
632 return getOperand(map.getNumDims() + sym.getPosition());
636 bool hasPoison =
false;
638 map.constantFold(adaptor.getMapOperands(), result, &hasPoison);
641 if (failed(foldResult))
658 auto dimExpr = dyn_cast<AffineDimExpr>(e);
668 Value operand = operands[dimExpr.getPosition()];
669 int64_t operandDivisor = 1;
673 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() == 0) {
674 operandDivisor = forOp.getStepAsInt();
676 uint64_t lbLargestKnownDivisor =
677 forOp.getLowerBoundMap().getLargestKnownDivisorOfMapExprs();
678 operandDivisor = std::gcd(lbLargestKnownDivisor, forOp.getStepAsInt());
681 return operandDivisor;
688 if (
auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
689 int64_t constVal = constExpr.getValue();
690 return constVal >= 0 && constVal < k;
692 auto dimExpr = dyn_cast<AffineDimExpr>(e);
695 Value operand = operands[dimExpr.getPosition()];
699 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() >= 0 &&
700 forOp.hasConstantUpperBound() && forOp.getConstantUpperBound() <= k) {
716 auto bin = dyn_cast<AffineBinaryOpExpr>(e);
724 quotientTimesDiv = llhs;
730 quotientTimesDiv = rlhs;
740 if (forOp && forOp.hasConstantLowerBound())
741 return forOp.getConstantLowerBound();
748 if (!forOp || !forOp.hasConstantUpperBound())
753 if (forOp.hasConstantLowerBound()) {
754 return forOp.getConstantUpperBound() - 1 -
755 (forOp.getConstantUpperBound() - forOp.getConstantLowerBound() - 1) %
756 forOp.getStepAsInt();
758 return forOp.getConstantUpperBound() - 1;
769 constLowerBounds.reserve(operands.size());
770 constUpperBounds.reserve(operands.size());
771 for (
Value operand : operands) {
776 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr))
777 return constExpr.getValue();
792 constLowerBounds.reserve(operands.size());
793 constUpperBounds.reserve(operands.size());
794 for (
Value operand : operands) {
799 std::optional<int64_t> lowerBound;
800 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
801 lowerBound = constExpr.getValue();
804 constLowerBounds, constUpperBounds,
815 auto binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
826 binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
834 lhs = binExpr.getLHS();
835 rhs = binExpr.getRHS();
836 auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
840 int64_t rhsConstVal = rhsConst.getValue();
842 if (rhsConstVal <= 0)
847 std::optional<int64_t> lhsLbConst =
849 std::optional<int64_t> lhsUbConst =
851 if (lhsLbConst && lhsUbConst) {
852 int64_t lhsLbConstVal = *lhsLbConst;
853 int64_t lhsUbConstVal = *lhsUbConst;
857 divideFloorSigned(lhsLbConstVal, rhsConstVal) ==
858 divideFloorSigned(lhsUbConstVal, rhsConstVal)) {
860 divideFloorSigned(lhsLbConstVal, rhsConstVal), context);
866 divideCeilSigned(lhsLbConstVal, rhsConstVal) ==
867 divideCeilSigned(lhsUbConstVal, rhsConstVal)) {
874 lhsLbConstVal < rhsConstVal && lhsUbConstVal < rhsConstVal) {
886 if (
isQTimesDPlusR(lhs, operands, divisor, quotientTimesDiv, rem)) {
887 if (rhsConstVal % divisor == 0 &&
889 expr = quotientTimesDiv.
floorDiv(rhsConst);
890 }
else if (divisor % rhsConstVal == 0 &&
892 expr = rem % rhsConst;
918 if (operands.empty())
924 constLowerBounds.reserve(operands.size());
925 constUpperBounds.reserve(operands.size());
926 for (
Value operand : operands) {
940 if (
auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
941 lowerBounds.push_back(constExpr.getValue());
942 upperBounds.push_back(constExpr.getValue());
944 lowerBounds.push_back(
946 constLowerBounds, constUpperBounds,
948 upperBounds.push_back(
950 constLowerBounds, constUpperBounds,
959 unsigned i = exprEn.index();
961 if (lowerBounds[i] && upperBounds[i] && *lowerBounds[i] == *upperBounds[i])
966 if (!upperBounds[i]) {
967 irredundantExprs.push_back(e);
973 auto otherLowerBound = en.value();
974 unsigned pos = en.index();
975 if (pos == i || !otherLowerBound)
977 if (*otherLowerBound > *upperBounds[i])
979 if (*otherLowerBound < *upperBounds[i])
984 if (upperBounds[pos] && lowerBounds[i] &&
985 lowerBounds[i] == upperBounds[i] &&
986 otherLowerBound == *upperBounds[pos] && i < pos)
990 irredundantExprs.push_back(e);
992 if (!lowerBounds[i]) {
993 irredundantExprs.push_back(e);
998 auto otherUpperBound = en.value();
999 unsigned pos = en.index();
1000 if (pos == i || !otherUpperBound)
1002 if (*otherUpperBound < *lowerBounds[i])
1004 if (*otherUpperBound > *lowerBounds[i])
1006 if (lowerBounds[pos] && upperBounds[i] &&
1007 lowerBounds[i] == upperBounds[i] &&
1008 otherUpperBound == lowerBounds[pos] && i < pos)
1012 irredundantExprs.push_back(e);
1024 static void LLVM_ATTRIBUTE_UNUSED
1026 assert(map.
getNumInputs() == operands.size() &&
"invalid operands for map");
1032 newResults.push_back(expr);
1049 unsigned dimOrSymbolPosition,
1053 bool isDimReplacement = (dimOrSymbolPosition < dims.size());
1054 unsigned pos = isDimReplacement ? dimOrSymbolPosition
1055 : dimOrSymbolPosition - dims.size();
1056 Value &v = isDimReplacement ? dims[pos] : syms[pos];
1069 AffineMap composeMap = affineApply.getAffineMap();
1070 assert(composeMap.
getNumResults() == 1 &&
"affine.apply with >1 results");
1072 affineApply.getMapOperands().end());
1086 dims.append(composeDims.begin(), composeDims.end());
1087 syms.append(composeSyms.begin(), composeSyms.end());
1088 *map = map->
replace(toReplace, replacementExpr, dims.size(), syms.size());
1117 for (
unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
1129 unsigned nDims = 0, nSyms = 0;
1131 dimReplacements.reserve(dims.size());
1132 symReplacements.reserve(syms.size());
1133 for (
auto *container : {&dims, &syms}) {
1134 bool isDim = (container == &dims);
1135 auto &repls = isDim ? dimReplacements : symReplacements;
1137 Value v = en.value();
1141 "map is function of unexpected expr@pos");
1147 operands->push_back(v);
1160 while (llvm::any_of(*operands, [](
Value v) {
1174 return b.
create<AffineApplyOp>(loc, map, valueOperands);
1196 for (
unsigned i : llvm::seq<unsigned>(0, map.
getNumResults())) {
1203 llvm::append_range(dims,
1205 llvm::append_range(symbols,
1212 operands = llvm::to_vector(llvm::concat<Value>(dims, symbols));
1221 assert(map.
getNumResults() == 1 &&
"building affine.apply with !=1 result");
1231 AffineApplyOp applyOp =
1236 for (
unsigned i = 0, e = constOperands.size(); i != e; ++i)
1241 if (failed(applyOp->fold(constOperands, foldResults)) ||
1242 foldResults.empty()) {
1244 listener->notifyOperationInserted(applyOp, {});
1245 return applyOp.getResult();
1249 return llvm::getSingleElement(foldResults);
1267 return llvm::map_to_vector(llvm::seq<unsigned>(0, map.
getNumResults()),
1269 return makeComposedFoldedAffineApply(
1270 b, loc, map.getSubMap({i}), operands);
1274 template <
typename OpTy>
1286 return makeComposedMinMax<AffineMinOp>(b, loc, map, operands);
1289 template <
typename OpTy>
1301 auto minMaxOp = makeComposedMinMax<OpTy>(newBuilder, loc, map, operands);
1305 for (
unsigned i = 0, e = constOperands.size(); i != e; ++i)
1310 if (failed(minMaxOp->fold(constOperands, foldResults)) ||
1311 foldResults.empty()) {
1313 listener->notifyOperationInserted(minMaxOp, {});
1314 return minMaxOp.getResult();
1318 return llvm::getSingleElement(foldResults);
1325 return makeComposedFoldedMinMax<AffineMinOp>(b, loc, map, operands);
1332 return makeComposedFoldedMinMax<AffineMaxOp>(b, loc, map, operands);
1337 template <
class MapOrSet>
1340 if (!mapOrSet || operands->empty())
1343 assert(mapOrSet->getNumInputs() == operands->size() &&
1344 "map/set inputs must match number of operands");
1346 auto *context = mapOrSet->getContext();
1348 resultOperands.reserve(operands->size());
1350 remappedSymbols.reserve(operands->size());
1351 unsigned nextDim = 0;
1352 unsigned nextSym = 0;
1353 unsigned oldNumSyms = mapOrSet->getNumSymbols();
1355 for (
unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
1356 if (i < mapOrSet->getNumDims()) {
1360 remappedSymbols.push_back((*operands)[i]);
1363 resultOperands.push_back((*operands)[i]);
1366 resultOperands.push_back((*operands)[i]);
1370 resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
1371 *operands = resultOperands;
1372 *mapOrSet = mapOrSet->replaceDimsAndSymbols(
1373 dimRemapping, {}, nextDim, oldNumSyms + nextSym);
1375 assert(mapOrSet->getNumInputs() == operands->size() &&
1376 "map/set inputs must match number of operands");
1385 template <
class MapOrSet>
1388 if (!mapOrSet || operands.empty())
1391 unsigned numOperands = operands.size();
1393 assert(mapOrSet.getNumInputs() == numOperands &&
1394 "map/set inputs must match number of operands");
1396 auto *context = mapOrSet.getContext();
1398 resultOperands.reserve(numOperands);
1400 remappedDims.reserve(numOperands);
1402 symOperands.reserve(mapOrSet.getNumSymbols());
1403 unsigned nextSym = 0;
1404 unsigned nextDim = 0;
1405 unsigned oldNumDims = mapOrSet.getNumDims();
1407 resultOperands.assign(operands.begin(), operands.begin() + oldNumDims);
1408 for (
unsigned i = oldNumDims, e = mapOrSet.getNumInputs(); i != e; ++i) {
1411 symRemapping[i - oldNumDims] =
1413 remappedDims.push_back(operands[i]);
1416 symOperands.push_back(operands[i]);
1420 append_range(resultOperands, remappedDims);
1421 append_range(resultOperands, symOperands);
1422 operands = resultOperands;
1423 mapOrSet = mapOrSet.replaceDimsAndSymbols(
1424 {}, symRemapping, oldNumDims + nextDim, nextSym);
1426 assert(mapOrSet.getNumInputs() == operands.size() &&
1427 "map/set inputs must match number of operands");
1431 template <
class MapOrSet>
1434 static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
1435 "Argument must be either of AffineMap or IntegerSet type");
1437 if (!mapOrSet || operands->empty())
1440 assert(mapOrSet->getNumInputs() == operands->size() &&
1441 "map/set inputs must match number of operands");
1443 canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
1444 legalizeDemotedDims<MapOrSet>(*mapOrSet, *operands);
1447 llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
1448 llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
1450 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr))
1451 usedDims[dimExpr.getPosition()] =
true;
1452 else if (
auto symExpr = dyn_cast<AffineSymbolExpr>(expr))
1453 usedSyms[symExpr.getPosition()] =
true;
1456 auto *context = mapOrSet->getContext();
1459 resultOperands.reserve(operands->size());
1461 llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
1463 unsigned nextDim = 0;
1464 for (
unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
1467 auto it = seenDims.find((*operands)[i]);
1468 if (it == seenDims.end()) {
1470 resultOperands.push_back((*operands)[i]);
1471 seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
1473 dimRemapping[i] = it->second;
1477 llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
1479 unsigned nextSym = 0;
1480 for (
unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
1486 IntegerAttr operandCst;
1487 if (
matchPattern((*operands)[i + mapOrSet->getNumDims()],
1494 auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
1495 if (it == seenSymbols.end()) {
1497 resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
1498 seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
1501 symRemapping[i] = it->second;
1504 *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
1506 *operands = resultOperands;
1511 canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
1516 canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
1523 template <
typename AffineOpTy>
1532 LogicalResult matchAndRewrite(AffineOpTy affineOp,
1535 llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
1536 AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp,
1537 AffineVectorStoreOp, AffineVectorLoadOp>::value,
1538 "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op "
1540 auto map = affineOp.getAffineMap();
1542 auto oldOperands = affineOp.getMapOperands();
1547 if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
1548 resultOperands.begin()))
1551 replaceAffineOp(rewriter, affineOp, map, resultOperands);
1559 void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
1566 void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
1570 prefetch, prefetch.getMemref(), map, mapOperands, prefetch.getIsWrite(),
1571 prefetch.getLocalityHint(), prefetch.getIsDataCache());
1574 void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
1578 store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
1581 void SimplifyAffineOp<AffineVectorLoadOp>::replaceAffineOp(
1585 vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map,
1589 void SimplifyAffineOp<AffineVectorStoreOp>::replaceAffineOp(
1593 vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map,
1598 template <
typename AffineOpTy>
1599 void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
1608 results.
add<SimplifyAffineOp<AffineApplyOp>>(context);
1639 p <<
" " << getSrcMemRef() <<
'[';
1641 p <<
"], " << getDstMemRef() <<
'[';
1643 p <<
"], " << getTagMemRef() <<
'[';
1647 p <<
", " << getStride();
1648 p <<
", " << getNumElementsPerStride();
1650 p <<
" : " << getSrcMemRefType() <<
", " << getDstMemRefType() <<
", "
1651 << getTagMemRefType();
1663 AffineMapAttr srcMapAttr;
1666 AffineMapAttr dstMapAttr;
1669 AffineMapAttr tagMapAttr;
1684 getSrcMapAttrStrName(),
1688 getDstMapAttrStrName(),
1692 getTagMapAttrStrName(),
1701 if (!strideInfo.empty() && strideInfo.size() != 2) {
1703 "expected two stride related operands");
1705 bool isStrided = strideInfo.size() == 2;
1710 if (types.size() != 3)
1728 if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
1729 dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
1730 tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1732 "memref operand count not equal to map.numInputs");
1736 LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
1737 if (!llvm::isa<MemRefType>(getOperand(getSrcMemRefOperandIndex()).
getType()))
1738 return emitOpError(
"expected DMA source to be of memref type");
1739 if (!llvm::isa<MemRefType>(getOperand(getDstMemRefOperandIndex()).
getType()))
1740 return emitOpError(
"expected DMA destination to be of memref type");
1741 if (!llvm::isa<MemRefType>(getOperand(getTagMemRefOperandIndex()).
getType()))
1742 return emitOpError(
"expected DMA tag to be of memref type");
1744 unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
1745 getDstMap().getNumInputs() +
1746 getTagMap().getNumInputs();
1747 if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
1748 getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
1749 return emitOpError(
"incorrect number of operands");
1753 for (
auto idx : getSrcIndices()) {
1754 if (!idx.getType().isIndex())
1755 return emitOpError(
"src index to dma_start must have 'index' type");
1758 "src index must be a valid dimension or symbol identifier");
1760 for (
auto idx : getDstIndices()) {
1761 if (!idx.getType().isIndex())
1762 return emitOpError(
"dst index to dma_start must have 'index' type");
1765 "dst index must be a valid dimension or symbol identifier");
1767 for (
auto idx : getTagIndices()) {
1768 if (!idx.getType().isIndex())
1769 return emitOpError(
"tag index to dma_start must have 'index' type");
1772 "tag index must be a valid dimension or symbol identifier");
1783 void AffineDmaStartOp::getEffects(
1809 p <<
" " << getTagMemRef() <<
'[';
1814 p <<
" : " << getTagMemRef().getType();
1825 AffineMapAttr tagMapAttr;
1834 getTagMapAttrStrName(),
1843 if (!llvm::isa<MemRefType>(type))
1845 "expected tag to be of memref type");
1847 if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1849 "tag memref operand count != to map.numInputs");
1853 LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
1854 if (!llvm::isa<MemRefType>(getOperand(0).
getType()))
1855 return emitOpError(
"expected DMA tag to be of memref type");
1857 for (
auto idx : getTagIndices()) {
1858 if (!idx.getType().isIndex())
1859 return emitOpError(
"index to dma_wait must have 'index' type");
1862 "index must be a valid dimension or symbol identifier");
1873 void AffineDmaWaitOp::getEffects(
1889 ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
1890 assert(((!lbMap && lbOperands.empty()) ||
1892 "lower bound operand count does not match the affine map");
1893 assert(((!ubMap && ubOperands.empty()) ||
1895 "upper bound operand count does not match the affine map");
1896 assert(step > 0 &&
"step has to be a positive integer constant");
1902 getOperandSegmentSizeAttr(),
1904 static_cast<int32_t>(ubOperands.size()),
1905 static_cast<int32_t>(iterArgs.size())}));
1907 for (
Value val : iterArgs)
1929 Value inductionVar =
1931 for (
Value val : iterArgs)
1932 bodyBlock->
addArgument(val.getType(), val.getLoc());
1937 if (iterArgs.empty() && !bodyBuilder) {
1938 ensureTerminator(*bodyRegion, builder, result.
location);
1939 }
else if (bodyBuilder) {
1942 bodyBuilder(builder, result.
location, inductionVar,
1948 int64_t ub, int64_t step,
ValueRange iterArgs,
1949 BodyBuilderFn bodyBuilder) {
1952 return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,
1956 LogicalResult AffineForOp::verifyRegions() {
1959 auto *body = getBody();
1960 if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
1961 return emitOpError(
"expected body to have a single index argument for the "
1962 "induction variable");
1966 if (getLowerBoundMap().getNumInputs() > 0)
1968 getLowerBoundMap().getNumDims())))
1971 if (getUpperBoundMap().getNumInputs() > 0)
1973 getUpperBoundMap().getNumDims())))
1975 if (getLowerBoundMap().getNumResults() < 1)
1976 return emitOpError(
"expected lower bound map to have at least one result");
1977 if (getUpperBoundMap().getNumResults() < 1)
1978 return emitOpError(
"expected upper bound map to have at least one result");
1980 unsigned opNumResults = getNumResults();
1981 if (opNumResults == 0)
1987 if (getNumIterOperands() != opNumResults)
1989 "mismatch between the number of loop-carried values and results");
1990 if (getNumRegionIterArgs() != opNumResults)
1992 "mismatch between the number of basic block args and results");
2002 bool failedToParsedMinMax =
2006 auto boundAttrStrName =
2007 isLower ? AffineForOp::getLowerBoundMapAttrName(result.
name)
2008 : AffineForOp::getUpperBoundMapAttrName(result.
name);
2015 if (!boundOpInfos.empty()) {
2017 if (boundOpInfos.size() > 1)
2019 "expected only one loop bound operand");
2044 if (
auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(boundAttr)) {
2045 unsigned currentNumOperands = result.
operands.size();
2050 auto map = affineMapAttr.getValue();
2054 "dim operand count and affine map dim count must match");
2056 unsigned numDimAndSymbolOperands =
2057 result.
operands.size() - currentNumOperands;
2058 if (numDims + map.
getNumSymbols() != numDimAndSymbolOperands)
2061 "symbol operand count and affine map symbol count must match");
2067 return p.
emitError(attrLoc,
"lower loop bound affine map with "
2068 "multiple results requires 'max' prefix");
2070 return p.
emitError(attrLoc,
"upper loop bound affine map with multiple "
2071 "results requires 'min' prefix");
2077 if (
auto integerAttr = llvm::dyn_cast<IntegerAttr>(boundAttr)) {
2087 "expected valid affine map representation for loop bounds");
2099 int64_t numOperands = result.
operands.size();
2102 int64_t numLbOperands = result.
operands.size() - numOperands;
2105 numOperands = result.
operands.size();
2108 int64_t numUbOperands = result.
operands.size() - numOperands;
2113 getStepAttrName(result.
name),
2117 IntegerAttr stepAttr;
2119 getStepAttrName(result.
name).data(),
2123 if (stepAttr.getValue().isNegative())
2126 "expected step to be representable as a positive signed integer");
2134 regionArgs.push_back(inductionVariable);
2142 for (
auto argOperandType :
2143 llvm::zip(llvm::drop_begin(regionArgs), operands, result.
types)) {
2144 Type type = std::get<2>(argOperandType);
2145 std::get<0>(argOperandType).type = type;
2153 getOperandSegmentSizeAttr(),
2155 static_cast<int32_t>(numUbOperands),
2156 static_cast<int32_t>(operands.size())}));
2160 if (regionArgs.size() != result.
types.size() + 1)
2163 "mismatch between the number of loop-carried values and results");
2167 AffineForOp::ensureTerminator(*body, builder, result.
location);
2189 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
2190 p << constExpr.getValue();
2198 if (dyn_cast<AffineSymbolExpr>(expr)) {
2214 unsigned AffineForOp::getNumIterOperands() {
2215 AffineMap lbMap = getLowerBoundMapAttr().getValue();
2216 AffineMap ubMap = getUpperBoundMapAttr().getValue();
2221 std::optional<MutableArrayRef<OpOperand>>
2222 AffineForOp::getYieldedValuesMutable() {
2223 return cast<AffineYieldOp>(getBody()->getTerminator()).getOperandsMutable();
2235 if (getStepAsInt() != 1)
2236 p <<
" step " << getStepAsInt();
2238 bool printBlockTerminators =
false;
2239 if (getNumIterOperands() > 0) {
2241 auto regionArgs = getRegionIterArgs();
2242 auto operands = getInits();
2244 llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](
auto it) {
2245 p << std::get<0>(it) <<
" = " << std::get<1>(it);
2247 p <<
") -> (" << getResultTypes() <<
")";
2248 printBlockTerminators =
true;
2253 printBlockTerminators);
2255 (*this)->getAttrs(),
2256 {getLowerBoundMapAttrName(getOperation()->getName()),
2257 getUpperBoundMapAttrName(getOperation()->getName()),
2258 getStepAttrName(getOperation()->getName()),
2259 getOperandSegmentSizeAttr()});
2264 auto foldLowerOrUpperBound = [&forOp](
bool lower) {
2268 auto boundOperands =
2269 lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
2270 for (
auto operand : boundOperands) {
2273 operandConstants.push_back(operandCst);
2277 lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
2279 "bound maps should have at least one result");
2281 if (failed(boundMap.
constantFold(operandConstants, foldedResults)))
2285 assert(!foldedResults.empty() &&
"bounds should have at least one result");
2286 auto maxOrMin = llvm::cast<IntegerAttr>(foldedResults[0]).getValue();
2287 for (
unsigned i = 1, e = foldedResults.size(); i < e; i++) {
2288 auto foldedResult = llvm::cast<IntegerAttr>(foldedResults[i]).getValue();
2289 maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
2290 : llvm::APIntOps::smin(maxOrMin, foldedResult);
2292 lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
2293 : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
2298 bool folded =
false;
2299 if (!forOp.hasConstantLowerBound())
2300 folded |= succeeded(foldLowerOrUpperBound(
true));
2303 if (!forOp.hasConstantUpperBound())
2304 folded |= succeeded(foldLowerOrUpperBound(
false));
2305 return success(folded);
2313 auto lbMap = forOp.getLowerBoundMap();
2314 auto ubMap = forOp.getUpperBoundMap();
2315 auto prevLbMap = lbMap;
2316 auto prevUbMap = ubMap;
2329 if (lbMap == prevLbMap && ubMap == prevUbMap)
2332 if (lbMap != prevLbMap)
2333 forOp.setLowerBound(lbOperands, lbMap);
2334 if (ubMap != prevUbMap)
2335 forOp.setUpperBound(ubOperands, ubMap);
2341 static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
2342 int64_t step = forOp.getStepAsInt();
2343 if (!forOp.hasConstantBounds() || step <= 0)
2344 return std::nullopt;
2345 int64_t lb = forOp.getConstantLowerBound();
2346 int64_t ub = forOp.getConstantUpperBound();
2347 return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
2355 LogicalResult matchAndRewrite(AffineForOp forOp,
2358 if (!llvm::hasSingleElement(*forOp.getBody()))
2360 if (forOp.getNumResults() == 0)
2362 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
2363 if (tripCount && *tripCount == 0) {
2366 rewriter.
replaceOp(forOp, forOp.getInits());
2370 auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
2371 auto iterArgs = forOp.getRegionIterArgs();
2372 bool hasValDefinedOutsideLoop =
false;
2373 bool iterArgsNotInOrder =
false;
2374 for (
unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
2375 Value val = yieldOp.getOperand(i);
2376 auto *iterArgIt = llvm::find(iterArgs, val);
2379 if (val == forOp.getInductionVar())
2381 if (iterArgIt == iterArgs.end()) {
2383 assert(forOp.isDefinedOutsideOfLoop(val) &&
2384 "must be defined outside of the loop");
2385 hasValDefinedOutsideLoop =
true;
2386 replacements.push_back(val);
2388 unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
2390 iterArgsNotInOrder =
true;
2391 replacements.push_back(forOp.getInits()[pos]);
2396 if (!tripCount.has_value() &&
2397 (hasValDefinedOutsideLoop || iterArgsNotInOrder))
2401 if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
2403 rewriter.
replaceOp(forOp, replacements);
2411 results.
add<AffineForEmptyLoopFolder>(context);
2415 assert((point.
isParent() || point == getRegion()) &&
"invalid region point");
2422 void AffineForOp::getSuccessorRegions(
2424 assert((point.
isParent() || point == getRegion()) &&
"expected loop region");
2429 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*
this);
2430 if (point.
isParent() && tripCount.has_value()) {
2431 if (tripCount.value() > 0) {
2432 regions.push_back(
RegionSuccessor(&getRegion(), getRegionIterArgs()));
2435 if (tripCount.value() == 0) {
2443 if (!point.
isParent() && tripCount && *tripCount == 1) {
2450 regions.push_back(
RegionSuccessor(&getRegion(), getRegionIterArgs()));
2456 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(op);
2457 return tripCount && *tripCount == 0;
2460 LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
2470 results.assign(getInits().begin(), getInits().end());
2473 return success(folded);
2486 assert(map.
getNumResults() >= 1 &&
"bound map has at least one result");
2487 getLowerBoundOperandsMutable().assign(lbOperands);
2488 setLowerBoundMap(map);
2493 assert(map.
getNumResults() >= 1 &&
"bound map has at least one result");
2494 getUpperBoundOperandsMutable().assign(ubOperands);
2495 setUpperBoundMap(map);
2498 bool AffineForOp::hasConstantLowerBound() {
2499 return getLowerBoundMap().isSingleConstant();
2502 bool AffineForOp::hasConstantUpperBound() {
2503 return getUpperBoundMap().isSingleConstant();
2506 int64_t AffineForOp::getConstantLowerBound() {
2507 return getLowerBoundMap().getSingleConstantResult();
2510 int64_t AffineForOp::getConstantUpperBound() {
2511 return getUpperBoundMap().getSingleConstantResult();
2514 void AffineForOp::setConstantLowerBound(int64_t value) {
2518 void AffineForOp::setConstantUpperBound(int64_t value) {
2522 AffineForOp::operand_range AffineForOp::getControlOperands() {
2527 bool AffineForOp::matchingBoundOperandList() {
2528 auto lbMap = getLowerBoundMap();
2529 auto ubMap = getUpperBoundMap();
2535 for (
unsigned i = 0, e = lbMap.
getNumInputs(); i < e; i++) {
2537 if (getOperand(i) != getOperand(numOperands + i))
2545 std::optional<SmallVector<Value>> AffineForOp::getLoopInductionVars() {
2549 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopLowerBounds() {
2550 if (!hasConstantLowerBound())
2551 return std::nullopt;
2554 OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))};
2557 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopSteps() {
2563 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopUpperBounds() {
2564 if (!hasConstantUpperBound())
2568 OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()))};
2571 FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
2573 bool replaceInitOperandUsesInLoop,
2578 auto inits = llvm::to_vector(getInits());
2579 inits.append(newInitOperands.begin(), newInitOperands.end());
2580 AffineForOp newLoop = rewriter.
create<AffineForOp>(
2585 auto yieldOp = cast<AffineYieldOp>(getBody()->getTerminator());
2587 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
2592 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
2593 assert(newInitOperands.size() == newYieldedValues.size() &&
2594 "expected as many new yield values as new iter operands");
2596 yieldOp.getOperandsMutable().append(newYieldedValues);
2601 rewriter.
mergeBlocks(getBody(), newLoop.getBody(),
2602 newLoop.getBody()->getArguments().take_front(
2603 getBody()->getNumArguments()));
2605 if (replaceInitOperandUsesInLoop) {
2608 for (
auto it : llvm::zip(newInitOperands, newIterArgs)) {
2619 newLoop->getResults().take_front(getNumResults()));
2620 return cast<LoopLikeOpInterface>(newLoop.getOperation());
2648 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2649 if (!ivArg || !ivArg.getOwner() || !ivArg.getOwner()->getParent())
2650 return AffineForOp();
2652 ivArg.getOwner()->getParent()->getParentOfType<AffineForOp>())
2654 return forOp.getInductionVar() == val ? forOp : AffineForOp();
2655 return AffineForOp();
2659 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2660 if (!ivArg || !ivArg.getOwner())
2663 auto parallelOp = dyn_cast<AffineParallelOp>(containingOp);
2664 if (parallelOp && llvm::is_contained(parallelOp.getIVs(), val))
2673 ivs->reserve(forInsts.size());
2674 for (
auto forInst : forInsts)
2675 ivs->push_back(forInst.getInductionVar());
2680 ivs.reserve(affineOps.size());
2683 if (
auto forOp = dyn_cast<AffineForOp>(op))
2684 ivs.push_back(forOp.getInductionVar());
2685 else if (
auto parallelOp = dyn_cast<AffineParallelOp>(op))
2686 for (
size_t i = 0; i < parallelOp.getBody()->getNumArguments(); i++)
2687 ivs.push_back(parallelOp.getBody()->getArgument(i));
2693 template <
typename BoundListTy,
typename LoopCreatorTy>
2698 LoopCreatorTy &&loopCreatorFn) {
2699 assert(lbs.size() == ubs.size() &&
"Mismatch in number of arguments");
2700 assert(lbs.size() == steps.size() &&
"Mismatch in number of arguments");
2712 ivs.reserve(lbs.size());
2713 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
2719 if (i == e - 1 && bodyBuilderFn) {
2721 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2723 nestedBuilder.
create<AffineYieldOp>(nestedLoc);
2728 auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
2736 int64_t ub, int64_t step,
2737 AffineForOp::BodyBuilderFn bodyBuilderFn) {
2738 return builder.
create<AffineForOp>(loc, lb, ub, step,
2739 std::nullopt, bodyBuilderFn);
2746 AffineForOp::BodyBuilderFn bodyBuilderFn) {
2749 if (lbConst && ubConst)
2751 ubConst.value(), step, bodyBuilderFn);
2754 std::nullopt, bodyBuilderFn);
2782 LogicalResult matchAndRewrite(AffineIfOp ifOp,
2784 if (ifOp.getElseRegion().empty() ||
2785 !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
2800 LogicalResult matchAndRewrite(AffineIfOp op,
2803 auto isTriviallyFalse = [](
IntegerSet iSet) {
2804 return iSet.isEmptyIntegerSet();
2808 return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 &&
2809 iSet.getConstraint(0) == 0);
2812 IntegerSet affineIfConditions = op.getIntegerSet();
2814 if (isTriviallyFalse(affineIfConditions)) {
2818 if (op.getNumResults() == 0 && !op.hasElse()) {
2824 blockToMove = op.getElseBlock();
2825 }
else if (isTriviallyTrue(affineIfConditions)) {
2826 blockToMove = op.getThenBlock();
2844 rewriter.
eraseOp(blockToMoveTerminator);
2852 void AffineIfOp::getSuccessorRegions(
2861 if (getElseRegion().empty()) {
2862 regions.push_back(getResults());
2878 auto conditionAttr =
2879 (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2881 return emitOpError(
"requires an integer set attribute named 'condition'");
2884 IntegerSet condition = conditionAttr.getValue();
2886 return emitOpError(
"operand count and condition integer set dimension and "
2887 "symbol count must match");
2899 IntegerSetAttr conditionAttr;
2902 AffineIfOp::getConditionAttrStrName(),
2908 auto set = conditionAttr.getValue();
2909 if (set.getNumDims() != numDims)
2912 "dim operand count and integer set dim count must match");
2913 if (numDims + set.getNumSymbols() != result.
operands.size())
2916 "symbol operand count and integer set symbol count must match");
2930 AffineIfOp::ensureTerminator(*thenRegion, parser.
getBuilder(),
2937 AffineIfOp::ensureTerminator(*elseRegion, parser.
getBuilder(),
2949 auto conditionAttr =
2950 (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2951 p <<
" " << conditionAttr;
2953 conditionAttr.getValue().getNumDims(), p);
2960 auto &elseRegion = this->getElseRegion();
2961 if (!elseRegion.
empty()) {
2970 getConditionAttrStrName());
2975 ->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName())
2979 void AffineIfOp::setIntegerSet(
IntegerSet newSet) {
2985 (*this)->setOperands(operands);
2990 bool withElseRegion) {
2991 assert(resultTypes.empty() || withElseRegion);
3000 if (resultTypes.empty())
3001 AffineIfOp::ensureTerminator(*thenRegion, builder, result.
location);
3004 if (withElseRegion) {
3006 if (resultTypes.empty())
3007 AffineIfOp::ensureTerminator(*elseRegion, builder, result.
location);
3013 AffineIfOp::build(builder, result, {}, set, args,
3028 if (llvm::none_of(operands,
3039 auto set = getIntegerSet();
3045 if (getIntegerSet() == set && llvm::equal(operands, getOperands()))
3048 setConditional(set, operands);
3054 results.
add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context);
3063 assert(operands.size() == 1 + map.
getNumInputs() &&
"inconsistent operands");
3067 auto memrefType = llvm::cast<MemRefType>(operands[0].
getType());
3068 result.
types.push_back(memrefType.getElementType());
3073 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
3076 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
3078 result.
types.push_back(memrefType.getElementType());
3083 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
3084 int64_t rank = memrefType.getRank();
3089 build(builder, result, memref, map, indices);
3098 AffineMapAttr mapAttr;
3103 AffineLoadOp::getMapAttrStrName(),
3113 p <<
" " << getMemRef() <<
'[';
3114 if (AffineMapAttr mapAttr =
3115 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3119 {getMapAttrStrName()});
3125 template <
typename AffineMemOpTy>
3126 static LogicalResult
3129 MemRefType memrefType,
unsigned numIndexOperands) {
3132 return op->emitOpError(
"affine map num results must equal memref rank");
3134 return op->emitOpError(
"expects as many subscripts as affine map inputs");
3136 for (
auto idx : mapOperands) {
3137 if (!idx.getType().isIndex())
3138 return op->emitOpError(
"index to load must have 'index' type");
3148 if (
getType() != memrefType.getElementType())
3149 return emitOpError(
"result type must match element type of memref");
3152 *
this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3153 getMapOperands(), memrefType,
3154 getNumOperands() - 1)))
3162 results.
add<SimplifyAffineOp<AffineLoadOp>>(context);
3171 auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
3178 auto global = dyn_cast_or_null<memref::GlobalOp>(
3185 llvm::dyn_cast_or_null<DenseElementsAttr>(global.getConstantInitValue());
3189 if (
auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(cstAttr))
3190 return splatAttr.getSplatValue<
Attribute>();
3192 if (!getAffineMap().isConstant())
3194 auto indices = llvm::to_vector<4>(
3195 llvm::map_range(getAffineMap().getConstantResults(),
3196 [](int64_t v) -> uint64_t {
return v; }));
3197 return cstAttr.getValues<
Attribute>()[indices];
3207 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
3218 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
3219 int64_t rank = memrefType.getRank();
3224 build(builder, result, valueToStore, memref, map, indices);
3233 AffineMapAttr mapAttr;
3238 mapOperands, mapAttr, AffineStoreOp::getMapAttrStrName(),
3249 p <<
" " << getValueToStore();
3250 p <<
", " << getMemRef() <<
'[';
3251 if (AffineMapAttr mapAttr =
3252 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3256 {getMapAttrStrName()});
3263 if (getValueToStore().
getType() != memrefType.getElementType())
3265 "value to store must have the same type as memref element type");
3268 *
this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3269 getMapOperands(), memrefType,
3270 getNumOperands() - 2)))
3278 results.
add<SimplifyAffineOp<AffineStoreOp>>(context);
3281 LogicalResult AffineStoreOp::fold(FoldAdaptor adaptor,
3291 template <
typename T>
3294 if (op.getNumOperands() !=
3295 op.getMap().getNumDims() + op.getMap().getNumSymbols())
3296 return op.emitOpError(
3297 "operand count and affine map dimension and symbol count must match");
3299 if (op.getMap().getNumResults() == 0)
3300 return op.emitOpError(
"affine map expect at least one result");
3304 template <
typename T>
3306 p <<
' ' << op->getAttr(T::getMapAttrStrName());
3307 auto operands = op.getOperands();
3308 unsigned numDims = op.getMap().getNumDims();
3309 p <<
'(' << operands.take_front(numDims) <<
')';
3311 if (operands.size() != numDims)
3312 p <<
'[' << operands.drop_front(numDims) <<
']';
3314 {T::getMapAttrStrName()});
3317 template <
typename T>
3324 AffineMapAttr mapAttr;
3340 template <
typename T>
3342 static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
3343 "expected affine min or max op");
3349 auto foldedMap = op.getMap().partialConstantFold(operands, &results);
3351 if (foldedMap.getNumSymbols() == 1 && foldedMap.isSymbolIdentity())
3352 return op.getOperand(0);
3355 if (results.empty()) {
3357 if (foldedMap == op.getMap())
3360 return op.getResult();
3364 auto resultIt = std::is_same<T, AffineMinOp>::value
3365 ? llvm::min_element(results)
3366 : llvm::max_element(results);
3367 if (resultIt == results.end())
3373 template <
typename T>
3379 AffineMap oldMap = affineOp.getAffineMap();
3385 if (!llvm::is_contained(newExprs, expr))
3386 newExprs.push_back(expr);
3416 template <
typename T>
3422 AffineMap oldMap = affineOp.getAffineMap();
3424 affineOp.getMapOperands().take_front(oldMap.
getNumDims());
3426 affineOp.getMapOperands().take_back(oldMap.
getNumSymbols());
3428 auto newDimOperands = llvm::to_vector<8>(dimOperands);
3429 auto newSymOperands = llvm::to_vector<8>(symOperands);
3437 if (
auto symExpr = dyn_cast<AffineSymbolExpr>(expr)) {
3438 Value symValue = symOperands[symExpr.getPosition()];
3440 producerOps.push_back(producerOp);
3443 }
else if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
3444 Value dimValue = dimOperands[dimExpr.getPosition()];
3446 producerOps.push_back(producerOp);
3453 newExprs.push_back(expr);
3456 if (producerOps.empty())
3463 for (T producerOp : producerOps) {
3464 AffineMap producerMap = producerOp.getAffineMap();
3465 unsigned numProducerDims = producerMap.
getNumDims();
3470 producerOp.getMapOperands().take_front(numProducerDims);
3472 producerOp.getMapOperands().take_back(numProducerSyms);
3473 newDimOperands.append(dimValues.begin(), dimValues.end());
3474 newSymOperands.append(symValues.begin(), symValues.end());
3478 newExprs.push_back(expr.shiftDims(numProducerDims, numUsedDims)
3479 .shiftSymbols(numProducerSyms, numUsedSyms));
3482 numUsedDims += numProducerDims;
3483 numUsedSyms += numProducerSyms;
3489 llvm::to_vector<8>(llvm::concat<Value>(newDimOperands, newSymOperands));
3508 if (!resultExpr.isPureAffine())
3513 if (failed(flattenResult))
3526 if (llvm::is_sorted(flattenedExprs))
3531 llvm::to_vector(llvm::seq<unsigned>(0, map.
getNumResults()));
3532 llvm::sort(resultPermutation, [&](
unsigned lhs,
unsigned rhs) {
3533 return flattenedExprs[lhs] < flattenedExprs[rhs];
3536 for (
unsigned idx : resultPermutation)
3557 template <
typename T>
3563 AffineMap map = affineOp.getAffineMap();
3571 template <
typename T>
3577 if (affineOp.getMap().getNumResults() != 1)
3580 affineOp.getOperands());
3608 return parseAffineMinMaxOp<AffineMinOp>(parser, result);
3636 return parseAffineMinMaxOp<AffineMaxOp>(parser, result);
3655 IntegerAttr hintInfo;
3657 StringRef readOrWrite, cacheType;
3659 AffineMapAttr mapAttr;
3663 AffinePrefetchOp::getMapAttrStrName(),
3669 AffinePrefetchOp::getLocalityHintAttrStrName(),
3679 if (readOrWrite !=
"read" && readOrWrite !=
"write")
3681 "rw specifier has to be 'read' or 'write'");
3682 result.
addAttribute(AffinePrefetchOp::getIsWriteAttrStrName(),
3685 if (cacheType !=
"data" && cacheType !=
"instr")
3687 "cache type has to be 'data' or 'instr'");
3689 result.
addAttribute(AffinePrefetchOp::getIsDataCacheAttrStrName(),
3696 p <<
" " << getMemref() <<
'[';
3697 AffineMapAttr mapAttr =
3698 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3701 p <<
']' <<
", " << (getIsWrite() ?
"write" :
"read") <<
", "
3702 <<
"locality<" << getLocalityHint() <<
">, "
3703 << (getIsDataCache() ?
"data" :
"instr");
3705 (*this)->getAttrs(),
3706 {getMapAttrStrName(), getLocalityHintAttrStrName(),
3707 getIsDataCacheAttrStrName(), getIsWriteAttrStrName()});
3712 auto mapAttr = (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3716 return emitOpError(
"affine.prefetch affine map num results must equal"
3719 return emitOpError(
"too few operands");
3721 if (getNumOperands() != 1)
3722 return emitOpError(
"too few operands");
3726 for (
auto idx : getMapOperands()) {
3729 "index must be a valid dimension or symbol identifier");
3737 results.
add<SimplifyAffineOp<AffinePrefetchOp>>(context);
3740 LogicalResult AffinePrefetchOp::fold(FoldAdaptor adaptor,
3755 auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) {
3759 build(builder, result, resultTypes, reductions, lbs, {}, ubs,
3769 assert(llvm::all_of(lbMaps,
3771 return m.getNumDims() == lbMaps[0].getNumDims() &&
3772 m.getNumSymbols() == lbMaps[0].getNumSymbols();
3774 "expected all lower bounds maps to have the same number of dimensions "
3776 assert(llvm::all_of(ubMaps,
3778 return m.getNumDims() == ubMaps[0].getNumDims() &&
3779 m.getNumSymbols() == ubMaps[0].getNumSymbols();
3781 "expected all upper bounds maps to have the same number of dimensions "
3783 assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) &&
3784 "expected lower bound maps to have as many inputs as lower bound "
3786 assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) &&
3787 "expected upper bound maps to have as many inputs as upper bound "
3795 for (arith::AtomicRMWKind reduction : reductions)
3796 reductionAttrs.push_back(
3808 groups.reserve(groups.size() + maps.size());
3809 exprs.reserve(maps.size());
3811 llvm::append_range(exprs, m.getResults());
3812 groups.push_back(m.getNumResults());
3814 return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs,
3820 AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups);
3821 AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups);
3839 for (
unsigned i = 0, e = steps.size(); i < e; ++i)
3841 if (resultTypes.empty())
3842 ensureTerminator(*bodyRegion, builder, result.
location);
3846 return {&getRegion()};
3849 unsigned AffineParallelOp::getNumDims() {
return getSteps().size(); }
3851 AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
3852 return getOperands().take_front(getLowerBoundsMap().getNumInputs());
3855 AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
3856 return getOperands().drop_front(getLowerBoundsMap().getNumInputs());
3859 AffineMap AffineParallelOp::getLowerBoundMap(
unsigned pos) {
3860 auto values = getLowerBoundsGroups().getValues<int32_t>();
3862 for (
unsigned i = 0; i < pos; ++i)
3864 return getLowerBoundsMap().getSliceMap(start, values[pos]);
3867 AffineMap AffineParallelOp::getUpperBoundMap(
unsigned pos) {
3868 auto values = getUpperBoundsGroups().getValues<int32_t>();
3870 for (
unsigned i = 0; i < pos; ++i)
3872 return getUpperBoundsMap().getSliceMap(start, values[pos]);
3876 return AffineValueMap(getLowerBoundsMap(), getLowerBoundsOperands());
3880 return AffineValueMap(getUpperBoundsMap(), getUpperBoundsOperands());
3883 std::optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
3884 if (hasMinMaxBounds())
3885 return std::nullopt;
3890 AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
3893 for (
unsigned i = 0, e = rangesValueMap.
getNumResults(); i < e; ++i) {
3894 auto expr = rangesValueMap.
getResult(i);
3895 auto cst = dyn_cast<AffineConstantExpr>(expr);
3897 return std::nullopt;
3898 out.push_back(cst.getValue());
3903 Block *AffineParallelOp::getBody() {
return &getRegion().
front(); }
3905 OpBuilder AffineParallelOp::getBodyBuilder() {
3906 return OpBuilder(getBody(), std::prev(getBody()->end()));
3911 "operands to map must match number of inputs");
3913 auto ubOperands = getUpperBoundsOperands();
3916 newOperands.append(ubOperands.begin(), ubOperands.end());
3917 (*this)->setOperands(newOperands);
3924 "operands to map must match number of inputs");
3927 newOperands.append(ubOperands.begin(), ubOperands.end());
3928 (*this)->setOperands(newOperands);
3934 setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
3939 arith::AtomicRMWKind op) {
3941 case arith::AtomicRMWKind::addf:
3942 return isa<FloatType>(resultType);
3943 case arith::AtomicRMWKind::addi:
3944 return isa<IntegerType>(resultType);
3945 case arith::AtomicRMWKind::assign:
3947 case arith::AtomicRMWKind::mulf:
3948 return isa<FloatType>(resultType);
3949 case arith::AtomicRMWKind::muli:
3950 return isa<IntegerType>(resultType);
3951 case arith::AtomicRMWKind::maximumf:
3952 return isa<FloatType>(resultType);
3953 case arith::AtomicRMWKind::minimumf:
3954 return isa<FloatType>(resultType);
3955 case arith::AtomicRMWKind::maxs: {
3956 auto intType = llvm::dyn_cast<IntegerType>(resultType);
3957 return intType && intType.isSigned();
3959 case arith::AtomicRMWKind::mins: {
3960 auto intType = llvm::dyn_cast<IntegerType>(resultType);
3961 return intType && intType.isSigned();
3963 case arith::AtomicRMWKind::maxu: {
3964 auto intType = llvm::dyn_cast<IntegerType>(resultType);
3965 return intType && intType.isUnsigned();
3967 case arith::AtomicRMWKind::minu: {
3968 auto intType = llvm::dyn_cast<IntegerType>(resultType);
3969 return intType && intType.isUnsigned();
3971 case arith::AtomicRMWKind::ori:
3972 return isa<IntegerType>(resultType);
3973 case arith::AtomicRMWKind::andi:
3974 return isa<IntegerType>(resultType);
3981 auto numDims = getNumDims();
3984 getSteps().size() != numDims || getBody()->getNumArguments() != numDims) {
3985 return emitOpError() <<
"the number of region arguments ("
3986 << getBody()->getNumArguments()
3987 <<
") and the number of map groups for lower ("
3988 << getLowerBoundsGroups().getNumElements()
3989 <<
") and upper bound ("
3990 << getUpperBoundsGroups().getNumElements()
3991 <<
"), and the number of steps (" << getSteps().size()
3992 <<
") must all match";
3995 unsigned expectedNumLBResults = 0;
3996 for (APInt v : getLowerBoundsGroups()) {
3997 unsigned results = v.getZExtValue();
3999 return emitOpError()
4000 <<
"expected lower bound map to have at least one result";
4001 expectedNumLBResults += results;
4003 if (expectedNumLBResults != getLowerBoundsMap().getNumResults())
4004 return emitOpError() <<
"expected lower bounds map to have "
4005 << expectedNumLBResults <<
" results";
4006 unsigned expectedNumUBResults = 0;
4007 for (APInt v : getUpperBoundsGroups()) {
4008 unsigned results = v.getZExtValue();
4010 return emitOpError()
4011 <<
"expected upper bound map to have at least one result";
4012 expectedNumUBResults += results;
4014 if (expectedNumUBResults != getUpperBoundsMap().getNumResults())
4015 return emitOpError() <<
"expected upper bounds map to have "
4016 << expectedNumUBResults <<
" results";
4018 if (getReductions().size() != getNumResults())
4019 return emitOpError(
"a reduction must be specified for each output");
4025 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
4026 if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
4027 return emitOpError(
"invalid reduction attribute");
4028 auto kind = arith::symbolizeAtomicRMWKind(intAttr.getInt()).value();
4030 return emitOpError(
"result type cannot match reduction attribute");
4036 getLowerBoundsMap().getNumDims())))
4040 getUpperBoundsMap().getNumDims())))
4045 LogicalResult AffineValueMap::canonicalize() {
4047 auto newMap = getAffineMap();
4049 if (newMap == getAffineMap() && newOperands == operands)
4051 reset(newMap, newOperands);
4064 if (!lbCanonicalized && !ubCanonicalized)
4067 if (lbCanonicalized)
4069 if (ubCanonicalized)
4075 LogicalResult AffineParallelOp::fold(FoldAdaptor adaptor,
4087 StringRef keyword) {
4090 ValueRange dimOperands = operands.take_front(numDims);
4091 ValueRange symOperands = operands.drop_front(numDims);
4093 for (llvm::APInt groupSize : group) {
4097 unsigned size = groupSize.getZExtValue();
4102 p << keyword <<
'(';
4112 p <<
" (" << getBody()->getArguments() <<
") = (";
4114 getLowerBoundsOperands(),
"max");
4117 getUpperBoundsOperands(),
"min");
4120 bool elideSteps = llvm::all_of(steps, [](int64_t step) {
return step == 1; });
4123 llvm::interleaveComma(steps, p);
4126 if (getNumResults()) {
4128 llvm::interleaveComma(getReductions(), p, [&](
auto &attr) {
4129 arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(
4130 llvm::cast<IntegerAttr>(attr).getInt());
4131 p <<
"\"" << arith::stringifyAtomicRMWKind(sym) <<
"\"";
4133 p <<
") -> (" << getResultTypes() <<
")";
4140 (*this)->getAttrs(),
4141 {AffineParallelOp::getReductionsAttrStrName(),
4142 AffineParallelOp::getLowerBoundsMapAttrStrName(),
4143 AffineParallelOp::getLowerBoundsGroupsAttrStrName(),
4144 AffineParallelOp::getUpperBoundsMapAttrStrName(),
4145 AffineParallelOp::getUpperBoundsGroupsAttrStrName(),
4146 AffineParallelOp::getStepsAttrStrName()});
4159 "expected operands to be dim or symbol expression");
4162 for (
const auto &list : operands) {
4166 for (
Value operand : valueOperands) {
4167 unsigned pos = std::distance(uniqueOperands.begin(),
4168 llvm::find(uniqueOperands, operand));
4169 if (pos == uniqueOperands.size())
4170 uniqueOperands.push_back(operand);
4171 replacements.push_back(
4181 enum class MinMaxKind { Min, Max };
4205 const llvm::StringLiteral tmpAttrStrName =
"__pseudo_bound_map";
4207 StringRef mapName =
kind == MinMaxKind::Min
4208 ? AffineParallelOp::getUpperBoundsMapAttrStrName()
4209 : AffineParallelOp::getLowerBoundsMapAttrStrName();
4210 StringRef groupsName =
4211 kind == MinMaxKind::Min
4212 ? AffineParallelOp::getUpperBoundsGroupsAttrStrName()
4213 : AffineParallelOp::getLowerBoundsGroupsAttrStrName();
4230 auto parseOperands = [&]() {
4232 kind == MinMaxKind::Min ?
"min" :
"max"))) {
4233 mapOperands.clear();
4240 llvm::append_range(flatExprs, map.getValue().getResults());
4242 auto dimsRef = operandsRef.take_front(map.getValue().getNumDims());
4244 auto symsRef = operandsRef.drop_front(map.getValue().getNumDims());
4246 flatDimOperands.append(map.getValue().getNumResults(), dims);
4247 flatSymOperands.append(map.getValue().getNumResults(), syms);
4248 numMapsPerGroup.push_back(map.getValue().getNumResults());
4251 flatSymOperands.emplace_back(),
4252 flatExprs.emplace_back())))
4254 numMapsPerGroup.push_back(1);
4261 unsigned totalNumDims = 0;
4262 unsigned totalNumSyms = 0;
4263 for (
unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
4264 unsigned numDims = flatDimOperands[i].size();
4265 unsigned numSyms = flatSymOperands[i].size();
4266 flatExprs[i] = flatExprs[i]
4267 .shiftDims(numDims, totalNumDims)
4268 .shiftSymbols(numSyms, totalNumSyms);
4269 totalNumDims += numDims;
4270 totalNumSyms += numSyms;
4282 result.
operands.append(dimOperands.begin(), dimOperands.end());
4283 result.
operands.append(symOperands.begin(), symOperands.end());
4286 auto flatMap =
AffineMap::get(totalNumDims, totalNumSyms, flatExprs,
4288 flatMap = flatMap.replaceDimsAndSymbols(
4289 dimRplacements, symRepacements, dimOperands.size(), symOperands.size());
4313 AffineMapAttr stepsMapAttr;
4318 result.
addAttribute(AffineParallelOp::getStepsAttrStrName(),
4322 AffineParallelOp::getStepsAttrStrName(),
4329 auto stepsMap = stepsMapAttr.getValue();
4330 for (
const auto &result : stepsMap.getResults()) {
4331 auto constExpr = dyn_cast<AffineConstantExpr>(result);
4334 "steps must be constant integers");
4335 steps.push_back(constExpr.getValue());
4337 result.
addAttribute(AffineParallelOp::getStepsAttrStrName(),
4347 auto parseAttributes = [&]() -> ParseResult {
4357 std::optional<arith::AtomicRMWKind> reduction =
4358 arith::symbolizeAtomicRMWKind(attrVal.getValue());
4360 return parser.
emitError(loc,
"invalid reduction value: ") << attrVal;
4361 reductions.push_back(
4369 result.
addAttribute(AffineParallelOp::getReductionsAttrStrName(),
4378 for (
auto &iv : ivs)
4379 iv.type = indexType;
4385 AffineParallelOp::ensureTerminator(*body, builder, result.
location);
4394 auto *parentOp = (*this)->getParentOp();
4395 auto results = parentOp->getResults();
4396 auto operands = getOperands();
4398 if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
4399 return emitOpError() <<
"only terminates affine.if/for/parallel regions";
4400 if (parentOp->getNumResults() != getNumOperands())
4401 return emitOpError() <<
"parent of yield must have same number of "
4402 "results as the yield operands";
4403 for (
auto it : llvm::zip(results, operands)) {
4405 return emitOpError() <<
"types mismatch between yield op and its parent";
4418 assert(operands.size() == 1 + map.
getNumInputs() &&
"inconsistent operands");
4422 result.
types.push_back(resultType);
4426 VectorType resultType,
Value memref,
4428 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
4432 result.
types.push_back(resultType);
4436 VectorType resultType,
Value memref,
4438 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
4439 int64_t rank = memrefType.getRank();
4444 build(builder, result, resultType, memref, map, indices);
4447 void AffineVectorLoadOp::getCanonicalizationPatterns(
RewritePatternSet &results,
4449 results.
add<SimplifyAffineOp<AffineVectorLoadOp>>(context);
4457 MemRefType memrefType;
4458 VectorType resultType;
4460 AffineMapAttr mapAttr;
4465 AffineVectorLoadOp::getMapAttrStrName(),
4476 p <<
" " << getMemRef() <<
'[';
4477 if (AffineMapAttr mapAttr =
4478 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4482 {getMapAttrStrName()});
4488 VectorType vectorType) {
4490 if (memrefType.getElementType() != vectorType.getElementType())
4492 "requires memref and vector types of the same elemental type");
4499 *
this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4500 getMapOperands(), memrefType,
4501 getNumOperands() - 1)))
4517 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
4528 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
4529 int64_t rank = memrefType.getRank();
4534 build(builder, result, valueToStore, memref, map, indices);
4536 void AffineVectorStoreOp::getCanonicalizationPatterns(
4538 results.
add<SimplifyAffineOp<AffineVectorStoreOp>>(context);
4545 MemRefType memrefType;
4546 VectorType resultType;
4549 AffineMapAttr mapAttr;
4555 AffineVectorStoreOp::getMapAttrStrName(),
4566 p <<
" " << getValueToStore();
4567 p <<
", " << getMemRef() <<
'[';
4568 if (AffineMapAttr mapAttr =
4569 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4573 {getMapAttrStrName()});
4574 p <<
" : " <<
getMemRefType() <<
", " << getValueToStore().getType();
4580 *
this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4581 getMapOperands(), memrefType,
4582 getNumOperands() - 2)))
4595 void AffineDelinearizeIndexOp::build(
OpBuilder &odsBuilder,
4599 bool hasOuterBound) {
4601 : staticBasis.size() + 1,
4603 build(odsBuilder, odsState, returnTypes, linearIndex, dynamicBasis,
4607 void AffineDelinearizeIndexOp::build(
OpBuilder &odsBuilder,
4610 bool hasOuterBound) {
4611 if (hasOuterBound && !basis.empty() && basis.front() ==
nullptr) {
4612 hasOuterBound =
false;
4613 basis = basis.drop_front();
4619 build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4623 void AffineDelinearizeIndexOp::build(
OpBuilder &odsBuilder,
4627 bool hasOuterBound) {
4628 if (hasOuterBound && !basis.empty() && basis.front() ==
OpFoldResult()) {
4629 hasOuterBound =
false;
4630 basis = basis.drop_front();
4635 build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4639 void AffineDelinearizeIndexOp::build(
OpBuilder &odsBuilder,
4642 bool hasOuterBound) {
4643 build(odsBuilder, odsState, linearIndex,
ValueRange{}, basis, hasOuterBound);
4648 if (getNumResults() != staticBasis.size() &&
4649 getNumResults() != staticBasis.size() + 1)
4650 return emitOpError(
"should return an index for each basis element and up "
4651 "to one extra index");
4653 auto dynamicMarkersCount = llvm::count_if(staticBasis, ShapedType::isDynamic);
4654 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
4656 "mismatch between dynamic and static basis (kDynamic marker but no "
4657 "corresponding dynamic basis entry) -- this can only happen due to an "
4658 "incorrect fold/rewrite");
4660 if (!llvm::all_of(staticBasis, [](int64_t v) {
4661 return v > 0 || ShapedType::isDynamic(v);
4663 return emitOpError(
"no basis element may be statically non-positive");
4672 static std::optional<SmallVector<int64_t>>
4676 uint64_t dynamicBasisIndex = 0;
4679 mutableDynamicBasis.
erase(dynamicBasisIndex);
4681 ++dynamicBasisIndex;
4686 if (dynamicBasisIndex == dynamicBasis.size())
4687 return std::nullopt;
4693 staticBasis.push_back(ShapedType::kDynamic);
4695 staticBasis.push_back(*basisVal);
4702 AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
4704 std::optional<SmallVector<int64_t>> maybeStaticBasis =
4706 adaptor.getDynamicBasis());
4707 if (maybeStaticBasis) {
4708 setStaticBasis(*maybeStaticBasis);
4713 if (getNumResults() == 1) {
4714 result.push_back(getLinearIndex());
4718 if (adaptor.getLinearIndex() ==
nullptr)
4721 if (!adaptor.getDynamicBasis().empty())
4724 int64_t highPart = cast<IntegerAttr>(adaptor.getLinearIndex()).getInt();
4725 Type attrType = getLinearIndex().getType();
4728 if (hasOuterBound())
4729 staticBasis = staticBasis.drop_front();
4730 for (int64_t modulus : llvm::reverse(staticBasis)) {
4731 result.push_back(
IntegerAttr::get(attrType, llvm::mod(highPart, modulus)));
4732 highPart = llvm::divideFloorSigned(highPart, modulus);
4735 std::reverse(result.begin(), result.end());
4741 if (hasOuterBound()) {
4742 if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
4744 getDynamicBasis().drop_front(), builder);
4746 return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
4750 return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
4755 if (!hasOuterBound())
4763 struct DropUnitExtentBasis
4767 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4770 std::optional<Value> zero = std::nullopt;
4771 Location loc = delinearizeOp->getLoc();
4774 zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
4775 return zero.value();
4781 for (
auto [index, basis] :
4783 std::optional<int64_t> basisVal =
4785 if (basisVal && *basisVal == 1)
4786 replacements[index] =
getZero();
4788 newBasis.push_back(basis);
4791 if (newBasis.size() == delinearizeOp.getNumResults())
4793 "no unit basis elements");
4795 if (!newBasis.empty()) {
4797 auto newDelinearizeOp = rewriter.
create<affine::AffineDelinearizeIndexOp>(
4798 loc, delinearizeOp.getLinearIndex(), newBasis);
4801 for (
auto &replacement : replacements) {
4804 replacement = newDelinearizeOp->
getResult(newIndex++);
4808 rewriter.
replaceOp(delinearizeOp, replacements);
4823 struct CancelDelinearizeOfLinearizeDisjointExactTail
4827 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4829 auto linearizeOp = delinearizeOp.getLinearIndex()
4830 .getDefiningOp<affine::AffineLinearizeIndexOp>();
4833 "index doesn't come from linearize");
4835 if (!linearizeOp.getDisjoint())
4838 ValueRange linearizeIns = linearizeOp.getMultiIndex();
4842 size_t numMatches = 0;
4843 for (
auto [linSize, delinSize] : llvm::zip(
4844 llvm::reverse(linearizeBasis), llvm::reverse(delinearizeBasis))) {
4845 if (linSize != delinSize)
4850 if (numMatches == 0)
4852 delinearizeOp,
"final basis element doesn't match linearize");
4855 if (numMatches == linearizeBasis.size() &&
4856 numMatches == delinearizeBasis.size() &&
4857 linearizeIns.size() == delinearizeOp.getNumResults()) {
4858 rewriter.
replaceOp(delinearizeOp, linearizeOp.getMultiIndex());
4862 Value newLinearize = rewriter.
create<affine::AffineLinearizeIndexOp>(
4863 linearizeOp.getLoc(), linearizeIns.drop_back(numMatches),
4865 linearizeOp.getDisjoint());
4866 auto newDelinearize = rewriter.
create<affine::AffineDelinearizeIndexOp>(
4867 delinearizeOp.getLoc(), newLinearize,
4869 delinearizeOp.hasOuterBound());
4871 mergedResults.append(linearizeIns.take_back(numMatches).begin(),
4872 linearizeIns.take_back(numMatches).end());
4873 rewriter.
replaceOp(delinearizeOp, mergedResults);
4891 struct SplitDelinearizeSpanningLastLinearizeArg final
4895 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4897 auto linearizeOp = delinearizeOp.getLinearIndex()
4898 .getDefiningOp<affine::AffineLinearizeIndexOp>();
4901 "index doesn't come from linearize");
4903 if (!linearizeOp.getDisjoint())
4905 "linearize isn't disjoint");
4907 int64_t target = linearizeOp.getStaticBasis().back();
4908 if (ShapedType::isDynamic(target))
4910 linearizeOp,
"linearize ends with dynamic basis value");
4912 int64_t sizeToSplit = 1;
4913 size_t elemsToSplit = 0;
4915 for (int64_t basisElem : llvm::reverse(basis)) {
4916 if (ShapedType::isDynamic(basisElem))
4918 delinearizeOp,
"dynamic basis element while scanning for split");
4919 sizeToSplit *= basisElem;
4922 if (sizeToSplit > target)
4924 "overshot last argument size");
4925 if (sizeToSplit == target)
4929 if (sizeToSplit < target)
4931 delinearizeOp,
"product of known basis elements doesn't exceed last "
4932 "linearize argument");
4934 if (elemsToSplit < 2)
4937 "need at least two elements to form the basis product");
4939 Value linearizeWithoutBack =
4940 rewriter.
create<affine::AffineLinearizeIndexOp>(
4941 linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(),
4942 linearizeOp.getDynamicBasis(),
4943 linearizeOp.getStaticBasis().drop_back(),
4944 linearizeOp.getDisjoint());
4945 auto delinearizeWithoutSplitPart =
4946 rewriter.
create<affine::AffineDelinearizeIndexOp>(
4947 delinearizeOp.getLoc(), linearizeWithoutBack,
4948 delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),
4949 delinearizeOp.hasOuterBound());
4950 auto delinearizeBack = rewriter.
create<affine::AffineDelinearizeIndexOp>(
4951 delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(),
4952 basis.take_back(elemsToSplit),
true);
4954 llvm::concat<Value>(delinearizeWithoutSplitPart.getResults(),
4955 delinearizeBack.getResults()));
4956 rewriter.
replaceOp(delinearizeOp, results);
4963 void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
4966 .insert<CancelDelinearizeOfLinearizeDisjointExactTail,
4967 DropUnitExtentBasis, SplitDelinearizeSpanningLastLinearizeArg>(
4975 void AffineLinearizeIndexOp::build(
OpBuilder &odsBuilder,
4979 if (!basis.empty() && basis.front() ==
Value())
4980 basis = basis.drop_front();
4985 build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
4988 void AffineLinearizeIndexOp::build(
OpBuilder &odsBuilder,
4994 basis = basis.drop_front();
4998 build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
5001 void AffineLinearizeIndexOp::build(
OpBuilder &odsBuilder,
5005 build(odsBuilder, odsState, multiIndex,
ValueRange{}, basis, disjoint);
5009 size_t numIndexes = getMultiIndex().size();
5010 size_t numBasisElems = getStaticBasis().size();
5011 if (numIndexes != numBasisElems && numIndexes != numBasisElems + 1)
5012 return emitOpError(
"should be passed a basis element for each index except "
5013 "possibly the first");
5015 auto dynamicMarkersCount =
5016 llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
5017 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
5019 "mismatch between dynamic and static basis (kDynamic marker but no "
5020 "corresponding dynamic basis entry) -- this can only happen due to an "
5021 "incorrect fold/rewrite");
5026 OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
5027 std::optional<SmallVector<int64_t>> maybeStaticBasis =
5029 adaptor.getDynamicBasis());
5030 if (maybeStaticBasis) {
5031 setStaticBasis(*maybeStaticBasis);
5035 if (getMultiIndex().empty())
5039 if (getMultiIndex().size() == 1)
5040 return getMultiIndex().front();
5042 if (llvm::any_of(adaptor.getMultiIndex(),
5043 [](
Attribute a) { return a == nullptr; }))
5046 if (!adaptor.getDynamicBasis().empty())
5051 for (
auto [length, indexAttr] :
5052 llvm::zip_first(llvm::reverse(getStaticBasis()),
5053 llvm::reverse(adaptor.getMultiIndex()))) {
5054 result = result + cast<IntegerAttr>(indexAttr).getInt() * stride;
5055 stride = stride * length;
5058 if (!hasOuterBound())
5061 cast<IntegerAttr>(adaptor.getMultiIndex().front()).getInt() * stride;
5068 if (hasOuterBound()) {
5069 if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
5071 getDynamicBasis().drop_front(), builder);
5073 return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
5077 return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
5082 if (!hasOuterBound())
5098 struct DropLinearizeUnitComponentsIfDisjointOrZero final
5102 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5105 size_t numIndices = multiIndex.size();
5107 newIndices.reserve(numIndices);
5109 newBasis.reserve(numIndices);
5111 if (!op.hasOuterBound()) {
5112 newIndices.push_back(multiIndex.front());
5113 multiIndex = multiIndex.drop_front();
5117 for (
auto [index, basisElem] : llvm::zip_equal(multiIndex, basis)) {
5119 if (!basisEntry || *basisEntry != 1) {
5120 newIndices.push_back(index);
5121 newBasis.push_back(basisElem);
5126 if (!op.getDisjoint() && (!indexValue || *indexValue != 0)) {
5127 newIndices.push_back(index);
5128 newBasis.push_back(basisElem);
5132 if (newIndices.size() == numIndices)
5134 "no unit basis entries to replace");
5136 if (newIndices.size() == 0) {
5141 op, newIndices, newBasis, op.getDisjoint());
5148 int64_t nDynamic = 0;
5158 dynamicPart.push_back(cast<Value>(term));
5162 if (
auto constant = dyn_cast<AffineConstantExpr>(result))
5164 return builder.
create<AffineApplyOp>(loc, result, dynamicPart).getResult();
5194 struct CancelLinearizeOfDelinearizePortion final
5205 unsigned linStart = 0;
5206 unsigned delinStart = 0;
5207 unsigned length = 0;
5211 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp,
5218 ValueRange multiIndex = linearizeOp.getMultiIndex();
5219 unsigned numLinArgs = multiIndex.size();
5220 unsigned linArgIdx = 0;
5224 while (linArgIdx < numLinArgs) {
5225 auto asResult = dyn_cast<OpResult>(multiIndex[linArgIdx]);
5231 auto delinearizeOp =
5232 dyn_cast<AffineDelinearizeIndexOp>(asResult.getOwner());
5233 if (!delinearizeOp) {
5250 unsigned delinArgIdx = asResult.getResultNumber();
5252 OpFoldResult firstDelinBound = delinBasis[delinArgIdx];
5254 bool boundsMatch = firstDelinBound == firstLinBound;
5255 bool bothAtFront = linArgIdx == 0 && delinArgIdx == 0;
5256 bool knownByDisjoint =
5257 linearizeOp.getDisjoint() && delinArgIdx == 0 && !firstDelinBound;
5258 if (!boundsMatch && !bothAtFront && !knownByDisjoint) {
5264 unsigned numDelinOuts = delinearizeOp.getNumResults();
5265 for (;
j + linArgIdx < numLinArgs &&
j + delinArgIdx < numDelinOuts;
5267 if (multiIndex[linArgIdx +
j] !=
5268 delinearizeOp.getResult(delinArgIdx +
j))
5270 if (linBasis[linArgIdx +
j] != delinBasis[delinArgIdx +
j])
5276 if (
j <= 1 || !alreadyMatchedDelinearize.insert(delinearizeOp).second) {
5280 matches.push_back(Match{delinearizeOp, linArgIdx, delinArgIdx,
j});
5284 if (matches.empty())
5286 linearizeOp,
"no run of delinearize outputs to deal with");
5294 newIndex.reserve(numLinArgs);
5296 newBasis.reserve(numLinArgs);
5297 unsigned prevMatchEnd = 0;
5298 for (Match m : matches) {
5299 unsigned gap = m.linStart - prevMatchEnd;
5300 llvm::append_range(newIndex, multiIndex.slice(prevMatchEnd, gap));
5301 llvm::append_range(newBasis, linBasisRef.slice(prevMatchEnd, gap));
5303 prevMatchEnd = m.linStart + m.length;
5305 PatternRewriter::InsertionGuard g(rewriter);
5309 linBasisRef.slice(m.linStart, m.length);
5316 if (m.length == m.delinearize.getNumResults()) {
5317 newIndex.push_back(m.delinearize.getLinearIndex());
5318 newBasis.push_back(newSize);
5326 newDelinBasis.erase(newDelinBasis.begin() + m.delinStart,
5327 newDelinBasis.begin() + m.delinStart + m.length);
5328 newDelinBasis.insert(newDelinBasis.begin() + m.delinStart, newSize);
5329 auto newDelinearize = rewriter.
create<AffineDelinearizeIndexOp>(
5330 m.delinearize.getLoc(), m.delinearize.getLinearIndex(),
5336 Value combinedElem = newDelinearize.getResult(m.delinStart);
5337 auto residualDelinearize = rewriter.
create<AffineDelinearizeIndexOp>(
5338 m.delinearize.getLoc(), combinedElem, basisToMerge);
5343 llvm::append_range(newDelinResults,
5344 newDelinearize.getResults().take_front(m.delinStart));
5345 llvm::append_range(newDelinResults, residualDelinearize.getResults());
5348 newDelinearize.getResults().drop_front(m.delinStart + 1));
5350 delinearizeReplacements.push_back(newDelinResults);
5351 newIndex.push_back(combinedElem);
5352 newBasis.push_back(newSize);
5354 llvm::append_range(newIndex, multiIndex.drop_front(prevMatchEnd));
5355 llvm::append_range(newBasis, linBasisRef.drop_front(prevMatchEnd));
5357 linearizeOp, newIndex, newBasis, linearizeOp.getDisjoint());
5359 for (
auto [m, newResults] :
5360 llvm::zip_equal(matches, delinearizeReplacements)) {
5361 if (newResults.empty())
5363 rewriter.
replaceOp(m.delinearize, newResults);
5374 struct DropLinearizeLeadingZero final
5378 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5380 Value leadingIdx = op.getMultiIndex().front();
5384 if (op.getMultiIndex().size() == 1) {
5391 if (op.hasOuterBound())
5392 newMixedBasis = newMixedBasis.drop_front();
5395 op, op.getMultiIndex().drop_front(), newMixedBasis, op.getDisjoint());
5401 void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
5403 patterns.add<CancelLinearizeOfDelinearizePortion, DropLinearizeLeadingZero,
5404 DropLinearizeUnitComponentsIfDisjointOrZero>(context);
5411 #define GET_OP_CLASSES
5412 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
static AffineForOp buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb, int64_t ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds known to be constants.
static bool hasTrivialZeroTripCount(AffineForOp op)
Returns true if the affine.for has zero iterations in trivial cases.
static void composeMultiResultAffineMap(AffineMap &map, SmallVectorImpl< Value > &operands)
Composes the given affine map with the given list of operands, pulling in the maps from any affine....
static LogicalResult verifyMemoryOpIndexing(AffineMemOpTy op, AffineMapAttr mapAttr, Operation::operand_range mapOperands, MemRefType memrefType, unsigned numIndexOperands)
Verify common indexing invariants of affine.load, affine.store, affine.vector_load and affine....
static void printAffineMinMaxOp(OpAsmPrinter &p, T op)
static bool isResultTypeMatchAtomicRMWKind(Type resultType, arith::AtomicRMWKind op)
static bool remainsLegalAfterInline(Value value, Region *src, Region *dest, const IRMapping &mapping, function_ref< bool(Value, Region *)> legalityCheck)
Checks if value known to be a legal affine dimension or symbol in src region remains legal if the ope...
static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr, DenseIntElementsAttr group, ValueRange operands, StringRef keyword)
Prints a lower(upper) bound of an affine parallel loop with max(min) conditions in it.
static void LLVM_ATTRIBUTE_UNUSED simplifyMapWithOperands(AffineMap &map, ArrayRef< Value > operands)
Simplify the map while exploiting information on the values in operands.
static OpFoldResult foldMinMaxOp(T op, ArrayRef< Attribute > operands)
Fold an affine min or max operation with the given operands.
static LogicalResult canonicalizeLoopBounds(AffineForOp forOp)
Canonicalize the bounds of the given loop.
static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims, unsigned numSymbols, ArrayRef< Value > operands)
Simplify expr while exploiting information from the values in operands.
static bool isValidAffineIndexOperand(Value value, Region *region)
static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
static void composeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Iterate over operands and fold away all those produced by an AffineApplyOp iteratively.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static ParseResult parseBound(bool isLower, OperationState &result, OpAsmParser &p)
Parse a for operation loop bounds.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static void composeSetAndOperands(IntegerSet &set, SmallVectorImpl< Value > &operands)
Compose any affine.apply ops feeding into operands of the integer set set by composing the maps of su...
static LogicalResult replaceDimOrSym(AffineMap *map, unsigned dimOrSymbolPosition, SmallVectorImpl< Value > &dims, SmallVectorImpl< Value > &syms)
Replace all occurrences of AffineExpr at position pos in map by the defining AffineApplyOp expression...
static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType)
Verify common invariants of affine.vector_load and affine.vector_store.
static void simplifyMinOrMaxExprWithOperands(AffineMap &map, ArrayRef< Value > operands, bool isMax)
Simplify the expressions in map while making use of lower or upper bounds of its operands.
static ParseResult parseAffineMinMaxOp(OpAsmParser &parser, OperationState &result)
static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, Region *region)
Returns true if the 'index' dimension of the memref defined by memrefDefOp is a statically shaped one...
static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef< Value > operands, int64_t k)
Check if e is known to be: 0 <= e < k.
static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser, OperationState &result, MinMaxKind kind)
Parses an affine map that can contain a min/max for groups of its results, e.g., max(expr-1,...
static AffineForOp buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds that may or may not be constants.
static void printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, OpAsmPrinter &printer)
Prints dimension and symbol list.
static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef< Value > operands)
Returns the largest known divisor of e.
static void legalizeDemotedDims(MapOrSet &mapOrSet, SmallVectorImpl< Value > &operands)
A valid affine dimension may appear as a symbol in affine.apply operations.
static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
static void buildAffineLoopNestImpl(OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, LoopCreatorTy &&loopCreatorFn)
Builds an affine loop nest, using "loopCreatorFn" to create individual loop operations.
static LogicalResult foldLoopBounds(AffineForOp forOp)
Fold the constant bounds of a loop.
static LogicalResult verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, unsigned numDims)
Utility function to verify that a set of operands are valid dimension and symbol identifiers.
static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region)
Returns true if the result of the dim op is a valid symbol for region.
static bool isQTimesDPlusR(AffineExpr e, ArrayRef< Value > operands, int64_t &div, AffineExpr "ientTimesDiv, AffineExpr &rem)
Check if expression e is of the form d*e_1 + e_2 where 0 <= e_2 < d.
static ParseResult deduplicateAndResolveOperands(OpAsmParser &parser, ArrayRef< SmallVector< OpAsmParser::UnresolvedOperand >> operands, SmallVectorImpl< Value > &uniqueOperands, SmallVectorImpl< AffineExpr > &replacements, AffineExprKind kind)
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult verifyAffineMinMaxOp(T op)
static void printBound(AffineMapAttr boundMap, Operation::operand_range boundOperands, const char *prefix, OpAsmPrinter &p)
static std::optional< SmallVector< int64_t > > foldCstValueToCstAttrBasis(ArrayRef< OpFoldResult > mixedBasis, MutableOperandRange mutableDynamicBasis, ArrayRef< Attribute > dynamicBasis)
Given mixed basis of affine.delinearize_index/linearize_index replace constant SSA values with the co...
static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map)
Canonicalize the result expression order of an affine map and return success if the order changed.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
union mlir::linalg::@1197::ArityGroupAndKind::Kind kind
static Operation::operand_range getLowerBoundOperands(AffineForOp forOp)
static Operation::operand_range getUpperBoundOperands(AffineForOp forOp)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
AffineExprKind getKind() const
Return the classification for this type.
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
MLIRContext * getContext() const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
MLIRContext * getContext() const
bool isFunctionOfDim(unsigned position) const
Return true if any affine expression involves AffineDimExpr position.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
bool isFunctionOfSymbol(unsigned position) const
Return true if any affine expression involves AffineSymbolExpr position.
unsigned getNumResults() const
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
unsigned getNumInputs() const
AffineMap shiftSymbols(unsigned shift, unsigned offset=0) const
Replace symbols[offset ...
AffineExpr getResult(unsigned idx) const
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
static AffineMap getConstantMap(int64_t val, MLIRContext *context)
Returns a single constant result affine map.
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
LogicalResult constantFold(ArrayRef< Attribute > operandConstants, SmallVectorImpl< Attribute > &results, bool *hasPoison=nullptr) const
Folds the results of the application of an affine map on the provided operands to a constant if possi...
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
AffineMap getDimIdentityMap()
AffineMap getMultiDimIdentityMap(unsigned rank)
AffineExpr getAffineSymbolExpr(unsigned position)
AffineExpr getAffineConstantExpr(int64_t constant)
DenseIntElementsAttr getI32TensorAttr(ArrayRef< int32_t > values)
Tensor-typed DenseIntElementsAttr getters.
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
AffineMap getEmptyAffineMap()
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
MLIRContext * getContext() const
AffineMap getSymbolIdentityMap()
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
An attribute that represents a reference to a dense integer vector or tensor object.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
An integer set representing a conjunction of one or more affine equalities and inequalities.
unsigned getNumDims() const
static IntegerSet get(unsigned dimCount, unsigned symbolCount, ArrayRef< AffineExpr > constraints, ArrayRef< bool > eqFlags)
MLIRContext * getContext() const
unsigned getNumInputs() const
ArrayRef< AffineExpr > getConstraints() const
ArrayRef< bool > getEqFlags() const
Returns the equality bits, which specify whether each of the constraints is an equality or inequality...
unsigned getNumSymbols() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
void erase(unsigned subStart, unsigned subLen=1)
Erase the operands within the given sub-range.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void pop_back()
Pop last element from list.
Attribute erase(StringAttr name)
Erase the attribute with the given name from the list.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseAffineMapOfSSAIds(SmallVectorImpl< UnresolvedOperand > &operands, Attribute &map, StringRef attrName, NamedAttrList &attrs, Delimiter delimiter=Delimiter::Square)=0
Parses an affine map attribute where dims and symbols are SSA operands.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseAffineExprOfSSAIds(SmallVectorImpl< UnresolvedOperand > &dimOperands, SmallVectorImpl< UnresolvedOperand > &symbOperands, AffineExpr &expr)=0
Parses an affine expression where dims and symbols are SSA operands.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, ValueRange symOperands)=0
Prints an affine expression of SSA ids with SSA id names used instead of dims and symbols.
virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, ValueRange operands)=0
Prints an affine map of SSA ids, where SSA id names are used in place of dims/symbols.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
A trait of region holding operations that defines a new scope for polyhedral optimization purposes.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
operand_range::iterator operand_iterator
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
std::vector< SmallVector< int64_t, 8 > > operandExprStack
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
AffineBound represents a lower or upper bound in the for operation.
AffineDmaStartOp starts a non-blocking DMA operation that transfers data from a source memref to a de...
AffineDmaWaitOp blocks until the completion of a DMA operation associated with the tag element 'tag[i...
An AffineValueMap is an affine map plus its ML value operands and results for analysis purposes.
LogicalResult canonicalize()
Attempts to canonicalize the map and operands.
ArrayRef< Value > getOperands() const
AffineExpr getResult(unsigned i)
AffineMap getAffineMap() const
unsigned getNumResults() const
Operation * getOwner() const
Return the owner of this operand.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
void buildAffineLoopNest(OpBuilder &builder, Location loc, ArrayRef< int64_t > lbs, ArrayRef< int64_t > ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn=nullptr)
Builds a perfect nest of affine.for loops, i.e., each loop except the innermost one contains only ano...
void fullyComposeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Given an affine map map and its input operands, this method composes into map, maps of AffineApplyOps...
void extractForInductionVars(ArrayRef< AffineForOp > forInsts, SmallVectorImpl< Value > *ivs)
Extracts the induction variables from a list of AffineForOps and places them in the output argument i...
bool isValidDim(Value value)
Returns true if the given Value can be used as a dimension id in the region of the closest surroundin...
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
bool isAffineInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp or AffineParallelOp.
AffineForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
bool isAffineForInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp.
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
bool isTopLevelValue(Value value)
A utility function to check if a value is defined at the top level of an op with trait AffineScope or...
Region * getAffineAnalysisScope(Operation *op)
Returns the closest region enclosing op that is held by a non-affine operation; nullptr if there is n...
void canonicalizeSetAndOperands(IntegerSet *set, SmallVectorImpl< Value > *operands)
Canonicalizes an integer set the same way canonicalizeMapAndOperands does for affine maps.
void extractInductionVars(ArrayRef< Operation * > affineOps, SmallVectorImpl< Value > &ivs)
Extracts the induction variables from a list of either AffineForOp or AffineParallelOp and places the...
bool isValidSymbol(Value value)
Returns true if the given value can be used as a symbol in the region of the closest surrounding op t...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
AffineParallelOp getAffineParallelInductionVarOwner(Value val)
Returns true if the provided value is among the induction variables of an AffineParallelOp.
Region * getAffineScope(Operation *op)
Returns the closest region enclosing op that is held by an operation with trait AffineScope; nullptr ...
ParseResult parseDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl< Value > &operands, unsigned &numDims)
Parses dimension and symbol list.
bool isAffineParallelInductionVar(Value val)
Returns true if val is the induction variable of an AffineParallelOp.
AffineMinOp makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns an AffineMinOp obtained by composing map and operands with AffineApplyOps supplying those ope...
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
AffineMap removeDuplicateExprs(AffineMap map)
Returns a map with the same dimension and symbol count as map, but whose results are the unique affin...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::optional< int64_t > getBoundForAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols, ArrayRef< std::optional< int64_t >> constLowerBounds, ArrayRef< std::optional< int64_t >> constUpperBounds, bool isUpper)
Get a lower or upper (depending on isUpper) bound for expr while using the constant lower and upper b...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ SymbolId
Symbolic identifier.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
AffineMap foldAttributesIntoMap(Builder &b, AffineMap map, ArrayRef< OpFoldResult > operands, SmallVector< Value > &remainingValues)
Fold all attributes among the given operands into the affine map.
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Canonicalize the affine map result expression order of an affine min/max operation.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Remove duplicated expressions in affine min/max ops.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Merge an affine min/max op to its consumers if its consumer is also an affine min/max op.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.