23 #include "llvm/ADT/ScopeExit.h"
24 #include "llvm/ADT/SmallBitVector.h"
25 #include "llvm/ADT/SmallVectorExtras.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Support/Debug.h"
34 #define DEBUG_TYPE "affine-ops"
36 #include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
43 if (
auto arg = llvm::dyn_cast<BlockArgument>(value))
44 return arg.getParentRegion() == region;
67 if (llvm::isa<BlockArgument>(value))
68 return legalityCheck(mapping.
lookup(value), dest);
75 bool isDimLikeOp = isa<ShapedDimOpInterface>(value.
getDefiningOp());
86 return llvm::all_of(values, [&](
Value v) {
93 template <
typename OpTy>
96 static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
97 AffineWriteOpInterface>::value,
98 "only ops with affine read/write interface are supported");
105 dimOperands, src, dest, mapping,
109 symbolOperands, src, dest, mapping,
126 op.getMapOperands(), src, dest, mapping,
131 op.getMapOperands(), src, dest, mapping,
158 if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
163 if (!llvm::hasSingleElement(*src))
171 if (
auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
172 if (iface.hasNoEffect())
180 .Case<AffineApplyOp, AffineReadOpInterface,
181 AffineWriteOpInterface>([&](
auto op) {
206 isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
210 bool shouldAnalyzeRecursively(
Operation *op)
const final {
return true; }
218 void AffineDialect::initialize() {
221 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
223 addInterfaces<AffineInlinerInterface>();
224 declarePromisedInterfaces<ValueBoundsOpInterface, AffineApplyOp, AffineMaxOp,
233 if (
auto poison = dyn_cast<ub::PoisonAttr>(value))
234 return builder.
create<ub::PoisonOp>(loc, type, poison);
235 return arith::ConstantOp::materialize(builder, value, type, loc);
243 if (
auto arg = llvm::dyn_cast<BlockArgument>(value)) {
259 while (
auto *parentOp = curOp->getParentOp()) {
282 auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
284 isa<AffineForOp, AffineParallelOp>(parentOp));
305 auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->
getParentOp();
306 return isa<AffineForOp, AffineParallelOp>(parentOp);
310 if (
auto applyOp = dyn_cast<AffineApplyOp>(op))
311 return applyOp.isValidDim(region);
314 if (
auto dimOp = dyn_cast<ShapedDimOpInterface>(op))
322 template <
typename AnyMemRefDefOp>
325 MemRefType memRefType = memrefDefOp.getType();
328 if (index >= memRefType.getRank()) {
333 if (!memRefType.isDynamicDim(index))
336 unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
337 return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
349 if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
357 if (!index.has_value())
361 Operation *op = dimOp.getShapedValue().getDefiningOp();
362 while (
auto castOp = dyn_cast<memref::CastOp>(op)) {
364 if (isa<UnrankedMemRefType>(castOp.getSource().getType()))
366 op = castOp.getSource().getDefiningOp();
371 int64_t i = index.value();
373 .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
375 .Default([](
Operation *) {
return false; });
441 if (
auto applyOp = dyn_cast<AffineApplyOp>(defOp))
442 return applyOp.isValidSymbol(region);
445 if (
auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp))
469 printer <<
'(' << operands.take_front(numDims) <<
')';
470 if (operands.size() > numDims)
471 printer <<
'[' << operands.drop_front(numDims) <<
']';
481 numDims = opInfos.size();
495 template <
typename OpTy>
500 for (
auto operand : operands) {
501 if (opIt++ < numDims) {
503 return op.
emitOpError(
"operand cannot be used as a dimension id");
505 return op.
emitOpError(
"operand cannot be used as a symbol");
516 return AffineValueMap(getAffineMap(), getOperands(), getResult());
523 AffineMapAttr mapAttr;
529 auto map = mapAttr.getValue();
531 if (map.getNumDims() != numDims ||
532 numDims + map.getNumSymbols() != result.
operands.size()) {
534 "dimension or symbol index mismatch");
537 result.
types.append(map.getNumResults(), indexTy);
542 p <<
" " << getMapAttr();
544 getAffineMap().getNumDims(), p);
555 "operand count and affine map dimension and symbol count must match");
559 return emitOpError(
"mapping must produce one value");
567 return llvm::all_of(getOperands(),
575 return llvm::all_of(getOperands(),
582 return llvm::all_of(getOperands(),
589 return llvm::all_of(getOperands(), [&](
Value operand) {
595 auto map = getAffineMap();
598 auto expr = map.getResult(0);
599 if (
auto dim = dyn_cast<AffineDimExpr>(expr))
600 return getOperand(dim.getPosition());
601 if (
auto sym = dyn_cast<AffineSymbolExpr>(expr))
602 return getOperand(map.getNumDims() + sym.getPosition());
606 bool hasPoison =
false;
608 map.constantFold(adaptor.getMapOperands(), result, &hasPoison);
628 auto dimExpr = dyn_cast<AffineDimExpr>(e);
638 Value operand = operands[dimExpr.getPosition()];
639 int64_t operandDivisor = 1;
643 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() == 0) {
644 operandDivisor = forOp.getStepAsInt();
646 uint64_t lbLargestKnownDivisor =
647 forOp.getLowerBoundMap().getLargestKnownDivisorOfMapExprs();
648 operandDivisor =
std::gcd(lbLargestKnownDivisor, forOp.getStepAsInt());
651 return operandDivisor;
658 if (
auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
659 int64_t constVal = constExpr.getValue();
660 return constVal >= 0 && constVal < k;
662 auto dimExpr = dyn_cast<AffineDimExpr>(e);
665 Value operand = operands[dimExpr.getPosition()];
669 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() >= 0 &&
670 forOp.hasConstantUpperBound() && forOp.getConstantUpperBound() <= k) {
686 auto bin = dyn_cast<AffineBinaryOpExpr>(e);
694 quotientTimesDiv = llhs;
700 quotientTimesDiv = rlhs;
710 if (forOp && forOp.hasConstantLowerBound())
711 return forOp.getConstantLowerBound();
718 if (!forOp || !forOp.hasConstantUpperBound())
723 if (forOp.hasConstantLowerBound()) {
724 return forOp.getConstantUpperBound() - 1 -
725 (forOp.getConstantUpperBound() - forOp.getConstantLowerBound() - 1) %
726 forOp.getStepAsInt();
728 return forOp.getConstantUpperBound() - 1;
739 constLowerBounds.reserve(operands.size());
740 constUpperBounds.reserve(operands.size());
741 for (
Value operand : operands) {
746 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr))
747 return constExpr.getValue();
762 constLowerBounds.reserve(operands.size());
763 constUpperBounds.reserve(operands.size());
764 for (
Value operand : operands) {
769 std::optional<int64_t> lowerBound;
770 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
771 lowerBound = constExpr.getValue();
774 constLowerBounds, constUpperBounds,
785 auto binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
796 binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
804 lhs = binExpr.getLHS();
805 rhs = binExpr.getRHS();
806 auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
810 int64_t rhsConstVal = rhsConst.getValue();
812 if (rhsConstVal <= 0)
817 std::optional<int64_t> lhsLbConst =
819 std::optional<int64_t> lhsUbConst =
821 if (lhsLbConst && lhsUbConst) {
822 int64_t lhsLbConstVal = *lhsLbConst;
823 int64_t lhsUbConstVal = *lhsUbConst;
827 floorDiv(lhsLbConstVal, rhsConstVal) ==
828 floorDiv(lhsUbConstVal, rhsConstVal)) {
836 ceilDiv(lhsLbConstVal, rhsConstVal) ==
837 ceilDiv(lhsUbConstVal, rhsConstVal)) {
844 lhsLbConstVal < rhsConstVal && lhsUbConstVal < rhsConstVal) {
856 if (
isQTimesDPlusR(lhs, operands, divisor, quotientTimesDiv, rem)) {
857 if (rhsConstVal % divisor == 0 &&
859 expr = quotientTimesDiv.
floorDiv(rhsConst);
860 }
else if (divisor % rhsConstVal == 0 &&
862 expr = rem % rhsConst;
888 if (operands.empty())
894 constLowerBounds.reserve(operands.size());
895 constUpperBounds.reserve(operands.size());
896 for (
Value operand : operands) {
910 if (
auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
911 lowerBounds.push_back(constExpr.getValue());
912 upperBounds.push_back(constExpr.getValue());
914 lowerBounds.push_back(
916 constLowerBounds, constUpperBounds,
918 upperBounds.push_back(
920 constLowerBounds, constUpperBounds,
929 unsigned i = exprEn.index();
931 if (lowerBounds[i] && upperBounds[i] && *lowerBounds[i] == *upperBounds[i])
936 if (!upperBounds[i]) {
937 irredundantExprs.push_back(e);
943 auto otherLowerBound = en.value();
944 unsigned pos = en.index();
945 if (pos == i || !otherLowerBound)
947 if (*otherLowerBound > *upperBounds[i])
949 if (*otherLowerBound < *upperBounds[i])
954 if (upperBounds[pos] && lowerBounds[i] &&
955 lowerBounds[i] == upperBounds[i] &&
956 otherLowerBound == *upperBounds[pos] && i < pos)
960 irredundantExprs.push_back(e);
962 if (!lowerBounds[i]) {
963 irredundantExprs.push_back(e);
968 auto otherUpperBound = en.value();
969 unsigned pos = en.index();
970 if (pos == i || !otherUpperBound)
972 if (*otherUpperBound < *lowerBounds[i])
974 if (*otherUpperBound > *lowerBounds[i])
976 if (lowerBounds[pos] && upperBounds[i] &&
977 lowerBounds[i] == upperBounds[i] &&
978 otherUpperBound == lowerBounds[pos] && i < pos)
982 irredundantExprs.push_back(e);
994 static void LLVM_ATTRIBUTE_UNUSED
996 assert(map.
getNumInputs() == operands.size() &&
"invalid operands for map");
1002 newResults.push_back(expr);
1019 unsigned dimOrSymbolPosition,
1023 bool isDimReplacement = (dimOrSymbolPosition < dims.size());
1024 unsigned pos = isDimReplacement ? dimOrSymbolPosition
1025 : dimOrSymbolPosition - dims.size();
1026 Value &v = isDimReplacement ? dims[pos] : syms[pos];
1039 AffineMap composeMap = affineApply.getAffineMap();
1040 assert(composeMap.
getNumResults() == 1 &&
"affine.apply with >1 results");
1042 affineApply.getMapOperands().end());
1056 dims.append(composeDims.begin(), composeDims.end());
1057 syms.append(composeSyms.begin(), composeSyms.end());
1058 *map = map->
replace(toReplace, replacementExpr, dims.size(), syms.size());
1086 bool changed =
false;
1087 for (
unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
1099 unsigned nDims = 0, nSyms = 0;
1101 dimReplacements.reserve(dims.size());
1102 symReplacements.reserve(syms.size());
1103 for (
auto *container : {&dims, &syms}) {
1104 bool isDim = (container == &dims);
1105 auto &repls = isDim ? dimReplacements : symReplacements;
1107 Value v = en.value();
1111 "map is function of unexpected expr@pos");
1117 operands->push_back(v);
1130 while (llvm::any_of(*operands, [](
Value v) {
1144 return b.
create<AffineApplyOp>(loc, map, valueOperands);
1166 for (
unsigned i : llvm::seq<unsigned>(0, map.
getNumResults())) {
1173 llvm::append_range(dims,
1175 llvm::append_range(symbols,
1182 operands = llvm::to_vector(llvm::concat<Value>(dims, symbols));
1191 assert(map.
getNumResults() == 1 &&
"building affine.apply with !=1 result");
1201 AffineApplyOp applyOp =
1206 for (
unsigned i = 0, e = constOperands.size(); i != e; ++i)
1211 if (
failed(applyOp->fold(constOperands, foldResults)) ||
1212 foldResults.empty()) {
1214 listener->notifyOperationInserted(applyOp, {});
1215 return applyOp.getResult();
1219 assert(foldResults.size() == 1 &&
"expected 1 folded result");
1220 return foldResults.front();
1238 return llvm::map_to_vector(llvm::seq<unsigned>(0, map.
getNumResults()),
1240 return makeComposedFoldedAffineApply(
1241 b, loc, map.getSubMap({i}), operands);
1245 template <
typename OpTy>
1257 return makeComposedMinMax<AffineMinOp>(b, loc, map, operands);
1260 template <
typename OpTy>
1272 auto minMaxOp = makeComposedMinMax<OpTy>(newBuilder, loc, map, operands);
1276 for (
unsigned i = 0, e = constOperands.size(); i != e; ++i)
1281 if (
failed(minMaxOp->fold(constOperands, foldResults)) ||
1282 foldResults.empty()) {
1284 listener->notifyOperationInserted(minMaxOp, {});
1285 return minMaxOp.getResult();
1289 assert(foldResults.size() == 1 &&
"expected 1 folded result");
1290 return foldResults.front();
1297 return makeComposedFoldedMinMax<AffineMinOp>(b, loc, map, operands);
1304 return makeComposedFoldedMinMax<AffineMaxOp>(b, loc, map, operands);
1309 template <
class MapOrSet>
1312 if (!mapOrSet || operands->empty())
1315 assert(mapOrSet->getNumInputs() == operands->size() &&
1316 "map/set inputs must match number of operands");
1318 auto *context = mapOrSet->getContext();
1320 resultOperands.reserve(operands->size());
1322 remappedSymbols.reserve(operands->size());
1323 unsigned nextDim = 0;
1324 unsigned nextSym = 0;
1325 unsigned oldNumSyms = mapOrSet->getNumSymbols();
1327 for (
unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
1328 if (i < mapOrSet->getNumDims()) {
1332 remappedSymbols.push_back((*operands)[i]);
1335 resultOperands.push_back((*operands)[i]);
1338 resultOperands.push_back((*operands)[i]);
1342 resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
1343 *operands = resultOperands;
1344 *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
1345 oldNumSyms + nextSym);
1347 assert(mapOrSet->getNumInputs() == operands->size() &&
1348 "map/set inputs must match number of operands");
1352 template <
class MapOrSet>
1355 static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
1356 "Argument must be either of AffineMap or IntegerSet type");
1358 if (!mapOrSet || operands->empty())
1361 assert(mapOrSet->getNumInputs() == operands->size() &&
1362 "map/set inputs must match number of operands");
1364 canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
1367 llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
1368 llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
1370 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr))
1371 usedDims[dimExpr.getPosition()] =
true;
1372 else if (
auto symExpr = dyn_cast<AffineSymbolExpr>(expr))
1373 usedSyms[symExpr.getPosition()] =
true;
1376 auto *context = mapOrSet->getContext();
1379 resultOperands.reserve(operands->size());
1381 llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
1383 unsigned nextDim = 0;
1384 for (
unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
1387 auto it = seenDims.find((*operands)[i]);
1388 if (it == seenDims.end()) {
1390 resultOperands.push_back((*operands)[i]);
1391 seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
1393 dimRemapping[i] = it->second;
1397 llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
1399 unsigned nextSym = 0;
1400 for (
unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
1406 IntegerAttr operandCst;
1407 if (
matchPattern((*operands)[i + mapOrSet->getNumDims()],
1414 auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
1415 if (it == seenSymbols.end()) {
1417 resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
1418 seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
1421 symRemapping[i] = it->second;
1424 *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
1426 *operands = resultOperands;
1431 canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
1436 canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
1443 template <
typename AffineOpTy>
1455 llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
1456 AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp,
1457 AffineVectorStoreOp, AffineVectorLoadOp>::value,
1458 "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op "
1460 auto map = affineOp.getAffineMap();
1462 auto oldOperands = affineOp.getMapOperands();
1467 if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
1468 resultOperands.begin()))
1471 replaceAffineOp(rewriter, affineOp, map, resultOperands);
1479 void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
1486 void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
1490 prefetch, prefetch.getMemref(), map, mapOperands, prefetch.getIsWrite(),
1491 prefetch.getLocalityHint(), prefetch.getIsDataCache());
1494 void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
1498 store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
1501 void SimplifyAffineOp<AffineVectorLoadOp>::replaceAffineOp(
1505 vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map,
1509 void SimplifyAffineOp<AffineVectorStoreOp>::replaceAffineOp(
1513 vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map,
1518 template <
typename AffineOpTy>
1519 void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
1528 results.
add<SimplifyAffineOp<AffineApplyOp>>(context);
1559 p <<
" " << getSrcMemRef() <<
'[';
1561 p <<
"], " << getDstMemRef() <<
'[';
1563 p <<
"], " << getTagMemRef() <<
'[';
1567 p <<
", " << getStride();
1568 p <<
", " << getNumElementsPerStride();
1570 p <<
" : " << getSrcMemRefType() <<
", " << getDstMemRefType() <<
", "
1571 << getTagMemRefType();
1583 AffineMapAttr srcMapAttr;
1586 AffineMapAttr dstMapAttr;
1589 AffineMapAttr tagMapAttr;
1604 getSrcMapAttrStrName(),
1608 getDstMapAttrStrName(),
1612 getTagMapAttrStrName(),
1621 if (!strideInfo.empty() && strideInfo.size() != 2) {
1623 "expected two stride related operands");
1625 bool isStrided = strideInfo.size() == 2;
1630 if (types.size() != 3)
1648 if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
1649 dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
1650 tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1652 "memref operand count not equal to map.numInputs");
1657 if (!llvm::isa<MemRefType>(getOperand(getSrcMemRefOperandIndex()).getType()))
1658 return emitOpError(
"expected DMA source to be of memref type");
1659 if (!llvm::isa<MemRefType>(getOperand(getDstMemRefOperandIndex()).getType()))
1660 return emitOpError(
"expected DMA destination to be of memref type");
1661 if (!llvm::isa<MemRefType>(getOperand(getTagMemRefOperandIndex()).getType()))
1662 return emitOpError(
"expected DMA tag to be of memref type");
1664 unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
1665 getDstMap().getNumInputs() +
1666 getTagMap().getNumInputs();
1667 if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
1668 getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
1669 return emitOpError(
"incorrect number of operands");
1673 for (
auto idx : getSrcIndices()) {
1674 if (!idx.getType().isIndex())
1675 return emitOpError(
"src index to dma_start must have 'index' type");
1678 "src index must be a valid dimension or symbol identifier");
1680 for (
auto idx : getDstIndices()) {
1681 if (!idx.getType().isIndex())
1682 return emitOpError(
"dst index to dma_start must have 'index' type");
1685 "dst index must be a valid dimension or symbol identifier");
1687 for (
auto idx : getTagIndices()) {
1688 if (!idx.getType().isIndex())
1689 return emitOpError(
"tag index to dma_start must have 'index' type");
1692 "tag index must be a valid dimension or symbol identifier");
1703 void AffineDmaStartOp::getEffects(
1729 p <<
" " << getTagMemRef() <<
'[';
1734 p <<
" : " << getTagMemRef().getType();
1745 AffineMapAttr tagMapAttr;
1754 getTagMapAttrStrName(),
1763 if (!llvm::isa<MemRefType>(type))
1765 "expected tag to be of memref type");
1767 if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1769 "tag memref operand count != to map.numInputs");
1774 if (!llvm::isa<MemRefType>(getOperand(0).getType()))
1775 return emitOpError(
"expected DMA tag to be of memref type");
1777 for (
auto idx : getTagIndices()) {
1778 if (!idx.getType().isIndex())
1779 return emitOpError(
"index to dma_wait must have 'index' type");
1782 "index must be a valid dimension or symbol identifier");
1793 void AffineDmaWaitOp::getEffects(
1809 ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
1810 assert(((!lbMap && lbOperands.empty()) ||
1812 "lower bound operand count does not match the affine map");
1813 assert(((!ubMap && ubOperands.empty()) ||
1815 "upper bound operand count does not match the affine map");
1816 assert(step > 0 &&
"step has to be a positive integer constant");
1822 getOperandSegmentSizeAttr(),
1824 static_cast<int32_t>(ubOperands.size()),
1825 static_cast<int32_t>(iterArgs.size())}));
1827 for (
Value val : iterArgs)
1849 Value inductionVar =
1851 for (
Value val : iterArgs)
1852 bodyBlock->
addArgument(val.getType(), val.getLoc());
1857 if (iterArgs.empty() && !bodyBuilder) {
1858 ensureTerminator(*bodyRegion, builder, result.
location);
1859 }
else if (bodyBuilder) {
1862 bodyBuilder(builder, result.
location, inductionVar,
1868 int64_t ub, int64_t step,
ValueRange iterArgs,
1869 BodyBuilderFn bodyBuilder) {
1872 return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,
1879 auto *body = getBody();
1880 if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
1881 return emitOpError(
"expected body to have a single index argument for the "
1882 "induction variable");
1886 if (getLowerBoundMap().getNumInputs() > 0)
1888 getLowerBoundMap().getNumDims())))
1891 if (getUpperBoundMap().getNumInputs() > 0)
1893 getUpperBoundMap().getNumDims())))
1896 unsigned opNumResults = getNumResults();
1897 if (opNumResults == 0)
1903 if (getNumIterOperands() != opNumResults)
1905 "mismatch between the number of loop-carried values and results");
1906 if (getNumRegionIterArgs() != opNumResults)
1908 "mismatch between the number of basic block args and results");
1918 bool failedToParsedMinMax =
1922 auto boundAttrStrName =
1923 isLower ? AffineForOp::getLowerBoundMapAttrName(result.
name)
1924 : AffineForOp::getUpperBoundMapAttrName(result.
name);
1931 if (!boundOpInfos.empty()) {
1933 if (boundOpInfos.size() > 1)
1935 "expected only one loop bound operand");
1960 if (
auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(boundAttr)) {
1961 unsigned currentNumOperands = result.
operands.size();
1966 auto map = affineMapAttr.getValue();
1970 "dim operand count and affine map dim count must match");
1972 unsigned numDimAndSymbolOperands =
1973 result.
operands.size() - currentNumOperands;
1974 if (numDims + map.
getNumSymbols() != numDimAndSymbolOperands)
1977 "symbol operand count and affine map symbol count must match");
1983 return p.
emitError(attrLoc,
"lower loop bound affine map with "
1984 "multiple results requires 'max' prefix");
1986 return p.
emitError(attrLoc,
"upper loop bound affine map with multiple "
1987 "results requires 'min' prefix");
1993 if (
auto integerAttr = llvm::dyn_cast<IntegerAttr>(boundAttr)) {
2003 "expected valid affine map representation for loop bounds");
2015 int64_t numOperands = result.
operands.size();
2018 int64_t numLbOperands = result.
operands.size() - numOperands;
2021 numOperands = result.
operands.size();
2024 int64_t numUbOperands = result.
operands.size() - numOperands;
2029 getStepAttrName(result.
name),
2033 IntegerAttr stepAttr;
2035 getStepAttrName(result.
name).data(),
2039 if (stepAttr.getValue().isNegative())
2042 "expected step to be representable as a positive signed integer");
2050 regionArgs.push_back(inductionVariable);
2058 for (
auto argOperandType :
2059 llvm::zip(llvm::drop_begin(regionArgs), operands, result.
types)) {
2060 Type type = std::get<2>(argOperandType);
2061 std::get<0>(argOperandType).type = type;
2069 getOperandSegmentSizeAttr(),
2071 static_cast<int32_t>(numUbOperands),
2072 static_cast<int32_t>(operands.size())}));
2076 if (regionArgs.size() != result.
types.size() + 1)
2079 "mismatch between the number of loop-carried values and results");
2083 AffineForOp::ensureTerminator(*body, builder, result.
location);
2105 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
2106 p << constExpr.getValue();
2114 if (dyn_cast<AffineSymbolExpr>(expr)) {
2130 unsigned AffineForOp::getNumIterOperands() {
2131 AffineMap lbMap = getLowerBoundMapAttr().getValue();
2132 AffineMap ubMap = getUpperBoundMapAttr().getValue();
2137 std::optional<MutableArrayRef<OpOperand>>
2138 AffineForOp::getYieldedValuesMutable() {
2139 return cast<AffineYieldOp>(getBody()->getTerminator()).getOperandsMutable();
2151 if (getStepAsInt() != 1)
2152 p <<
" step " << getStepAsInt();
2154 bool printBlockTerminators =
false;
2155 if (getNumIterOperands() > 0) {
2157 auto regionArgs = getRegionIterArgs();
2158 auto operands = getInits();
2160 llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](
auto it) {
2161 p << std::get<0>(it) <<
" = " << std::get<1>(it);
2163 p <<
") -> (" << getResultTypes() <<
")";
2164 printBlockTerminators =
true;
2169 printBlockTerminators);
2171 (*this)->getAttrs(),
2172 {getLowerBoundMapAttrName(getOperation()->getName()),
2173 getUpperBoundMapAttrName(getOperation()->getName()),
2174 getStepAttrName(getOperation()->getName()),
2175 getOperandSegmentSizeAttr()});
2180 auto foldLowerOrUpperBound = [&forOp](
bool lower) {
2184 auto boundOperands =
2185 lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
2186 for (
auto operand : boundOperands) {
2189 operandConstants.push_back(operandCst);
2193 lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
2195 "bound maps should have at least one result");
2201 assert(!foldedResults.empty() &&
"bounds should have at least one result");
2202 auto maxOrMin = llvm::cast<IntegerAttr>(foldedResults[0]).getValue();
2203 for (
unsigned i = 1, e = foldedResults.size(); i < e; i++) {
2204 auto foldedResult = llvm::cast<IntegerAttr>(foldedResults[i]).getValue();
2205 maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
2206 : llvm::APIntOps::smin(maxOrMin, foldedResult);
2208 lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
2209 : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
2214 bool folded =
false;
2215 if (!forOp.hasConstantLowerBound())
2216 folded |=
succeeded(foldLowerOrUpperBound(
true));
2219 if (!forOp.hasConstantUpperBound())
2220 folded |=
succeeded(foldLowerOrUpperBound(
false));
2229 auto lbMap = forOp.getLowerBoundMap();
2230 auto ubMap = forOp.getUpperBoundMap();
2231 auto prevLbMap = lbMap;
2232 auto prevUbMap = ubMap;
2245 if (lbMap == prevLbMap && ubMap == prevUbMap)
2248 if (lbMap != prevLbMap)
2249 forOp.setLowerBound(lbOperands, lbMap);
2250 if (ubMap != prevUbMap)
2251 forOp.setUpperBound(ubOperands, ubMap);
2257 static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
2258 int64_t step = forOp.getStepAsInt();
2259 if (!forOp.hasConstantBounds() || step <= 0)
2260 return std::nullopt;
2261 int64_t lb = forOp.getConstantLowerBound();
2262 int64_t ub = forOp.getConstantUpperBound();
2263 return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
2274 if (!llvm::hasSingleElement(*forOp.getBody()))
2276 if (forOp.getNumResults() == 0)
2278 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
2279 if (tripCount && *tripCount == 0) {
2282 rewriter.
replaceOp(forOp, forOp.getInits());
2286 auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
2287 auto iterArgs = forOp.getRegionIterArgs();
2288 bool hasValDefinedOutsideLoop =
false;
2289 bool iterArgsNotInOrder =
false;
2290 for (
unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
2291 Value val = yieldOp.getOperand(i);
2292 auto *iterArgIt = llvm::find(iterArgs, val);
2293 if (iterArgIt == iterArgs.end()) {
2295 assert(forOp.isDefinedOutsideOfLoop(val) &&
2296 "must be defined outside of the loop");
2297 hasValDefinedOutsideLoop =
true;
2298 replacements.push_back(val);
2300 unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
2302 iterArgsNotInOrder =
true;
2303 replacements.push_back(forOp.getInits()[pos]);
2308 if (!tripCount.has_value() &&
2309 (hasValDefinedOutsideLoop || iterArgsNotInOrder))
2313 if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
2315 rewriter.
replaceOp(forOp, replacements);
2323 results.
add<AffineForEmptyLoopFolder>(context);
2327 assert((point.
isParent() || point == getRegion()) &&
"invalid region point");
2334 void AffineForOp::getSuccessorRegions(
2336 assert((point.
isParent() || point == getRegion()) &&
"expected loop region");
2341 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*
this);
2342 if (point.
isParent() && tripCount.has_value()) {
2343 if (tripCount.value() > 0) {
2344 regions.push_back(
RegionSuccessor(&getRegion(), getRegionIterArgs()));
2347 if (tripCount.value() == 0) {
2355 if (!point.
isParent() && tripCount && *tripCount == 1) {
2362 regions.push_back(
RegionSuccessor(&getRegion(), getRegionIterArgs()));
2368 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(op);
2369 return tripCount && *tripCount == 0;
2382 results.assign(getInits().begin(), getInits().end());
2398 assert(map.
getNumResults() >= 1 &&
"bound map has at least one result");
2399 getLowerBoundOperandsMutable().assign(lbOperands);
2400 setLowerBoundMap(map);
2405 assert(map.
getNumResults() >= 1 &&
"bound map has at least one result");
2406 getUpperBoundOperandsMutable().assign(ubOperands);
2407 setUpperBoundMap(map);
2410 bool AffineForOp::hasConstantLowerBound() {
2411 return getLowerBoundMap().isSingleConstant();
2414 bool AffineForOp::hasConstantUpperBound() {
2415 return getUpperBoundMap().isSingleConstant();
2418 int64_t AffineForOp::getConstantLowerBound() {
2419 return getLowerBoundMap().getSingleConstantResult();
2422 int64_t AffineForOp::getConstantUpperBound() {
2423 return getUpperBoundMap().getSingleConstantResult();
2426 void AffineForOp::setConstantLowerBound(int64_t value) {
2430 void AffineForOp::setConstantUpperBound(int64_t value) {
2434 AffineForOp::operand_range AffineForOp::getControlOperands() {
2439 bool AffineForOp::matchingBoundOperandList() {
2440 auto lbMap = getLowerBoundMap();
2441 auto ubMap = getUpperBoundMap();
2447 for (
unsigned i = 0, e = lbMap.
getNumInputs(); i < e; i++) {
2449 if (getOperand(i) != getOperand(numOperands + i))
2457 std::optional<Value> AffineForOp::getSingleInductionVar() {
2458 return getInductionVar();
2461 std::optional<OpFoldResult> AffineForOp::getSingleLowerBound() {
2462 if (!hasConstantLowerBound())
2463 return std::nullopt;
2465 return OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()));
2468 std::optional<OpFoldResult> AffineForOp::getSingleStep() {
2470 return OpFoldResult(b.getI64IntegerAttr(getStepAsInt()));
2473 std::optional<OpFoldResult> AffineForOp::getSingleUpperBound() {
2474 if (!hasConstantUpperBound())
2475 return std::nullopt;
2477 return OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()));
2482 bool replaceInitOperandUsesInLoop,
2487 auto inits = llvm::to_vector(getInits());
2488 inits.append(newInitOperands.begin(), newInitOperands.end());
2489 AffineForOp newLoop = rewriter.
create<AffineForOp>(
2494 auto yieldOp = cast<AffineYieldOp>(getBody()->getTerminator());
2496 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
2501 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
2502 assert(newInitOperands.size() == newYieldedValues.size() &&
2503 "expected as many new yield values as new iter operands");
2505 yieldOp.getOperandsMutable().append(newYieldedValues);
2510 rewriter.
mergeBlocks(getBody(), newLoop.getBody(),
2511 newLoop.getBody()->getArguments().take_front(
2512 getBody()->getNumArguments()));
2514 if (replaceInitOperandUsesInLoop) {
2517 for (
auto it : llvm::zip(newInitOperands, newIterArgs)) {
2528 newLoop->getResults().take_front(getNumResults()));
2529 return cast<LoopLikeOpInterface>(newLoop.getOperation());
2557 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2558 if (!ivArg || !ivArg.getOwner())
2559 return AffineForOp();
2560 auto *containingInst = ivArg.getOwner()->getParent()->getParentOp();
2561 if (
auto forOp = dyn_cast<AffineForOp>(containingInst))
2563 return forOp.getInductionVar() == val ? forOp : AffineForOp();
2564 return AffineForOp();
2568 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2569 if (!ivArg || !ivArg.getOwner())
2572 auto parallelOp = dyn_cast<AffineParallelOp>(containingOp);
2573 if (parallelOp && llvm::is_contained(parallelOp.getIVs(), val))
2582 ivs->reserve(forInsts.size());
2583 for (
auto forInst : forInsts)
2584 ivs->push_back(forInst.getInductionVar());
2589 ivs.reserve(affineOps.size());
2592 if (
auto forOp = dyn_cast<AffineForOp>(op))
2593 ivs.push_back(forOp.getInductionVar());
2594 else if (
auto parallelOp = dyn_cast<AffineParallelOp>(op))
2595 for (
size_t i = 0; i < parallelOp.getBody()->getNumArguments(); i++)
2596 ivs.push_back(parallelOp.getBody()->getArgument(i));
2602 template <
typename BoundListTy,
typename LoopCreatorTy>
2607 LoopCreatorTy &&loopCreatorFn) {
2608 assert(lbs.size() == ubs.size() &&
"Mismatch in number of arguments");
2609 assert(lbs.size() == steps.size() &&
"Mismatch in number of arguments");
2621 ivs.reserve(lbs.size());
2622 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
2628 if (i == e - 1 && bodyBuilderFn) {
2630 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2632 nestedBuilder.
create<AffineYieldOp>(nestedLoc);
2637 auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
2645 int64_t ub, int64_t step,
2646 AffineForOp::BodyBuilderFn bodyBuilderFn) {
2647 return builder.
create<AffineForOp>(loc, lb, ub, step,
2648 std::nullopt, bodyBuilderFn);
2655 AffineForOp::BodyBuilderFn bodyBuilderFn) {
2658 if (lbConst && ubConst)
2660 ubConst.value(), step, bodyBuilderFn);
2663 std::nullopt, bodyBuilderFn);
2693 if (ifOp.getElseRegion().empty() ||
2694 !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
2712 auto isTriviallyFalse = [](
IntegerSet iSet) {
2713 return iSet.isEmptyIntegerSet();
2717 return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 &&
2718 iSet.getConstraint(0) == 0);
2721 IntegerSet affineIfConditions = op.getIntegerSet();
2723 if (isTriviallyFalse(affineIfConditions)) {
2733 blockToMove = op.getElseBlock();
2734 }
else if (isTriviallyTrue(affineIfConditions)) {
2735 blockToMove = op.getThenBlock();
2753 rewriter.
eraseOp(blockToMoveTerminator);
2761 void AffineIfOp::getSuccessorRegions(
2770 if (getElseRegion().empty()) {
2771 regions.push_back(getResults());
2787 auto conditionAttr =
2788 (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2790 return emitOpError(
"requires an integer set attribute named 'condition'");
2793 IntegerSet condition = conditionAttr.getValue();
2795 return emitOpError(
"operand count and condition integer set dimension and "
2796 "symbol count must match");
2808 IntegerSetAttr conditionAttr;
2811 AffineIfOp::getConditionAttrStrName(),
2817 auto set = conditionAttr.getValue();
2818 if (set.getNumDims() != numDims)
2821 "dim operand count and integer set dim count must match");
2822 if (numDims + set.getNumSymbols() != result.
operands.size())
2825 "symbol operand count and integer set symbol count must match");
2839 AffineIfOp::ensureTerminator(*thenRegion, parser.
getBuilder(),
2846 AffineIfOp::ensureTerminator(*elseRegion, parser.
getBuilder(),
2858 auto conditionAttr =
2859 (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2860 p <<
" " << conditionAttr;
2862 conditionAttr.getValue().getNumDims(), p);
2869 auto &elseRegion = this->getElseRegion();
2870 if (!elseRegion.
empty()) {
2879 getConditionAttrStrName());
2884 ->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName())
2888 void AffineIfOp::setIntegerSet(
IntegerSet newSet) {
2894 (*this)->setOperands(operands);
2899 bool withElseRegion) {
2900 assert(resultTypes.empty() || withElseRegion);
2909 if (resultTypes.empty())
2910 AffineIfOp::ensureTerminator(*thenRegion, builder, result.
location);
2913 if (withElseRegion) {
2915 if (resultTypes.empty())
2916 AffineIfOp::ensureTerminator(*elseRegion, builder, result.
location);
2922 AffineIfOp::build(builder, result, {}, set, args,
2937 if (llvm::none_of(operands,
2948 auto set = getIntegerSet();
2954 if (getIntegerSet() == set && llvm::equal(operands, getOperands()))
2957 setConditional(set, operands);
2963 results.
add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context);
2972 assert(operands.size() == 1 + map.
getNumInputs() &&
"inconsistent operands");
2976 auto memrefType = llvm::cast<MemRefType>(operands[0].getType());
2977 result.
types.push_back(memrefType.getElementType());
2982 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
2985 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
2987 result.
types.push_back(memrefType.getElementType());
2992 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
2993 int64_t rank = memrefType.getRank();
2998 build(builder, result, memref, map, indices);
3007 AffineMapAttr mapAttr;
3012 AffineLoadOp::getMapAttrStrName(),
3022 p <<
" " << getMemRef() <<
'[';
3023 if (AffineMapAttr mapAttr =
3024 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3028 {getMapAttrStrName()});
3037 MemRefType memrefType,
unsigned numIndexOperands) {
3040 return op->
emitOpError(
"affine map num results must equal memref rank");
3042 return op->
emitOpError(
"expects as many subscripts as affine map inputs");
3045 for (
auto idx : mapOperands) {
3046 if (!idx.getType().isIndex())
3047 return op->
emitOpError(
"index to load must have 'index' type");
3050 "index must be a valid dimension or symbol identifier");
3058 if (getType() != memrefType.getElementType())
3059 return emitOpError(
"result type must match element type of memref");
3063 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3064 getMapOperands(), memrefType,
3065 getNumOperands() - 1)))
3073 results.
add<SimplifyAffineOp<AffineLoadOp>>(context);
3082 auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
3089 auto global = dyn_cast_or_null<memref::GlobalOp>(
3096 llvm::dyn_cast_or_null<DenseElementsAttr>(global.getConstantInitValue());
3100 if (
auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(cstAttr))
3101 return splatAttr.getSplatValue<
Attribute>();
3103 if (!getAffineMap().isConstant())
3105 auto indices = llvm::to_vector<4>(
3106 llvm::map_range(getAffineMap().getConstantResults(),
3107 [](int64_t v) -> uint64_t {
return v; }));
3108 return cstAttr.getValues<
Attribute>()[indices];
3118 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
3129 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
3130 int64_t rank = memrefType.getRank();
3135 build(builder, result, valueToStore, memref, map, indices);
3144 AffineMapAttr mapAttr;
3149 mapOperands, mapAttr, AffineStoreOp::getMapAttrStrName(),
3160 p <<
" " << getValueToStore();
3161 p <<
", " << getMemRef() <<
'[';
3162 if (AffineMapAttr mapAttr =
3163 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3167 {getMapAttrStrName()});
3174 if (getValueToStore().getType() != memrefType.getElementType())
3176 "value to store must have the same type as memref element type");
3180 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3181 getMapOperands(), memrefType,
3182 getNumOperands() - 2)))
3190 results.
add<SimplifyAffineOp<AffineStoreOp>>(context);
3203 template <
typename T>
3207 op.getMap().getNumDims() + op.getMap().getNumSymbols())
3209 "operand count and affine map dimension and symbol count must match");
3213 template <
typename T>
3215 p <<
' ' << op->
getAttr(T::getMapAttrStrName());
3217 unsigned numDims = op.getMap().getNumDims();
3218 p <<
'(' << operands.take_front(numDims) <<
')';
3220 if (operands.size() != numDims)
3221 p <<
'[' << operands.drop_front(numDims) <<
']';
3223 {T::getMapAttrStrName()});
3226 template <
typename T>
3233 AffineMapAttr mapAttr;
3249 template <
typename T>
3251 static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
3252 "expected affine min or max op");
3258 auto foldedMap = op.getMap().partialConstantFold(operands, &results);
3260 if (foldedMap.getNumSymbols() == 1 && foldedMap.isSymbolIdentity())
3264 if (results.empty()) {
3266 if (foldedMap == op.getMap())
3273 auto resultIt = std::is_same<T, AffineMinOp>::value
3274 ? llvm::min_element(results)
3275 : llvm::max_element(results);
3276 if (resultIt == results.end())
3282 template <
typename T>
3288 AffineMap oldMap = affineOp.getAffineMap();
3294 if (!llvm::is_contained(newExprs, expr))
3295 newExprs.push_back(expr);
3325 template <
typename T>
3331 AffineMap oldMap = affineOp.getAffineMap();
3333 affineOp.getMapOperands().take_front(oldMap.
getNumDims());
3335 affineOp.getMapOperands().take_back(oldMap.
getNumSymbols());
3337 auto newDimOperands = llvm::to_vector<8>(dimOperands);
3338 auto newSymOperands = llvm::to_vector<8>(symOperands);
3346 if (
auto symExpr = dyn_cast<AffineSymbolExpr>(expr)) {
3347 Value symValue = symOperands[symExpr.getPosition()];
3349 producerOps.push_back(producerOp);
3352 }
else if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
3353 Value dimValue = dimOperands[dimExpr.getPosition()];
3355 producerOps.push_back(producerOp);
3362 newExprs.push_back(expr);
3365 if (producerOps.empty())
3372 for (T producerOp : producerOps) {
3373 AffineMap producerMap = producerOp.getAffineMap();
3374 unsigned numProducerDims = producerMap.
getNumDims();
3379 producerOp.getMapOperands().take_front(numProducerDims);
3381 producerOp.getMapOperands().take_back(numProducerSyms);
3382 newDimOperands.append(dimValues.begin(), dimValues.end());
3383 newSymOperands.append(symValues.begin(), symValues.end());
3387 newExprs.push_back(expr.shiftDims(numProducerDims, numUsedDims)
3388 .shiftSymbols(numProducerSyms, numUsedSyms));
3391 numUsedDims += numProducerDims;
3392 numUsedSyms += numProducerSyms;
3398 llvm::to_vector<8>(llvm::concat<Value>(newDimOperands, newSymOperands));
3417 if (!resultExpr.isPureAffine())
3422 if (
failed(flattenResult))
3435 if (llvm::is_sorted(flattenedExprs))
3440 llvm::to_vector(llvm::seq<unsigned>(0, map.
getNumResults()));
3441 llvm::sort(resultPermutation, [&](
unsigned lhs,
unsigned rhs) {
3442 return flattenedExprs[lhs] < flattenedExprs[rhs];
3445 for (
unsigned idx : resultPermutation)
3466 template <
typename T>
3472 AffineMap map = affineOp.getAffineMap();
3480 template <
typename T>
3486 if (affineOp.getMap().getNumResults() != 1)
3489 affineOp.getOperands());
3517 return parseAffineMinMaxOp<AffineMinOp>(parser, result);
3545 return parseAffineMinMaxOp<AffineMaxOp>(parser, result);
3564 IntegerAttr hintInfo;
3566 StringRef readOrWrite, cacheType;
3568 AffineMapAttr mapAttr;
3572 AffinePrefetchOp::getMapAttrStrName(),
3578 AffinePrefetchOp::getLocalityHintAttrStrName(),
3588 if (!readOrWrite.equals(
"read") && !readOrWrite.equals(
"write"))
3590 "rw specifier has to be 'read' or 'write'");
3592 AffinePrefetchOp::getIsWriteAttrStrName(),
3595 if (!cacheType.equals(
"data") && !cacheType.equals(
"instr"))
3597 "cache type has to be 'data' or 'instr'");
3600 AffinePrefetchOp::getIsDataCacheAttrStrName(),
3607 p <<
" " << getMemref() <<
'[';
3608 AffineMapAttr mapAttr =
3609 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3612 p <<
']' <<
", " << (getIsWrite() ?
"write" :
"read") <<
", "
3613 <<
"locality<" << getLocalityHint() <<
">, "
3614 << (getIsDataCache() ?
"data" :
"instr");
3616 (*this)->getAttrs(),
3617 {getMapAttrStrName(), getLocalityHintAttrStrName(),
3618 getIsDataCacheAttrStrName(), getIsWriteAttrStrName()});
3623 auto mapAttr = (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3627 return emitOpError(
"affine.prefetch affine map num results must equal"
3630 return emitOpError(
"too few operands");
3632 if (getNumOperands() != 1)
3633 return emitOpError(
"too few operands");
3637 for (
auto idx : getMapOperands()) {
3640 "index must be a valid dimension or symbol identifier");
3648 results.
add<SimplifyAffineOp<AffinePrefetchOp>>(context);
3666 auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) {
3670 build(builder, result, resultTypes, reductions, lbs, {}, ubs,
3680 assert(llvm::all_of(lbMaps,
3682 return m.
getNumDims() == lbMaps[0].getNumDims() &&
3685 "expected all lower bounds maps to have the same number of dimensions "
3687 assert(llvm::all_of(ubMaps,
3689 return m.
getNumDims() == ubMaps[0].getNumDims() &&
3692 "expected all upper bounds maps to have the same number of dimensions "
3694 assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) &&
3695 "expected lower bound maps to have as many inputs as lower bound "
3697 assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) &&
3698 "expected upper bound maps to have as many inputs as upper bound "
3706 for (arith::AtomicRMWKind reduction : reductions)
3707 reductionAttrs.push_back(
3719 groups.reserve(groups.size() + maps.size());
3720 exprs.reserve(maps.size());
3725 return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs,
3731 AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups);
3732 AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups);
3750 for (
unsigned i = 0, e = steps.size(); i < e; ++i)
3752 if (resultTypes.empty())
3753 ensureTerminator(*bodyRegion, builder, result.
location);
3757 return {&getRegion()};
3760 unsigned AffineParallelOp::getNumDims() {
return getSteps().size(); }
3762 AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
3763 return getOperands().take_front(getLowerBoundsMap().getNumInputs());
3766 AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
3767 return getOperands().drop_front(getLowerBoundsMap().getNumInputs());
3770 AffineMap AffineParallelOp::getLowerBoundMap(
unsigned pos) {
3771 auto values = getLowerBoundsGroups().getValues<int32_t>();
3773 for (
unsigned i = 0; i < pos; ++i)
3775 return getLowerBoundsMap().getSliceMap(start, values[pos]);
3778 AffineMap AffineParallelOp::getUpperBoundMap(
unsigned pos) {
3779 auto values = getUpperBoundsGroups().getValues<int32_t>();
3781 for (
unsigned i = 0; i < pos; ++i)
3783 return getUpperBoundsMap().getSliceMap(start, values[pos]);
3787 return AffineValueMap(getLowerBoundsMap(), getLowerBoundsOperands());
3791 return AffineValueMap(getUpperBoundsMap(), getUpperBoundsOperands());
3794 std::optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
3795 if (hasMinMaxBounds())
3796 return std::nullopt;
3801 AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
3804 for (
unsigned i = 0, e = rangesValueMap.
getNumResults(); i < e; ++i) {
3805 auto expr = rangesValueMap.
getResult(i);
3806 auto cst = dyn_cast<AffineConstantExpr>(expr);
3808 return std::nullopt;
3809 out.push_back(cst.getValue());
3814 Block *AffineParallelOp::getBody() {
return &getRegion().
front(); }
3816 OpBuilder AffineParallelOp::getBodyBuilder() {
3817 return OpBuilder(getBody(), std::prev(getBody()->end()));
3822 "operands to map must match number of inputs");
3824 auto ubOperands = getUpperBoundsOperands();
3827 newOperands.append(ubOperands.begin(), ubOperands.end());
3828 (*this)->setOperands(newOperands);
3835 "operands to map must match number of inputs");
3838 newOperands.append(ubOperands.begin(), ubOperands.end());
3839 (*this)->setOperands(newOperands);
3845 setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
3850 arith::AtomicRMWKind op) {
3852 case arith::AtomicRMWKind::addf:
3853 return isa<FloatType>(resultType);
3854 case arith::AtomicRMWKind::addi:
3855 return isa<IntegerType>(resultType);
3856 case arith::AtomicRMWKind::assign:
3858 case arith::AtomicRMWKind::mulf:
3859 return isa<FloatType>(resultType);
3860 case arith::AtomicRMWKind::muli:
3861 return isa<IntegerType>(resultType);
3862 case arith::AtomicRMWKind::maximumf:
3863 return isa<FloatType>(resultType);
3864 case arith::AtomicRMWKind::minimumf:
3865 return isa<FloatType>(resultType);
3866 case arith::AtomicRMWKind::maxs: {
3867 auto intType = llvm::dyn_cast<IntegerType>(resultType);
3868 return intType && intType.isSigned();
3870 case arith::AtomicRMWKind::mins: {
3871 auto intType = llvm::dyn_cast<IntegerType>(resultType);
3872 return intType && intType.isSigned();
3874 case arith::AtomicRMWKind::maxu: {
3875 auto intType = llvm::dyn_cast<IntegerType>(resultType);
3876 return intType && intType.isUnsigned();
3878 case arith::AtomicRMWKind::minu: {
3879 auto intType = llvm::dyn_cast<IntegerType>(resultType);
3880 return intType && intType.isUnsigned();
3882 case arith::AtomicRMWKind::ori:
3883 return isa<IntegerType>(resultType);
3884 case arith::AtomicRMWKind::andi:
3885 return isa<IntegerType>(resultType);
3892 auto numDims = getNumDims();
3895 getSteps().size() != numDims || getBody()->getNumArguments() != numDims) {
3896 return emitOpError() <<
"the number of region arguments ("
3897 << getBody()->getNumArguments()
3898 <<
") and the number of map groups for lower ("
3899 << getLowerBoundsGroups().getNumElements()
3900 <<
") and upper bound ("
3901 << getUpperBoundsGroups().getNumElements()
3902 <<
"), and the number of steps (" << getSteps().size()
3903 <<
") must all match";
3906 unsigned expectedNumLBResults = 0;
3907 for (APInt v : getLowerBoundsGroups())
3908 expectedNumLBResults += v.getZExtValue();
3909 if (expectedNumLBResults != getLowerBoundsMap().getNumResults())
3910 return emitOpError() <<
"expected lower bounds map to have "
3911 << expectedNumLBResults <<
" results";
3912 unsigned expectedNumUBResults = 0;
3913 for (APInt v : getUpperBoundsGroups())
3914 expectedNumUBResults += v.getZExtValue();
3915 if (expectedNumUBResults != getUpperBoundsMap().getNumResults())
3916 return emitOpError() <<
"expected upper bounds map to have "
3917 << expectedNumUBResults <<
" results";
3919 if (getReductions().size() != getNumResults())
3920 return emitOpError(
"a reduction must be specified for each output");
3926 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
3927 if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
3928 return emitOpError(
"invalid reduction attribute");
3929 auto kind = arith::symbolizeAtomicRMWKind(intAttr.getInt()).value();
3931 return emitOpError(
"result type cannot match reduction attribute");
3937 getLowerBoundsMap().getNumDims())))
3941 getUpperBoundsMap().getNumDims())))
3948 auto newMap = getAffineMap();
3950 if (newMap == getAffineMap() && newOperands == operands)
3952 reset(newMap, newOperands);
3965 if (!lbCanonicalized && !ubCanonicalized)
3968 if (lbCanonicalized)
3970 if (ubCanonicalized)
3988 StringRef keyword) {
3991 ValueRange dimOperands = operands.take_front(numDims);
3992 ValueRange symOperands = operands.drop_front(numDims);
3994 for (llvm::APInt groupSize : group) {
3998 unsigned size = groupSize.getZExtValue();
4003 p << keyword <<
'(';
4013 p <<
" (" << getBody()->getArguments() <<
") = (";
4015 getLowerBoundsOperands(),
"max");
4018 getUpperBoundsOperands(),
"min");
4021 bool elideSteps = llvm::all_of(steps, [](int64_t step) {
return step == 1; });
4024 llvm::interleaveComma(steps, p);
4027 if (getNumResults()) {
4029 llvm::interleaveComma(getReductions(), p, [&](
auto &attr) {
4030 arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(
4031 llvm::cast<IntegerAttr>(attr).getInt());
4032 p <<
"\"" << arith::stringifyAtomicRMWKind(sym) <<
"\"";
4034 p <<
") -> (" << getResultTypes() <<
")";
4041 (*this)->getAttrs(),
4042 {AffineParallelOp::getReductionsAttrStrName(),
4043 AffineParallelOp::getLowerBoundsMapAttrStrName(),
4044 AffineParallelOp::getLowerBoundsGroupsAttrStrName(),
4045 AffineParallelOp::getUpperBoundsMapAttrStrName(),
4046 AffineParallelOp::getUpperBoundsGroupsAttrStrName(),
4047 AffineParallelOp::getStepsAttrStrName()});
4060 "expected operands to be dim or symbol expression");
4063 for (
const auto &list : operands) {
4067 for (
Value operand : valueOperands) {
4068 unsigned pos = std::distance(uniqueOperands.begin(),
4069 llvm::find(uniqueOperands, operand));
4070 if (pos == uniqueOperands.size())
4071 uniqueOperands.push_back(operand);
4072 replacements.push_back(
4082 enum class MinMaxKind { Min, Max };
4106 const llvm::StringLiteral tmpAttrStrName =
"__pseudo_bound_map";
4108 StringRef mapName = kind == MinMaxKind::Min
4109 ? AffineParallelOp::getUpperBoundsMapAttrStrName()
4110 : AffineParallelOp::getLowerBoundsMapAttrStrName();
4111 StringRef groupsName =
4112 kind == MinMaxKind::Min
4113 ? AffineParallelOp::getUpperBoundsGroupsAttrStrName()
4114 : AffineParallelOp::getLowerBoundsGroupsAttrStrName();
4131 auto parseOperands = [&]() {
4133 kind == MinMaxKind::Min ?
"min" :
"max"))) {
4134 mapOperands.clear();
4141 llvm::append_range(flatExprs, map.getValue().getResults());
4143 auto dimsRef = operandsRef.take_front(map.getValue().getNumDims());
4146 auto symsRef = operandsRef.drop_front(map.getValue().getNumDims());
4149 flatDimOperands.append(map.getValue().getNumResults(), dims);
4150 flatSymOperands.append(map.getValue().getNumResults(), syms);
4151 numMapsPerGroup.push_back(map.getValue().getNumResults());
4154 flatSymOperands.emplace_back(),
4155 flatExprs.emplace_back())))
4157 numMapsPerGroup.push_back(1);
4164 unsigned totalNumDims = 0;
4165 unsigned totalNumSyms = 0;
4166 for (
unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
4167 unsigned numDims = flatDimOperands[i].size();
4168 unsigned numSyms = flatSymOperands[i].size();
4169 flatExprs[i] = flatExprs[i]
4170 .shiftDims(numDims, totalNumDims)
4171 .shiftSymbols(numSyms, totalNumSyms);
4172 totalNumDims += numDims;
4173 totalNumSyms += numSyms;
4185 result.
operands.append(dimOperands.begin(), dimOperands.end());
4186 result.
operands.append(symOperands.begin(), symOperands.end());
4189 auto flatMap =
AffineMap::get(totalNumDims, totalNumSyms, flatExprs,
4191 flatMap = flatMap.replaceDimsAndSymbols(
4192 dimRplacements, symRepacements, dimOperands.size(), symOperands.size());
4216 AffineMapAttr stepsMapAttr;
4221 result.
addAttribute(AffineParallelOp::getStepsAttrStrName(),
4225 AffineParallelOp::getStepsAttrStrName(),
4232 auto stepsMap = stepsMapAttr.getValue();
4233 for (
const auto &result : stepsMap.getResults()) {
4234 auto constExpr = dyn_cast<AffineConstantExpr>(result);
4237 "steps must be constant integers");
4238 steps.push_back(constExpr.getValue());
4240 result.
addAttribute(AffineParallelOp::getStepsAttrStrName(),
4260 std::optional<arith::AtomicRMWKind> reduction =
4261 arith::symbolizeAtomicRMWKind(attrVal.getValue());
4263 return parser.
emitError(loc,
"invalid reduction value: ") << attrVal;
4264 reductions.push_back(
4272 result.
addAttribute(AffineParallelOp::getReductionsAttrStrName(),
4281 for (
auto &iv : ivs)
4282 iv.type = indexType;
4288 AffineParallelOp::ensureTerminator(*body, builder, result.
location);
4297 auto *parentOp = (*this)->getParentOp();
4298 auto results = parentOp->getResults();
4299 auto operands = getOperands();
4301 if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
4302 return emitOpError() <<
"only terminates affine.if/for/parallel regions";
4303 if (parentOp->getNumResults() != getNumOperands())
4304 return emitOpError() <<
"parent of yield must have same number of "
4305 "results as the yield operands";
4306 for (
auto it : llvm::zip(results, operands)) {
4307 if (std::get<0>(it).getType() != std::get<1>(it).getType())
4308 return emitOpError() <<
"types mismatch between yield op and its parent";
4321 assert(operands.size() == 1 + map.
getNumInputs() &&
"inconsistent operands");
4325 result.
types.push_back(resultType);
4329 VectorType resultType,
Value memref,
4331 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
4335 result.
types.push_back(resultType);
4339 VectorType resultType,
Value memref,
4341 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
4342 int64_t rank = memrefType.getRank();
4347 build(builder, result, resultType, memref, map, indices);
4350 void AffineVectorLoadOp::getCanonicalizationPatterns(
RewritePatternSet &results,
4352 results.
add<SimplifyAffineOp<AffineVectorLoadOp>>(context);
4360 MemRefType memrefType;
4361 VectorType resultType;
4363 AffineMapAttr mapAttr;
4368 AffineVectorLoadOp::getMapAttrStrName(),
4379 p <<
" " << getMemRef() <<
'[';
4380 if (AffineMapAttr mapAttr =
4381 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4385 {getMapAttrStrName()});
4391 VectorType vectorType) {
4393 if (memrefType.getElementType() != vectorType.getElementType())
4395 "requires memref and vector types of the same elemental type");
4403 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4404 getMapOperands(), memrefType,
4405 getNumOperands() - 1)))
4421 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
4432 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
4433 int64_t rank = memrefType.getRank();
4438 build(builder, result, valueToStore, memref, map, indices);
4440 void AffineVectorStoreOp::getCanonicalizationPatterns(
4442 results.
add<SimplifyAffineOp<AffineVectorStoreOp>>(context);
4449 MemRefType memrefType;
4450 VectorType resultType;
4453 AffineMapAttr mapAttr;
4459 AffineVectorStoreOp::getMapAttrStrName(),
4470 p <<
" " << getValueToStore();
4471 p <<
", " << getMemRef() <<
'[';
4472 if (AffineMapAttr mapAttr =
4473 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4477 {getMapAttrStrName()});
4478 p <<
" : " <<
getMemRefType() <<
", " << getValueToStore().getType();
4484 *
this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4485 getMapOperands(), memrefType,
4486 getNumOperands() - 2)))
4500 MLIRContext *context, std::optional<::mlir::Location> location,
4503 AffineDelinearizeIndexOpAdaptor adaptor(operands, attributes, properties,
4505 inferredReturnTypes.assign(adaptor.getBasis().size(),
4518 if (staticDim.has_value())
4521 return llvm::dyn_cast_if_present<Value>(ofr);
4527 if (getBasis().empty())
4528 return emitOpError(
"basis should not be empty");
4529 if (getNumResults() != getBasis().size())
4530 return emitOpError(
"should return an index for each basis element");
4538 #define GET_OP_CLASSES
4539 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
static AffineForOp buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb, int64_t ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds known to be constants.
static bool hasTrivialZeroTripCount(AffineForOp op)
Returns true if the affine.for has zero iterations in trivial cases.
static void composeMultiResultAffineMap(AffineMap &map, SmallVectorImpl< Value > &operands)
Composes the given affine map with the given list of operands, pulling in the maps from any affine....
static void printAffineMinMaxOp(OpAsmPrinter &p, T op)
static bool isResultTypeMatchAtomicRMWKind(Type resultType, arith::AtomicRMWKind op)
static bool remainsLegalAfterInline(Value value, Region *src, Region *dest, const IRMapping &mapping, function_ref< bool(Value, Region *)> legalityCheck)
Checks if value known to be a legal affine dimension or symbol in src region remains legal if the ope...
static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr, DenseIntElementsAttr group, ValueRange operands, StringRef keyword)
Prints a lower(upper) bound of an affine parallel loop with max(min) conditions in it.
static void LLVM_ATTRIBUTE_UNUSED simplifyMapWithOperands(AffineMap &map, ArrayRef< Value > operands)
Simplify the map while exploiting information on the values in operands.
static OpFoldResult foldMinMaxOp(T op, ArrayRef< Attribute > operands)
Fold an affine min or max operation with the given operands.
static LogicalResult canonicalizeLoopBounds(AffineForOp forOp)
Canonicalize the bounds of the given loop.
static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims, unsigned numSymbols, ArrayRef< Value > operands)
Simplify expr while exploiting information from the values in operands.
static bool isValidAffineIndexOperand(Value value, Region *region)
static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
static void composeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Iterate over operands and fold away all those produced by an AffineApplyOp iteratively.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static ParseResult parseBound(bool isLower, OperationState &result, OpAsmParser &p)
Parse a for operation loop bounds.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static void composeSetAndOperands(IntegerSet &set, SmallVectorImpl< Value > &operands)
Compose any affine.apply ops feeding into operands of the integer set set by composing the maps of su...
static LogicalResult replaceDimOrSym(AffineMap *map, unsigned dimOrSymbolPosition, SmallVectorImpl< Value > &dims, SmallVectorImpl< Value > &syms)
Replace all occurrences of AffineExpr at position pos in map by the defining AffineApplyOp expression...
static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType)
Verify common invariants of affine.vector_load and affine.vector_store.
static void simplifyMinOrMaxExprWithOperands(AffineMap &map, ArrayRef< Value > operands, bool isMax)
Simplify the expressions in map while making use of lower or upper bounds of its operands.
static ParseResult parseAffineMinMaxOp(OpAsmParser &parser, OperationState &result)
static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, Region *region)
Returns true if the 'index' dimension of the memref defined by memrefDefOp is a statically shaped one...
static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef< Value > operands, int64_t k)
Check if e is known to be: 0 <= e < k.
static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser, OperationState &result, MinMaxKind kind)
Parses an affine map that can contain a min/max for groups of its results, e.g., max(expr-1,...
static AffineForOp buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds that may or may not be constants.
static void printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, OpAsmPrinter &printer)
Prints dimension and symbol list.
static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef< Value > operands)
Returns the largest known divisor of e.
static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
static void buildAffineLoopNestImpl(OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, LoopCreatorTy &&loopCreatorFn)
Builds an affine loop nest, using "loopCreatorFn" to create individual loop operations.
static LogicalResult foldLoopBounds(AffineForOp forOp)
Fold the constant bounds of a loop.
static LogicalResult verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, unsigned numDims)
Utility function to verify that a set of operands are valid dimension and symbol identifiers.
static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region)
Returns true if the result of the dim op is a valid symbol for region.
static bool isQTimesDPlusR(AffineExpr e, ArrayRef< Value > operands, int64_t &div, AffineExpr "ientTimesDiv, AffineExpr &rem)
Check if expression e is of the form d*e_1 + e_2 where 0 <= e_2 < d.
static ParseResult deduplicateAndResolveOperands(OpAsmParser &parser, ArrayRef< SmallVector< OpAsmParser::UnresolvedOperand >> operands, SmallVectorImpl< Value > &uniqueOperands, SmallVectorImpl< AffineExpr > &replacements, AffineExprKind kind)
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult verifyAffineMinMaxOp(T op)
static void printBound(AffineMapAttr boundMap, Operation::operand_range boundOperands, const char *prefix, OpAsmPrinter &p)
static LogicalResult verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr, Operation::operand_range mapOperands, MemRefType memrefType, unsigned numIndexOperands)
Verify common indexing invariants of affine.load, affine.store, affine.vector_load and affine....
static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map)
Canonicalize the result expression order of an affine map and return success if the order changed.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Operation::operand_range getLowerBoundOperands(AffineForOp forOp)
static Operation::operand_range getUpperBoundOperands(AffineForOp forOp)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static int64_t getNumElements(ShapedType type)
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
AffineExprKind getKind() const
Return the classification for this type.
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
MLIRContext * getContext() const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
MLIRContext * getContext() const
bool isFunctionOfDim(unsigned position) const
Return true if any affine expression involves AffineDimExpr position.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
bool isFunctionOfSymbol(unsigned position) const
Return true if any affine expression involves AffineSymbolExpr position.
unsigned getNumResults() const
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
unsigned getNumInputs() const
AffineMap shiftSymbols(unsigned shift, unsigned offset=0) const
Replace symbols[offset ...
AffineExpr getResult(unsigned idx) const
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
static AffineMap getConstantMap(int64_t val, MLIRContext *context)
Returns a single constant result affine map.
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
LogicalResult constantFold(ArrayRef< Attribute > operandConstants, SmallVectorImpl< Attribute > &results, bool *hasPoison=nullptr) const
Folds the results of the application of an affine map on the provided operands to a constant if possi...
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
AffineMap getDimIdentityMap()
AffineMap getMultiDimIdentityMap(unsigned rank)
DenseIntElementsAttr getI32TensorAttr(ArrayRef< int32_t > values)
Tensor-typed DenseIntElementsAttr getters.
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
AffineMap getEmptyAffineMap()
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
MLIRContext * getContext() const
AffineMap getSymbolIdentityMap()
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
An attribute that represents a reference to a dense integer vector or tensor object.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This class provides support for representing a failure result, or a valid value of type T.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
An integer set representing a conjunction of one or more affine equalities and inequalities.
unsigned getNumDims() const
static IntegerSet get(unsigned dimCount, unsigned symbolCount, ArrayRef< AffineExpr > constraints, ArrayRef< bool > eqFlags)
MLIRContext * getContext() const
unsigned getNumInputs() const
ArrayRef< AffineExpr > getConstraints() const
ArrayRef< bool > getEqFlags() const
Returns the equality bits, which specify whether each of the constraints is an equality or inequality...
unsigned getNumSymbols() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void pop_back()
Pop last element from list.
Attribute erase(StringAttr name)
Erase the attribute with the given name from the list.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseAffineMapOfSSAIds(SmallVectorImpl< UnresolvedOperand > &operands, Attribute &map, StringRef attrName, NamedAttrList &attrs, Delimiter delimiter=Delimiter::Square)=0
Parses an affine map attribute where dims and symbols are SSA operands.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseAffineExprOfSSAIds(SmallVectorImpl< UnresolvedOperand > &dimOperands, SmallVectorImpl< UnresolvedOperand > &symbOperands, AffineExpr &expr)=0
Parses an affine expression where dims and symbols are SSA operands.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, ValueRange symOperands)=0
Prints an affine expression of SSA ids with SSA id names used instead of dims and symbols.
virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, ValueRange operands)=0
Prints an affine map of SSA ids, where SSA id names are used in place of dims/symbols.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
A trait of region holding operations that defines a new scope for polyhedral optimization purposes.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumOperands()
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
operand_range::iterator operand_iterator
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
std::vector< SmallVector< int64_t, 8 > > operandExprStack
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
AffineBound represents a lower or upper bound in the for operation.
AffineDmaStartOp starts a non-blocking DMA operation that transfers data from a source memref to a de...
AffineDmaWaitOp blocks until the completion of a DMA operation associated with the tag element 'tag[i...
An AffineValueMap is an affine map plus its ML value operands and results for analysis purposes.
LogicalResult canonicalize()
Attempts to canonicalize the map and operands.
ArrayRef< Value > getOperands() const
AffineExpr getResult(unsigned i)
AffineMap getAffineMap() const
unsigned getNumResults() const
Operation * getOwner() const
Return the owner of this operand.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
void buildAffineLoopNest(OpBuilder &builder, Location loc, ArrayRef< int64_t > lbs, ArrayRef< int64_t > ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn=nullptr)
Builds a perfect nest of affine.for loops, i.e., each loop except the innermost one contains only ano...
void fullyComposeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Given an affine map map and its input operands, this method composes into map, maps of AffineApplyOps...
void extractForInductionVars(ArrayRef< AffineForOp > forInsts, SmallVectorImpl< Value > *ivs)
Extracts the induction variables from a list of AffineForOps and places them in the output argument i...
bool isValidDim(Value value)
Returns true if the given Value can be used as a dimension id in the region of the closest surroundin...
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
bool isAffineInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp or AffineParallelOp.
AffineForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
bool isAffineForInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp.
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
bool isTopLevelValue(Value value)
A utility function to check if a value is defined at the top level of an op with trait AffineScope or...
void canonicalizeSetAndOperands(IntegerSet *set, SmallVectorImpl< Value > *operands)
Canonicalizes an integer set the same way canonicalizeMapAndOperands does for affine maps.
void extractInductionVars(ArrayRef< Operation * > affineOps, SmallVectorImpl< Value > &ivs)
Extracts the induction variables from a list of either AffineForOp or AffineParallelOp and places the...
bool isValidSymbol(Value value)
Returns true if the given value can be used as a symbol in the region of the closest surrounding op t...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
AffineParallelOp getAffineParallelInductionVarOwner(Value val)
Returns true if the provided value is among the induction variables of an AffineParallelOp.
Region * getAffineScope(Operation *op)
Returns the closest region enclosing op that is held by an operation with trait AffineScope; nullptr ...
ParseResult parseDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl< Value > &operands, unsigned &numDims)
Parses dimension and symbol list.
bool isAffineParallelInductionVar(Value val)
Returns true if val is the induction variable of an AffineParallelOp.
AffineMinOp makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns an AffineMinOp obtained by composing map and operands with AffineApplyOps supplying those ope...
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt gcd(const MPInt &a, const MPInt &b)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
AffineMap removeDuplicateExprs(AffineMap map)
Returns a map with the same dimension and symbol count as map, but whose results are the unique affin...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::optional< int64_t > getBoundForAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols, ArrayRef< std::optional< int64_t >> constLowerBounds, ArrayRef< std::optional< int64_t >> constUpperBounds, bool isUpper)
Get a lower or upper (depending on isUpper) bound for expr while using the constant lower and upper b...
int64_t floorDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's floordiv operation on constants.
int64_t ceilDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's ceildiv operation on constants.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ SymbolId
Symbolic identifier.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool isStrided(MemRefType t)
Return "true" if the layout for t is compatible with strided semantics.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
AffineMap foldAttributesIntoMap(Builder &b, AffineMap map, ArrayRef< OpFoldResult > operands, SmallVector< Value > &remainingValues)
Fold all attributes among the given operands into the affine map.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Canonicalize the affine map result expression order of an affine min/max operation.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Remove duplicated expressions in affine min/max ops.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Merge an affine min/max op to its consumers if its consumer is also an affine min/max op.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
This class represents an efficient way to signal success or failure.
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.