23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/ScopeExit.h"
25 #include "llvm/ADT/SmallBitVector.h"
26 #include "llvm/ADT/SmallVectorExtras.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/MathExtras.h"
36 using llvm::divideCeilSigned;
37 using llvm::divideFloorSigned;
40 #define DEBUG_TYPE "affine-ops"
42 #include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
49 if (
auto arg = llvm::dyn_cast<BlockArgument>(value))
50 return arg.getParentRegion() == region;
73 if (llvm::isa<BlockArgument>(value))
74 return legalityCheck(mapping.
lookup(value), dest);
81 bool isDimLikeOp = isa<ShapedDimOpInterface>(value.
getDefiningOp());
92 return llvm::all_of(values, [&](
Value v) {
99 template <
typename OpTy>
102 static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
103 AffineWriteOpInterface>::value,
104 "only ops with affine read/write interface are supported");
111 dimOperands, src, dest, mapping,
115 symbolOperands, src, dest, mapping,
132 op.getMapOperands(), src, dest, mapping,
137 op.getMapOperands(), src, dest, mapping,
164 if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
169 if (!llvm::hasSingleElement(*src))
177 if (
auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
178 if (iface.hasNoEffect())
186 .Case<AffineApplyOp, AffineReadOpInterface,
187 AffineWriteOpInterface>([&](
auto op) {
212 isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
216 bool shouldAnalyzeRecursively(
Operation *op)
const final {
return true; }
224 void AffineDialect::initialize() {
227 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
229 addInterfaces<AffineInlinerInterface>();
230 declarePromisedInterfaces<ValueBoundsOpInterface, AffineApplyOp, AffineMaxOp,
239 if (
auto poison = dyn_cast<ub::PoisonAttr>(value))
240 return builder.
create<ub::PoisonOp>(loc, type, poison);
241 return arith::ConstantOp::materialize(builder, value, type, loc);
249 if (
auto arg = llvm::dyn_cast<BlockArgument>(value)) {
265 while (
auto *parentOp = curOp->getParentOp()) {
276 if (!isa<AffineForOp, AffineIfOp, AffineParallelOp>(parentOp))
301 auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
329 if (
auto applyOp = dyn_cast<AffineApplyOp>(op))
330 return applyOp.isValidDim(region);
333 if (isa<AffineDelinearizeIndexOp, AffineLinearizeIndexOp>(op))
334 return llvm::all_of(op->getOperands(),
335 [&](
Value arg) { return ::isValidDim(arg, region); });
338 if (
auto dimOp = dyn_cast<ShapedDimOpInterface>(op))
346 template <
typename AnyMemRefDefOp>
349 MemRefType memRefType = memrefDefOp.getType();
352 if (index >= memRefType.getRank()) {
357 if (!memRefType.isDynamicDim(index))
360 unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
361 return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
373 if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
381 if (!index.has_value())
385 Operation *op = dimOp.getShapedValue().getDefiningOp();
386 while (
auto castOp = dyn_cast<memref::CastOp>(op)) {
388 if (isa<UnrankedMemRefType>(castOp.getSource().getType()))
390 op = castOp.getSource().getDefiningOp();
395 int64_t i = index.value();
397 .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
399 .Default([](
Operation *) {
return false; });
466 if (
isPure(defOp) && llvm::all_of(defOp->getOperands(), [&](
Value operand) {
467 return affine::isValidSymbol(operand, region);
473 if (
auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp))
497 printer <<
'(' << operands.take_front(numDims) <<
')';
498 if (operands.size() > numDims)
499 printer <<
'[' << operands.drop_front(numDims) <<
']';
509 numDims = opInfos.size();
523 template <
typename OpTy>
528 for (
auto operand : operands) {
529 if (opIt++ < numDims) {
531 return op.emitOpError(
"operand cannot be used as a dimension id");
533 return op.emitOpError(
"operand cannot be used as a symbol");
544 return AffineValueMap(getAffineMap(), getOperands(), getResult());
551 AffineMapAttr mapAttr;
557 auto map = mapAttr.getValue();
559 if (map.getNumDims() != numDims ||
560 numDims + map.getNumSymbols() != result.
operands.size()) {
562 "dimension or symbol index mismatch");
565 result.
types.append(map.getNumResults(), indexTy);
570 p <<
" " << getMapAttr();
572 getAffineMap().getNumDims(), p);
583 "operand count and affine map dimension and symbol count must match");
587 return emitOpError(
"mapping must produce one value");
593 for (
Value operand : getMapOperands().drop_front(affineMap.
getNumDims())) {
595 return emitError(
"dimensional operand cannot be used as a symbol");
604 return llvm::all_of(getOperands(),
612 return llvm::all_of(getOperands(),
619 return llvm::all_of(getOperands(),
626 return llvm::all_of(getOperands(), [&](
Value operand) {
632 auto map = getAffineMap();
635 auto expr = map.getResult(0);
636 if (
auto dim = dyn_cast<AffineDimExpr>(expr))
637 return getOperand(dim.getPosition());
638 if (
auto sym = dyn_cast<AffineSymbolExpr>(expr))
639 return getOperand(map.getNumDims() + sym.getPosition());
643 bool hasPoison =
false;
645 map.constantFold(adaptor.getMapOperands(), result, &hasPoison);
648 if (failed(foldResult))
665 auto dimExpr = dyn_cast<AffineDimExpr>(e);
675 Value operand = operands[dimExpr.getPosition()];
676 int64_t operandDivisor = 1;
680 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() == 0) {
681 operandDivisor = forOp.getStepAsInt();
683 uint64_t lbLargestKnownDivisor =
684 forOp.getLowerBoundMap().getLargestKnownDivisorOfMapExprs();
685 operandDivisor = std::gcd(lbLargestKnownDivisor, forOp.getStepAsInt());
688 return operandDivisor;
695 if (
auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
696 int64_t constVal = constExpr.getValue();
697 return constVal >= 0 && constVal < k;
699 auto dimExpr = dyn_cast<AffineDimExpr>(e);
702 Value operand = operands[dimExpr.getPosition()];
706 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() >= 0 &&
707 forOp.hasConstantUpperBound() && forOp.getConstantUpperBound() <= k) {
723 auto bin = dyn_cast<AffineBinaryOpExpr>(e);
731 quotientTimesDiv = llhs;
737 quotientTimesDiv = rlhs;
747 if (forOp && forOp.hasConstantLowerBound())
748 return forOp.getConstantLowerBound();
755 if (!forOp || !forOp.hasConstantUpperBound())
760 if (forOp.hasConstantLowerBound()) {
761 return forOp.getConstantUpperBound() - 1 -
762 (forOp.getConstantUpperBound() - forOp.getConstantLowerBound() - 1) %
763 forOp.getStepAsInt();
765 return forOp.getConstantUpperBound() - 1;
776 constLowerBounds.reserve(operands.size());
777 constUpperBounds.reserve(operands.size());
778 for (
Value operand : operands) {
783 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr))
784 return constExpr.getValue();
799 constLowerBounds.reserve(operands.size());
800 constUpperBounds.reserve(operands.size());
801 for (
Value operand : operands) {
806 std::optional<int64_t> lowerBound;
807 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
808 lowerBound = constExpr.getValue();
811 constLowerBounds, constUpperBounds,
822 auto binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
833 binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
841 lhs = binExpr.getLHS();
842 rhs = binExpr.getRHS();
843 auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
847 int64_t rhsConstVal = rhsConst.getValue();
849 if (rhsConstVal <= 0)
854 std::optional<int64_t> lhsLbConst =
856 std::optional<int64_t> lhsUbConst =
858 if (lhsLbConst && lhsUbConst) {
859 int64_t lhsLbConstVal = *lhsLbConst;
860 int64_t lhsUbConstVal = *lhsUbConst;
864 divideFloorSigned(lhsLbConstVal, rhsConstVal) ==
865 divideFloorSigned(lhsUbConstVal, rhsConstVal)) {
867 divideFloorSigned(lhsLbConstVal, rhsConstVal), context);
873 divideCeilSigned(lhsLbConstVal, rhsConstVal) ==
874 divideCeilSigned(lhsUbConstVal, rhsConstVal)) {
881 lhsLbConstVal < rhsConstVal && lhsUbConstVal < rhsConstVal) {
893 if (
isQTimesDPlusR(lhs, operands, divisor, quotientTimesDiv, rem)) {
894 if (rhsConstVal % divisor == 0 &&
896 expr = quotientTimesDiv.
floorDiv(rhsConst);
897 }
else if (divisor % rhsConstVal == 0 &&
899 expr = rem % rhsConst;
925 if (operands.empty())
931 constLowerBounds.reserve(operands.size());
932 constUpperBounds.reserve(operands.size());
933 for (
Value operand : operands) {
947 if (
auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
948 lowerBounds.push_back(constExpr.getValue());
949 upperBounds.push_back(constExpr.getValue());
951 lowerBounds.push_back(
953 constLowerBounds, constUpperBounds,
955 upperBounds.push_back(
957 constLowerBounds, constUpperBounds,
966 unsigned i = exprEn.index();
968 if (lowerBounds[i] && upperBounds[i] && *lowerBounds[i] == *upperBounds[i])
973 if (!upperBounds[i]) {
974 irredundantExprs.push_back(e);
980 auto otherLowerBound = en.value();
981 unsigned pos = en.index();
982 if (pos == i || !otherLowerBound)
984 if (*otherLowerBound > *upperBounds[i])
986 if (*otherLowerBound < *upperBounds[i])
991 if (upperBounds[pos] && lowerBounds[i] &&
992 lowerBounds[i] == upperBounds[i] &&
993 otherLowerBound == *upperBounds[pos] && i < pos)
997 irredundantExprs.push_back(e);
999 if (!lowerBounds[i]) {
1000 irredundantExprs.push_back(e);
1005 auto otherUpperBound = en.value();
1006 unsigned pos = en.index();
1007 if (pos == i || !otherUpperBound)
1009 if (*otherUpperBound < *lowerBounds[i])
1011 if (*otherUpperBound > *lowerBounds[i])
1013 if (lowerBounds[pos] && upperBounds[i] &&
1014 lowerBounds[i] == upperBounds[i] &&
1015 otherUpperBound == lowerBounds[pos] && i < pos)
1019 irredundantExprs.push_back(e);
1031 static void LLVM_ATTRIBUTE_UNUSED
1033 assert(map.
getNumInputs() == operands.size() &&
"invalid operands for map");
1039 newResults.push_back(expr);
1056 unsigned dimOrSymbolPosition,
1060 bool isDimReplacement = (dimOrSymbolPosition < dims.size());
1061 unsigned pos = isDimReplacement ? dimOrSymbolPosition
1062 : dimOrSymbolPosition - dims.size();
1063 Value &v = isDimReplacement ? dims[pos] : syms[pos];
1076 AffineMap composeMap = affineApply.getAffineMap();
1077 assert(composeMap.
getNumResults() == 1 &&
"affine.apply with >1 results");
1079 affineApply.getMapOperands().end());
1093 dims.append(composeDims.begin(), composeDims.end());
1094 syms.append(composeSyms.begin(), composeSyms.end());
1095 *map = map->
replace(toReplace, replacementExpr, dims.size(), syms.size());
1124 for (
unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
1136 unsigned nDims = 0, nSyms = 0;
1138 dimReplacements.reserve(dims.size());
1139 symReplacements.reserve(syms.size());
1140 for (
auto *container : {&dims, &syms}) {
1141 bool isDim = (container == &dims);
1142 auto &repls = isDim ? dimReplacements : symReplacements;
1144 Value v = en.value();
1148 "map is function of unexpected expr@pos");
1154 operands->push_back(v);
1167 while (llvm::any_of(*operands, [](
Value v) {
1181 return b.
create<AffineApplyOp>(loc, map, valueOperands);
1203 for (
unsigned i : llvm::seq<unsigned>(0, map.
getNumResults())) {
1210 llvm::append_range(dims,
1212 llvm::append_range(symbols,
1219 operands = llvm::to_vector(llvm::concat<Value>(dims, symbols));
1228 assert(map.
getNumResults() == 1 &&
"building affine.apply with !=1 result");
1238 AffineApplyOp applyOp =
1243 for (
unsigned i = 0, e = constOperands.size(); i != e; ++i)
1248 if (failed(applyOp->fold(constOperands, foldResults)) ||
1249 foldResults.empty()) {
1251 listener->notifyOperationInserted(applyOp, {});
1252 return applyOp.getResult();
1256 return llvm::getSingleElement(foldResults);
1274 return llvm::map_to_vector(llvm::seq<unsigned>(0, map.
getNumResults()),
1276 return makeComposedFoldedAffineApply(
1277 b, loc, map.getSubMap({i}), operands);
1281 template <
typename OpTy>
1293 return makeComposedMinMax<AffineMinOp>(b, loc, map, operands);
1296 template <
typename OpTy>
1308 auto minMaxOp = makeComposedMinMax<OpTy>(newBuilder, loc, map, operands);
1312 for (
unsigned i = 0, e = constOperands.size(); i != e; ++i)
1317 if (failed(minMaxOp->fold(constOperands, foldResults)) ||
1318 foldResults.empty()) {
1320 listener->notifyOperationInserted(minMaxOp, {});
1321 return minMaxOp.getResult();
1325 return llvm::getSingleElement(foldResults);
1332 return makeComposedFoldedMinMax<AffineMinOp>(b, loc, map, operands);
1339 return makeComposedFoldedMinMax<AffineMaxOp>(b, loc, map, operands);
1344 template <
class MapOrSet>
1347 if (!mapOrSet || operands->empty())
1350 assert(mapOrSet->getNumInputs() == operands->size() &&
1351 "map/set inputs must match number of operands");
1353 auto *context = mapOrSet->getContext();
1355 resultOperands.reserve(operands->size());
1357 remappedSymbols.reserve(operands->size());
1358 unsigned nextDim = 0;
1359 unsigned nextSym = 0;
1360 unsigned oldNumSyms = mapOrSet->getNumSymbols();
1362 for (
unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
1363 if (i < mapOrSet->getNumDims()) {
1367 remappedSymbols.push_back((*operands)[i]);
1370 resultOperands.push_back((*operands)[i]);
1373 resultOperands.push_back((*operands)[i]);
1377 resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
1378 *operands = resultOperands;
1379 *mapOrSet = mapOrSet->replaceDimsAndSymbols(
1380 dimRemapping, {}, nextDim, oldNumSyms + nextSym);
1382 assert(mapOrSet->getNumInputs() == operands->size() &&
1383 "map/set inputs must match number of operands");
1392 template <
class MapOrSet>
1395 if (!mapOrSet || operands.empty())
1398 unsigned numOperands = operands.size();
1400 assert(mapOrSet.getNumInputs() == numOperands &&
1401 "map/set inputs must match number of operands");
1403 auto *context = mapOrSet.getContext();
1405 resultOperands.reserve(numOperands);
1407 remappedDims.reserve(numOperands);
1409 symOperands.reserve(mapOrSet.getNumSymbols());
1410 unsigned nextSym = 0;
1411 unsigned nextDim = 0;
1412 unsigned oldNumDims = mapOrSet.getNumDims();
1414 resultOperands.assign(operands.begin(), operands.begin() + oldNumDims);
1415 for (
unsigned i = oldNumDims, e = mapOrSet.getNumInputs(); i != e; ++i) {
1418 symRemapping[i - oldNumDims] =
1420 remappedDims.push_back(operands[i]);
1423 symOperands.push_back(operands[i]);
1427 append_range(resultOperands, remappedDims);
1428 append_range(resultOperands, symOperands);
1429 operands = resultOperands;
1430 mapOrSet = mapOrSet.replaceDimsAndSymbols(
1431 {}, symRemapping, oldNumDims + nextDim, nextSym);
1433 assert(mapOrSet.getNumInputs() == operands.size() &&
1434 "map/set inputs must match number of operands");
1438 template <
class MapOrSet>
1441 static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
1442 "Argument must be either of AffineMap or IntegerSet type");
1444 if (!mapOrSet || operands->empty())
1447 assert(mapOrSet->getNumInputs() == operands->size() &&
1448 "map/set inputs must match number of operands");
1450 canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
1451 legalizeDemotedDims<MapOrSet>(*mapOrSet, *operands);
1454 llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
1455 llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
1457 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr))
1458 usedDims[dimExpr.getPosition()] =
true;
1459 else if (
auto symExpr = dyn_cast<AffineSymbolExpr>(expr))
1460 usedSyms[symExpr.getPosition()] =
true;
1463 auto *context = mapOrSet->getContext();
1466 resultOperands.reserve(operands->size());
1468 llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
1470 unsigned nextDim = 0;
1471 for (
unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
1474 auto it = seenDims.find((*operands)[i]);
1475 if (it == seenDims.end()) {
1477 resultOperands.push_back((*operands)[i]);
1478 seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
1480 dimRemapping[i] = it->second;
1484 llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
1486 unsigned nextSym = 0;
1487 for (
unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
1493 IntegerAttr operandCst;
1494 if (
matchPattern((*operands)[i + mapOrSet->getNumDims()],
1501 auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
1502 if (it == seenSymbols.end()) {
1504 resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
1505 seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
1508 symRemapping[i] = it->second;
1511 *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
1513 *operands = resultOperands;
1518 canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
1523 canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
1530 template <
typename AffineOpTy>
1539 LogicalResult matchAndRewrite(AffineOpTy affineOp,
1542 llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
1543 AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp,
1544 AffineVectorStoreOp, AffineVectorLoadOp>::value,
1545 "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op "
1547 auto map = affineOp.getAffineMap();
1549 auto oldOperands = affineOp.getMapOperands();
1554 if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
1555 resultOperands.begin()))
1558 replaceAffineOp(rewriter, affineOp, map, resultOperands);
1566 void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
1573 void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
1577 prefetch, prefetch.getMemref(), map, mapOperands, prefetch.getIsWrite(),
1578 prefetch.getLocalityHint(), prefetch.getIsDataCache());
1581 void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
1585 store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
1588 void SimplifyAffineOp<AffineVectorLoadOp>::replaceAffineOp(
1592 vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map,
1596 void SimplifyAffineOp<AffineVectorStoreOp>::replaceAffineOp(
1600 vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map,
1605 template <
typename AffineOpTy>
1606 void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
1615 results.
add<SimplifyAffineOp<AffineApplyOp>>(context);
1646 p <<
" " << getSrcMemRef() <<
'[';
1648 p <<
"], " << getDstMemRef() <<
'[';
1650 p <<
"], " << getTagMemRef() <<
'[';
1655 p <<
", " << getNumElementsPerStride();
1657 p <<
" : " << getSrcMemRefType() <<
", " << getDstMemRefType() <<
", "
1658 << getTagMemRefType();
1670 AffineMapAttr srcMapAttr;
1673 AffineMapAttr dstMapAttr;
1676 AffineMapAttr tagMapAttr;
1691 getSrcMapAttrStrName(),
1695 getDstMapAttrStrName(),
1699 getTagMapAttrStrName(),
1708 if (!strideInfo.empty() && strideInfo.size() != 2) {
1710 "expected two stride related operands");
1712 bool isStrided = strideInfo.size() == 2;
1717 if (types.size() != 3)
1735 if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
1736 dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
1737 tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1739 "memref operand count not equal to map.numInputs");
1743 LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
1744 if (!llvm::isa<MemRefType>(getOperand(getSrcMemRefOperandIndex()).
getType()))
1745 return emitOpError(
"expected DMA source to be of memref type");
1746 if (!llvm::isa<MemRefType>(getOperand(getDstMemRefOperandIndex()).
getType()))
1747 return emitOpError(
"expected DMA destination to be of memref type");
1748 if (!llvm::isa<MemRefType>(getOperand(getTagMemRefOperandIndex()).
getType()))
1749 return emitOpError(
"expected DMA tag to be of memref type");
1751 unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
1752 getDstMap().getNumInputs() +
1753 getTagMap().getNumInputs();
1754 if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
1755 getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
1756 return emitOpError(
"incorrect number of operands");
1760 for (
auto idx : getSrcIndices()) {
1761 if (!idx.getType().isIndex())
1762 return emitOpError(
"src index to dma_start must have 'index' type");
1765 "src index must be a valid dimension or symbol identifier");
1767 for (
auto idx : getDstIndices()) {
1768 if (!idx.getType().isIndex())
1769 return emitOpError(
"dst index to dma_start must have 'index' type");
1772 "dst index must be a valid dimension or symbol identifier");
1774 for (
auto idx : getTagIndices()) {
1775 if (!idx.getType().isIndex())
1776 return emitOpError(
"tag index to dma_start must have 'index' type");
1779 "tag index must be a valid dimension or symbol identifier");
1790 void AffineDmaStartOp::getEffects(
1816 p <<
" " << getTagMemRef() <<
'[';
1821 p <<
" : " << getTagMemRef().getType();
1832 AffineMapAttr tagMapAttr;
1841 getTagMapAttrStrName(),
1850 if (!llvm::isa<MemRefType>(type))
1852 "expected tag to be of memref type");
1854 if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1856 "tag memref operand count != to map.numInputs");
1860 LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
1861 if (!llvm::isa<MemRefType>(getOperand(0).
getType()))
1862 return emitOpError(
"expected DMA tag to be of memref type");
1864 for (
auto idx : getTagIndices()) {
1865 if (!idx.getType().isIndex())
1866 return emitOpError(
"index to dma_wait must have 'index' type");
1869 "index must be a valid dimension or symbol identifier");
1880 void AffineDmaWaitOp::getEffects(
1896 ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
1897 assert(((!lbMap && lbOperands.empty()) ||
1899 "lower bound operand count does not match the affine map");
1900 assert(((!ubMap && ubOperands.empty()) ||
1902 "upper bound operand count does not match the affine map");
1903 assert(step > 0 &&
"step has to be a positive integer constant");
1909 getOperandSegmentSizeAttr(),
1911 static_cast<int32_t>(ubOperands.size()),
1912 static_cast<int32_t>(iterArgs.size())}));
1914 for (
Value val : iterArgs)
1936 Value inductionVar =
1938 for (
Value val : iterArgs)
1939 bodyBlock->
addArgument(val.getType(), val.getLoc());
1944 if (iterArgs.empty() && !bodyBuilder) {
1945 ensureTerminator(*bodyRegion, builder, result.
location);
1946 }
else if (bodyBuilder) {
1949 bodyBuilder(builder, result.
location, inductionVar,
1955 int64_t ub, int64_t step,
ValueRange iterArgs,
1956 BodyBuilderFn bodyBuilder) {
1959 return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,
1963 LogicalResult AffineForOp::verifyRegions() {
1966 auto *body = getBody();
1967 if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
1968 return emitOpError(
"expected body to have a single index argument for the "
1969 "induction variable");
1973 if (getLowerBoundMap().getNumInputs() > 0)
1975 getLowerBoundMap().getNumDims())))
1978 if (getUpperBoundMap().getNumInputs() > 0)
1980 getUpperBoundMap().getNumDims())))
1982 if (getLowerBoundMap().getNumResults() < 1)
1983 return emitOpError(
"expected lower bound map to have at least one result");
1984 if (getUpperBoundMap().getNumResults() < 1)
1985 return emitOpError(
"expected upper bound map to have at least one result");
1987 unsigned opNumResults = getNumResults();
1988 if (opNumResults == 0)
1994 if (getNumIterOperands() != opNumResults)
1996 "mismatch between the number of loop-carried values and results");
1997 if (getNumRegionIterArgs() != opNumResults)
1999 "mismatch between the number of basic block args and results");
2009 bool failedToParsedMinMax =
2013 auto boundAttrStrName =
2014 isLower ? AffineForOp::getLowerBoundMapAttrName(result.
name)
2015 : AffineForOp::getUpperBoundMapAttrName(result.
name);
2022 if (!boundOpInfos.empty()) {
2024 if (boundOpInfos.size() > 1)
2026 "expected only one loop bound operand");
2051 if (
auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(boundAttr)) {
2052 unsigned currentNumOperands = result.
operands.size();
2057 auto map = affineMapAttr.getValue();
2061 "dim operand count and affine map dim count must match");
2063 unsigned numDimAndSymbolOperands =
2064 result.
operands.size() - currentNumOperands;
2065 if (numDims + map.
getNumSymbols() != numDimAndSymbolOperands)
2068 "symbol operand count and affine map symbol count must match");
2074 return p.
emitError(attrLoc,
"lower loop bound affine map with "
2075 "multiple results requires 'max' prefix");
2077 return p.
emitError(attrLoc,
"upper loop bound affine map with multiple "
2078 "results requires 'min' prefix");
2084 if (
auto integerAttr = llvm::dyn_cast<IntegerAttr>(boundAttr)) {
2094 "expected valid affine map representation for loop bounds");
2106 int64_t numOperands = result.
operands.size();
2109 int64_t numLbOperands = result.
operands.size() - numOperands;
2112 numOperands = result.
operands.size();
2115 int64_t numUbOperands = result.
operands.size() - numOperands;
2120 getStepAttrName(result.
name),
2124 IntegerAttr stepAttr;
2126 getStepAttrName(result.
name).data(),
2130 if (stepAttr.getValue().isNegative())
2133 "expected step to be representable as a positive signed integer");
2141 regionArgs.push_back(inductionVariable);
2149 for (
auto argOperandType :
2150 llvm::zip(llvm::drop_begin(regionArgs), operands, result.
types)) {
2151 Type type = std::get<2>(argOperandType);
2152 std::get<0>(argOperandType).type = type;
2160 getOperandSegmentSizeAttr(),
2162 static_cast<int32_t>(numUbOperands),
2163 static_cast<int32_t>(operands.size())}));
2167 if (regionArgs.size() != result.
types.size() + 1)
2170 "mismatch between the number of loop-carried values and results");
2174 AffineForOp::ensureTerminator(*body, builder, result.
location);
2196 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
2197 p << constExpr.getValue();
2205 if (isa<AffineSymbolExpr>(expr)) {
2221 unsigned AffineForOp::getNumIterOperands() {
2222 AffineMap lbMap = getLowerBoundMapAttr().getValue();
2223 AffineMap ubMap = getUpperBoundMapAttr().getValue();
2228 std::optional<MutableArrayRef<OpOperand>>
2229 AffineForOp::getYieldedValuesMutable() {
2230 return cast<AffineYieldOp>(getBody()->getTerminator()).getOperandsMutable();
2242 if (getStepAsInt() != 1)
2243 p <<
" step " << getStepAsInt();
2245 bool printBlockTerminators =
false;
2246 if (getNumIterOperands() > 0) {
2248 auto regionArgs = getRegionIterArgs();
2249 auto operands = getInits();
2251 llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](
auto it) {
2252 p << std::get<0>(it) <<
" = " << std::get<1>(it);
2254 p <<
") -> (" << getResultTypes() <<
")";
2255 printBlockTerminators =
true;
2260 printBlockTerminators);
2262 (*this)->getAttrs(),
2263 {getLowerBoundMapAttrName(getOperation()->getName()),
2264 getUpperBoundMapAttrName(getOperation()->getName()),
2265 getStepAttrName(getOperation()->getName()),
2266 getOperandSegmentSizeAttr()});
2271 auto foldLowerOrUpperBound = [&forOp](
bool lower) {
2275 auto boundOperands =
2276 lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
2277 for (
auto operand : boundOperands) {
2280 operandConstants.push_back(operandCst);
2284 lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
2286 "bound maps should have at least one result");
2288 if (failed(boundMap.
constantFold(operandConstants, foldedResults)))
2292 assert(!foldedResults.empty() &&
"bounds should have at least one result");
2293 auto maxOrMin = llvm::cast<IntegerAttr>(foldedResults[0]).getValue();
2294 for (
unsigned i = 1, e = foldedResults.size(); i < e; i++) {
2295 auto foldedResult = llvm::cast<IntegerAttr>(foldedResults[i]).getValue();
2296 maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
2297 : llvm::APIntOps::smin(maxOrMin, foldedResult);
2299 lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
2300 : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
2305 bool folded =
false;
2306 if (!forOp.hasConstantLowerBound())
2307 folded |= succeeded(foldLowerOrUpperBound(
true));
2310 if (!forOp.hasConstantUpperBound())
2311 folded |= succeeded(foldLowerOrUpperBound(
false));
2312 return success(folded);
2320 auto lbMap = forOp.getLowerBoundMap();
2321 auto ubMap = forOp.getUpperBoundMap();
2322 auto prevLbMap = lbMap;
2323 auto prevUbMap = ubMap;
2336 if (lbMap == prevLbMap && ubMap == prevUbMap)
2339 if (lbMap != prevLbMap)
2340 forOp.setLowerBound(lbOperands, lbMap);
2341 if (ubMap != prevUbMap)
2342 forOp.setUpperBound(ubOperands, ubMap);
2348 static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
2349 int64_t step = forOp.getStepAsInt();
2350 if (!forOp.hasConstantBounds() || step <= 0)
2351 return std::nullopt;
2352 int64_t lb = forOp.getConstantLowerBound();
2353 int64_t ub = forOp.getConstantUpperBound();
2354 return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
2362 LogicalResult matchAndRewrite(AffineForOp forOp,
2365 if (!llvm::hasSingleElement(*forOp.getBody()))
2367 if (forOp.getNumResults() == 0)
2369 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
2370 if (tripCount == 0) {
2373 rewriter.
replaceOp(forOp, forOp.getInits());
2377 auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
2378 auto iterArgs = forOp.getRegionIterArgs();
2379 bool hasValDefinedOutsideLoop =
false;
2380 bool iterArgsNotInOrder =
false;
2381 for (
unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
2382 Value val = yieldOp.getOperand(i);
2383 auto *iterArgIt = llvm::find(iterArgs, val);
2386 if (val == forOp.getInductionVar())
2388 if (iterArgIt == iterArgs.end()) {
2390 assert(forOp.isDefinedOutsideOfLoop(val) &&
2391 "must be defined outside of the loop");
2392 hasValDefinedOutsideLoop =
true;
2393 replacements.push_back(val);
2395 unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
2397 iterArgsNotInOrder =
true;
2398 replacements.push_back(forOp.getInits()[pos]);
2403 if (!tripCount.has_value() &&
2404 (hasValDefinedOutsideLoop || iterArgsNotInOrder))
2408 if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
2410 rewriter.
replaceOp(forOp, replacements);
2418 results.
add<AffineForEmptyLoopFolder>(context);
2422 assert((point.
isParent() || point == getRegion()) &&
"invalid region point");
2429 void AffineForOp::getSuccessorRegions(
2431 assert((point.
isParent() || point == getRegion()) &&
"expected loop region");
2436 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*
this);
2437 if (point.
isParent() && tripCount.has_value()) {
2438 if (tripCount.value() > 0) {
2439 regions.push_back(
RegionSuccessor(&getRegion(), getRegionIterArgs()));
2442 if (tripCount.value() == 0) {
2450 if (!point.
isParent() && tripCount == 1) {
2457 regions.push_back(
RegionSuccessor(&getRegion(), getRegionIterArgs()));
2463 return getTrivialConstantTripCount(op) == 0;
2466 LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
2476 results.assign(getInits().begin(), getInits().end());
2479 return success(folded);
2492 assert(map.
getNumResults() >= 1 &&
"bound map has at least one result");
2493 getLowerBoundOperandsMutable().assign(lbOperands);
2494 setLowerBoundMap(map);
2499 assert(map.
getNumResults() >= 1 &&
"bound map has at least one result");
2500 getUpperBoundOperandsMutable().assign(ubOperands);
2501 setUpperBoundMap(map);
2504 bool AffineForOp::hasConstantLowerBound() {
2505 return getLowerBoundMap().isSingleConstant();
2508 bool AffineForOp::hasConstantUpperBound() {
2509 return getUpperBoundMap().isSingleConstant();
2512 int64_t AffineForOp::getConstantLowerBound() {
2513 return getLowerBoundMap().getSingleConstantResult();
2516 int64_t AffineForOp::getConstantUpperBound() {
2517 return getUpperBoundMap().getSingleConstantResult();
2520 void AffineForOp::setConstantLowerBound(int64_t value) {
2524 void AffineForOp::setConstantUpperBound(int64_t value) {
2528 AffineForOp::operand_range AffineForOp::getControlOperands() {
2533 bool AffineForOp::matchingBoundOperandList() {
2534 auto lbMap = getLowerBoundMap();
2535 auto ubMap = getUpperBoundMap();
2541 for (
unsigned i = 0, e = lbMap.
getNumInputs(); i < e; i++) {
2543 if (getOperand(i) != getOperand(numOperands + i))
2551 std::optional<SmallVector<Value>> AffineForOp::getLoopInductionVars() {
2555 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopLowerBounds() {
2556 if (!hasConstantLowerBound())
2557 return std::nullopt;
2560 OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))};
2563 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopSteps() {
2569 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopUpperBounds() {
2570 if (!hasConstantUpperBound())
2574 OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()))};
2577 FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
2579 bool replaceInitOperandUsesInLoop,
2584 auto inits = llvm::to_vector(getInits());
2585 inits.append(newInitOperands.begin(), newInitOperands.end());
2586 AffineForOp newLoop = rewriter.
create<AffineForOp>(
2591 auto yieldOp = cast<AffineYieldOp>(getBody()->getTerminator());
2593 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
2598 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
2599 assert(newInitOperands.size() == newYieldedValues.size() &&
2600 "expected as many new yield values as new iter operands");
2602 yieldOp.getOperandsMutable().append(newYieldedValues);
2607 rewriter.
mergeBlocks(getBody(), newLoop.getBody(),
2608 newLoop.getBody()->getArguments().take_front(
2609 getBody()->getNumArguments()));
2611 if (replaceInitOperandUsesInLoop) {
2614 for (
auto it : llvm::zip(newInitOperands, newIterArgs)) {
2625 newLoop->getResults().take_front(getNumResults()));
2626 return cast<LoopLikeOpInterface>(newLoop.getOperation());
2654 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2655 if (!ivArg || !ivArg.getOwner() || !ivArg.getOwner()->getParent())
2656 return AffineForOp();
2658 ivArg.getOwner()->getParent()->getParentOfType<AffineForOp>())
2660 return forOp.getInductionVar() == val ? forOp : AffineForOp();
2661 return AffineForOp();
2665 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2666 if (!ivArg || !ivArg.getOwner())
2669 auto parallelOp = dyn_cast_if_present<AffineParallelOp>(containingOp);
2670 if (parallelOp && llvm::is_contained(parallelOp.getIVs(), val))
2679 ivs->reserve(forInsts.size());
2680 for (
auto forInst : forInsts)
2681 ivs->push_back(forInst.getInductionVar());
2686 ivs.reserve(affineOps.size());
2689 if (
auto forOp = dyn_cast<AffineForOp>(op))
2690 ivs.push_back(forOp.getInductionVar());
2691 else if (
auto parallelOp = dyn_cast<AffineParallelOp>(op))
2692 for (
size_t i = 0; i < parallelOp.getBody()->getNumArguments(); i++)
2693 ivs.push_back(parallelOp.getBody()->getArgument(i));
2699 template <
typename BoundListTy,
typename LoopCreatorTy>
2704 LoopCreatorTy &&loopCreatorFn) {
2705 assert(lbs.size() == ubs.size() &&
"Mismatch in number of arguments");
2706 assert(lbs.size() == steps.size() &&
"Mismatch in number of arguments");
2718 ivs.reserve(lbs.size());
2719 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
2725 if (i == e - 1 && bodyBuilderFn) {
2727 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2729 nestedBuilder.
create<AffineYieldOp>(nestedLoc);
2734 auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
2742 int64_t ub, int64_t step,
2743 AffineForOp::BodyBuilderFn bodyBuilderFn) {
2744 return builder.
create<AffineForOp>(loc, lb, ub, step,
2745 std::nullopt, bodyBuilderFn);
2752 AffineForOp::BodyBuilderFn bodyBuilderFn) {
2755 if (lbConst && ubConst)
2757 ubConst.value(), step, bodyBuilderFn);
2760 std::nullopt, bodyBuilderFn);
2788 LogicalResult matchAndRewrite(AffineIfOp ifOp,
2790 if (ifOp.getElseRegion().empty() ||
2791 !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
2806 LogicalResult matchAndRewrite(AffineIfOp op,
2809 auto isTriviallyFalse = [](
IntegerSet iSet) {
2810 return iSet.isEmptyIntegerSet();
2814 return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 &&
2815 iSet.getConstraint(0) == 0);
2818 IntegerSet affineIfConditions = op.getIntegerSet();
2820 if (isTriviallyFalse(affineIfConditions)) {
2824 if (op.getNumResults() == 0 && !op.hasElse()) {
2830 blockToMove = op.getElseBlock();
2831 }
else if (isTriviallyTrue(affineIfConditions)) {
2832 blockToMove = op.getThenBlock();
2850 rewriter.
eraseOp(blockToMoveTerminator);
2858 void AffineIfOp::getSuccessorRegions(
2867 if (getElseRegion().empty()) {
2868 regions.push_back(getResults());
2884 auto conditionAttr =
2885 (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2887 return emitOpError(
"requires an integer set attribute named 'condition'");
2890 IntegerSet condition = conditionAttr.getValue();
2892 return emitOpError(
"operand count and condition integer set dimension and "
2893 "symbol count must match");
2905 IntegerSetAttr conditionAttr;
2908 AffineIfOp::getConditionAttrStrName(),
2914 auto set = conditionAttr.getValue();
2915 if (set.getNumDims() != numDims)
2918 "dim operand count and integer set dim count must match");
2919 if (numDims + set.getNumSymbols() != result.
operands.size())
2922 "symbol operand count and integer set symbol count must match");
2936 AffineIfOp::ensureTerminator(*thenRegion, parser.
getBuilder(),
2943 AffineIfOp::ensureTerminator(*elseRegion, parser.
getBuilder(),
2955 auto conditionAttr =
2956 (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2957 p <<
" " << conditionAttr;
2959 conditionAttr.getValue().getNumDims(), p);
2966 auto &elseRegion = this->getElseRegion();
2967 if (!elseRegion.
empty()) {
2976 getConditionAttrStrName());
2981 ->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName())
2985 void AffineIfOp::setIntegerSet(
IntegerSet newSet) {
2991 (*this)->setOperands(operands);
2996 bool withElseRegion) {
2997 assert(resultTypes.empty() || withElseRegion);
3006 if (resultTypes.empty())
3007 AffineIfOp::ensureTerminator(*thenRegion, builder, result.
location);
3010 if (withElseRegion) {
3012 if (resultTypes.empty())
3013 AffineIfOp::ensureTerminator(*elseRegion, builder, result.
location);
3019 AffineIfOp::build(builder, result, {}, set, args,
3034 if (llvm::none_of(operands,
3045 auto set = getIntegerSet();
3051 if (getIntegerSet() == set && llvm::equal(operands, getOperands()))
3054 setConditional(set, operands);
3060 results.
add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context);
3069 assert(operands.size() == 1 + map.
getNumInputs() &&
"inconsistent operands");
3073 auto memrefType = llvm::cast<MemRefType>(operands[0].
getType());
3074 result.
types.push_back(memrefType.getElementType());
3079 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
3082 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
3084 result.
types.push_back(memrefType.getElementType());
3089 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
3090 int64_t rank = memrefType.getRank();
3095 build(builder, result, memref, map, indices);
3104 AffineMapAttr mapAttr;
3109 AffineLoadOp::getMapAttrStrName(),
3119 p <<
" " << getMemRef() <<
'[';
3120 if (AffineMapAttr mapAttr =
3121 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3125 {getMapAttrStrName()});
3131 template <
typename AffineMemOpTy>
3132 static LogicalResult
3135 MemRefType memrefType,
unsigned numIndexOperands) {
3138 return op->emitOpError(
"affine map num results must equal memref rank");
3140 return op->emitOpError(
"expects as many subscripts as affine map inputs");
3142 for (
auto idx : mapOperands) {
3143 if (!idx.getType().isIndex())
3144 return op->emitOpError(
"index to load must have 'index' type");
3154 if (
getType() != memrefType.getElementType())
3155 return emitOpError(
"result type must match element type of memref");
3158 *
this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3159 getMapOperands(), memrefType,
3160 getNumOperands() - 1)))
3168 results.
add<SimplifyAffineOp<AffineLoadOp>>(context);
3177 auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
3184 auto global = dyn_cast_or_null<memref::GlobalOp>(
3191 llvm::dyn_cast_or_null<DenseElementsAttr>(global.getConstantInitValue());
3195 if (
auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(cstAttr))
3196 return splatAttr.getSplatValue<
Attribute>();
3198 if (!getAffineMap().isConstant())
3200 auto indices = llvm::to_vector<4>(
3201 llvm::map_range(getAffineMap().getConstantResults(),
3202 [](int64_t v) -> uint64_t {
return v; }));
3203 return cstAttr.getValues<
Attribute>()[indices];
3213 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
3224 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
3225 int64_t rank = memrefType.getRank();
3230 build(builder, result, valueToStore, memref, map, indices);
3239 AffineMapAttr mapAttr;
3244 mapOperands, mapAttr, AffineStoreOp::getMapAttrStrName(),
3255 p <<
" " << getValueToStore();
3256 p <<
", " << getMemRef() <<
'[';
3257 if (AffineMapAttr mapAttr =
3258 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3262 {getMapAttrStrName()});
3269 if (getValueToStore().
getType() != memrefType.getElementType())
3271 "value to store must have the same type as memref element type");
3274 *
this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3275 getMapOperands(), memrefType,
3276 getNumOperands() - 2)))
3284 results.
add<SimplifyAffineOp<AffineStoreOp>>(context);
3287 LogicalResult AffineStoreOp::fold(FoldAdaptor adaptor,
3297 template <
typename T>
3300 if (op.getNumOperands() !=
3301 op.getMap().getNumDims() + op.getMap().getNumSymbols())
3302 return op.emitOpError(
3303 "operand count and affine map dimension and symbol count must match");
3305 if (op.getMap().getNumResults() == 0)
3306 return op.emitOpError(
"affine map expect at least one result");
3310 template <
typename T>
3312 p <<
' ' << op->getAttr(T::getMapAttrStrName());
3313 auto operands = op.getOperands();
3314 unsigned numDims = op.getMap().getNumDims();
3315 p <<
'(' << operands.take_front(numDims) <<
')';
3317 if (operands.size() != numDims)
3318 p <<
'[' << operands.drop_front(numDims) <<
']';
3320 {T::getMapAttrStrName()});
3323 template <
typename T>
3330 AffineMapAttr mapAttr;
3346 template <
typename T>
3348 static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
3349 "expected affine min or max op");
3355 auto foldedMap = op.getMap().partialConstantFold(operands, &results);
3357 if (foldedMap.getNumSymbols() == 1 && foldedMap.isSymbolIdentity())
3358 return op.getOperand(0);
3361 if (results.empty()) {
3363 if (foldedMap == op.getMap())
3366 return op.getResult();
3370 auto resultIt = std::is_same<T, AffineMinOp>::value
3371 ? llvm::min_element(results)
3372 : llvm::max_element(results);
3373 if (resultIt == results.end())
3379 template <
typename T>
3385 AffineMap oldMap = affineOp.getAffineMap();
3391 if (!llvm::is_contained(newExprs, expr))
3392 newExprs.push_back(expr);
3422 template <
typename T>
3428 AffineMap oldMap = affineOp.getAffineMap();
3430 affineOp.getMapOperands().take_front(oldMap.
getNumDims());
3432 affineOp.getMapOperands().take_back(oldMap.
getNumSymbols());
3434 auto newDimOperands = llvm::to_vector<8>(dimOperands);
3435 auto newSymOperands = llvm::to_vector<8>(symOperands);
3443 if (
auto symExpr = dyn_cast<AffineSymbolExpr>(expr)) {
3444 Value symValue = symOperands[symExpr.getPosition()];
3446 producerOps.push_back(producerOp);
3449 }
else if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
3450 Value dimValue = dimOperands[dimExpr.getPosition()];
3452 producerOps.push_back(producerOp);
3459 newExprs.push_back(expr);
3462 if (producerOps.empty())
3469 for (T producerOp : producerOps) {
3470 AffineMap producerMap = producerOp.getAffineMap();
3471 unsigned numProducerDims = producerMap.
getNumDims();
3476 producerOp.getMapOperands().take_front(numProducerDims);
3478 producerOp.getMapOperands().take_back(numProducerSyms);
3479 newDimOperands.append(dimValues.begin(), dimValues.end());
3480 newSymOperands.append(symValues.begin(), symValues.end());
3484 newExprs.push_back(expr.shiftDims(numProducerDims, numUsedDims)
3485 .shiftSymbols(numProducerSyms, numUsedSyms));
3488 numUsedDims += numProducerDims;
3489 numUsedSyms += numProducerSyms;
3495 llvm::to_vector<8>(llvm::concat<Value>(newDimOperands, newSymOperands));
3514 if (!resultExpr.isPureAffine())
3519 if (failed(flattenResult))
3532 if (llvm::is_sorted(flattenedExprs))
3537 llvm::to_vector(llvm::seq<unsigned>(0, map.
getNumResults()));
3538 llvm::sort(resultPermutation, [&](
unsigned lhs,
unsigned rhs) {
3539 return flattenedExprs[lhs] < flattenedExprs[rhs];
3542 for (
unsigned idx : resultPermutation)
3563 template <
typename T>
3569 AffineMap map = affineOp.getAffineMap();
3577 template <
typename T>
3583 if (affineOp.getMap().getNumResults() != 1)
3586 affineOp.getOperands());
3614 return parseAffineMinMaxOp<AffineMinOp>(parser, result);
3642 return parseAffineMinMaxOp<AffineMaxOp>(parser, result);
3661 IntegerAttr hintInfo;
3663 StringRef readOrWrite, cacheType;
3665 AffineMapAttr mapAttr;
3669 AffinePrefetchOp::getMapAttrStrName(),
3675 AffinePrefetchOp::getLocalityHintAttrStrName(),
3685 if (readOrWrite !=
"read" && readOrWrite !=
"write")
3687 "rw specifier has to be 'read' or 'write'");
3688 result.
addAttribute(AffinePrefetchOp::getIsWriteAttrStrName(),
3691 if (cacheType !=
"data" && cacheType !=
"instr")
3693 "cache type has to be 'data' or 'instr'");
3695 result.
addAttribute(AffinePrefetchOp::getIsDataCacheAttrStrName(),
3702 p <<
" " << getMemref() <<
'[';
3703 AffineMapAttr mapAttr =
3704 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3707 p <<
']' <<
", " << (getIsWrite() ?
"write" :
"read") <<
", "
3708 <<
"locality<" << getLocalityHint() <<
">, "
3709 << (getIsDataCache() ?
"data" :
"instr");
3711 (*this)->getAttrs(),
3712 {getMapAttrStrName(), getLocalityHintAttrStrName(),
3713 getIsDataCacheAttrStrName(), getIsWriteAttrStrName()});
3718 auto mapAttr = (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3722 return emitOpError(
"affine.prefetch affine map num results must equal"
3725 return emitOpError(
"too few operands");
3727 if (getNumOperands() != 1)
3728 return emitOpError(
"too few operands");
3732 for (
auto idx : getMapOperands()) {
3735 "index must be a valid dimension or symbol identifier");
3743 results.
add<SimplifyAffineOp<AffinePrefetchOp>>(context);
3746 LogicalResult AffinePrefetchOp::fold(FoldAdaptor adaptor,
3761 auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) {
3765 build(builder, result, resultTypes, reductions, lbs, {}, ubs,
3775 assert(llvm::all_of(lbMaps,
3777 return m.getNumDims() == lbMaps[0].getNumDims() &&
3778 m.getNumSymbols() == lbMaps[0].getNumSymbols();
3780 "expected all lower bounds maps to have the same number of dimensions "
3782 assert(llvm::all_of(ubMaps,
3784 return m.getNumDims() == ubMaps[0].getNumDims() &&
3785 m.getNumSymbols() == ubMaps[0].getNumSymbols();
3787 "expected all upper bounds maps to have the same number of dimensions "
3789 assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) &&
3790 "expected lower bound maps to have as many inputs as lower bound "
3792 assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) &&
3793 "expected upper bound maps to have as many inputs as upper bound "
3801 for (arith::AtomicRMWKind reduction : reductions)
3802 reductionAttrs.push_back(
3814 groups.reserve(groups.size() + maps.size());
3815 exprs.reserve(maps.size());
3817 llvm::append_range(exprs, m.getResults());
3818 groups.push_back(m.getNumResults());
3820 return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs,
3826 AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups);
3827 AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups);
3845 for (
unsigned i = 0, e = steps.size(); i < e; ++i)
3847 if (resultTypes.empty())
3848 ensureTerminator(*bodyRegion, builder, result.
location);
3852 return {&getRegion()};
3855 unsigned AffineParallelOp::getNumDims() {
return getSteps().size(); }
3857 AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
3858 return getOperands().take_front(getLowerBoundsMap().getNumInputs());
3861 AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
3862 return getOperands().drop_front(getLowerBoundsMap().getNumInputs());
3865 AffineMap AffineParallelOp::getLowerBoundMap(
unsigned pos) {
3866 auto values = getLowerBoundsGroups().getValues<int32_t>();
3868 for (
unsigned i = 0; i < pos; ++i)
3870 return getLowerBoundsMap().getSliceMap(start, values[pos]);
3873 AffineMap AffineParallelOp::getUpperBoundMap(
unsigned pos) {
3874 auto values = getUpperBoundsGroups().getValues<int32_t>();
3876 for (
unsigned i = 0; i < pos; ++i)
3878 return getUpperBoundsMap().getSliceMap(start, values[pos]);
3882 return AffineValueMap(getLowerBoundsMap(), getLowerBoundsOperands());
3886 return AffineValueMap(getUpperBoundsMap(), getUpperBoundsOperands());
3889 std::optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
3890 if (hasMinMaxBounds())
3891 return std::nullopt;
3896 AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
3899 for (
unsigned i = 0, e = rangesValueMap.
getNumResults(); i < e; ++i) {
3900 auto expr = rangesValueMap.
getResult(i);
3901 auto cst = dyn_cast<AffineConstantExpr>(expr);
3903 return std::nullopt;
3904 out.push_back(cst.getValue());
3909 Block *AffineParallelOp::getBody() {
return &getRegion().
front(); }
3911 OpBuilder AffineParallelOp::getBodyBuilder() {
3912 return OpBuilder(getBody(), std::prev(getBody()->end()));
3917 "operands to map must match number of inputs");
3919 auto ubOperands = getUpperBoundsOperands();
3922 newOperands.append(ubOperands.begin(), ubOperands.end());
3923 (*this)->setOperands(newOperands);
3930 "operands to map must match number of inputs");
3933 newOperands.append(ubOperands.begin(), ubOperands.end());
3934 (*this)->setOperands(newOperands);
3940 setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
3945 arith::AtomicRMWKind op) {
3947 case arith::AtomicRMWKind::addf:
3948 return isa<FloatType>(resultType);
3949 case arith::AtomicRMWKind::addi:
3950 return isa<IntegerType>(resultType);
3951 case arith::AtomicRMWKind::assign:
3953 case arith::AtomicRMWKind::mulf:
3954 return isa<FloatType>(resultType);
3955 case arith::AtomicRMWKind::muli:
3956 return isa<IntegerType>(resultType);
3957 case arith::AtomicRMWKind::maximumf:
3958 return isa<FloatType>(resultType);
3959 case arith::AtomicRMWKind::minimumf:
3960 return isa<FloatType>(resultType);
3961 case arith::AtomicRMWKind::maxs: {
3962 auto intType = llvm::dyn_cast<IntegerType>(resultType);
3963 return intType && intType.isSigned();
3965 case arith::AtomicRMWKind::mins: {
3966 auto intType = llvm::dyn_cast<IntegerType>(resultType);
3967 return intType && intType.isSigned();
3969 case arith::AtomicRMWKind::maxu: {
3970 auto intType = llvm::dyn_cast<IntegerType>(resultType);
3971 return intType && intType.isUnsigned();
3973 case arith::AtomicRMWKind::minu: {
3974 auto intType = llvm::dyn_cast<IntegerType>(resultType);
3975 return intType && intType.isUnsigned();
3977 case arith::AtomicRMWKind::ori:
3978 return isa<IntegerType>(resultType);
3979 case arith::AtomicRMWKind::andi:
3980 return isa<IntegerType>(resultType);
3987 auto numDims = getNumDims();
3990 getSteps().size() != numDims || getBody()->getNumArguments() != numDims) {
3991 return emitOpError() <<
"the number of region arguments ("
3992 << getBody()->getNumArguments()
3993 <<
") and the number of map groups for lower ("
3994 << getLowerBoundsGroups().getNumElements()
3995 <<
") and upper bound ("
3996 << getUpperBoundsGroups().getNumElements()
3997 <<
"), and the number of steps (" << getSteps().size()
3998 <<
") must all match";
4001 unsigned expectedNumLBResults = 0;
4002 for (APInt v : getLowerBoundsGroups()) {
4003 unsigned results = v.getZExtValue();
4005 return emitOpError()
4006 <<
"expected lower bound map to have at least one result";
4007 expectedNumLBResults += results;
4009 if (expectedNumLBResults != getLowerBoundsMap().getNumResults())
4010 return emitOpError() <<
"expected lower bounds map to have "
4011 << expectedNumLBResults <<
" results";
4012 unsigned expectedNumUBResults = 0;
4013 for (APInt v : getUpperBoundsGroups()) {
4014 unsigned results = v.getZExtValue();
4016 return emitOpError()
4017 <<
"expected upper bound map to have at least one result";
4018 expectedNumUBResults += results;
4020 if (expectedNumUBResults != getUpperBoundsMap().getNumResults())
4021 return emitOpError() <<
"expected upper bounds map to have "
4022 << expectedNumUBResults <<
" results";
4024 if (getReductions().size() != getNumResults())
4025 return emitOpError(
"a reduction must be specified for each output");
4031 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
4032 if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
4033 return emitOpError(
"invalid reduction attribute");
4034 auto kind = arith::symbolizeAtomicRMWKind(intAttr.getInt()).value();
4036 return emitOpError(
"result type cannot match reduction attribute");
4042 getLowerBoundsMap().getNumDims())))
4046 getUpperBoundsMap().getNumDims())))
4051 LogicalResult AffineValueMap::canonicalize() {
4053 auto newMap = getAffineMap();
4055 if (newMap == getAffineMap() && newOperands == operands)
4057 reset(newMap, newOperands);
4070 if (!lbCanonicalized && !ubCanonicalized)
4073 if (lbCanonicalized)
4075 if (ubCanonicalized)
4081 LogicalResult AffineParallelOp::fold(FoldAdaptor adaptor,
4093 StringRef keyword) {
4096 ValueRange dimOperands = operands.take_front(numDims);
4097 ValueRange symOperands = operands.drop_front(numDims);
4099 for (llvm::APInt groupSize : group) {
4103 unsigned size = groupSize.getZExtValue();
4108 p << keyword <<
'(';
4118 p <<
" (" << getBody()->getArguments() <<
") = (";
4120 getLowerBoundsOperands(),
"max");
4123 getUpperBoundsOperands(),
"min");
4126 bool elideSteps = llvm::all_of(steps, [](int64_t step) {
return step == 1; });
4129 llvm::interleaveComma(steps, p);
4132 if (getNumResults()) {
4134 llvm::interleaveComma(getReductions(), p, [&](
auto &attr) {
4135 arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(
4136 llvm::cast<IntegerAttr>(attr).getInt());
4137 p <<
"\"" << arith::stringifyAtomicRMWKind(sym) <<
"\"";
4139 p <<
") -> (" << getResultTypes() <<
")";
4146 (*this)->getAttrs(),
4147 {AffineParallelOp::getReductionsAttrStrName(),
4148 AffineParallelOp::getLowerBoundsMapAttrStrName(),
4149 AffineParallelOp::getLowerBoundsGroupsAttrStrName(),
4150 AffineParallelOp::getUpperBoundsMapAttrStrName(),
4151 AffineParallelOp::getUpperBoundsGroupsAttrStrName(),
4152 AffineParallelOp::getStepsAttrStrName()});
4165 "expected operands to be dim or symbol expression");
4168 for (
const auto &list : operands) {
4172 for (
Value operand : valueOperands) {
4173 unsigned pos = std::distance(uniqueOperands.begin(),
4174 llvm::find(uniqueOperands, operand));
4175 if (pos == uniqueOperands.size())
4176 uniqueOperands.push_back(operand);
4177 replacements.push_back(
4187 enum class MinMaxKind { Min, Max };
4211 const llvm::StringLiteral tmpAttrStrName =
"__pseudo_bound_map";
4213 StringRef mapName =
kind == MinMaxKind::Min
4214 ? AffineParallelOp::getUpperBoundsMapAttrStrName()
4215 : AffineParallelOp::getLowerBoundsMapAttrStrName();
4216 StringRef groupsName =
4217 kind == MinMaxKind::Min
4218 ? AffineParallelOp::getUpperBoundsGroupsAttrStrName()
4219 : AffineParallelOp::getLowerBoundsGroupsAttrStrName();
4236 auto parseOperands = [&]() {
4238 kind == MinMaxKind::Min ?
"min" :
"max"))) {
4239 mapOperands.clear();
4246 llvm::append_range(flatExprs, map.getValue().getResults());
4248 auto dimsRef = operandsRef.take_front(map.getValue().getNumDims());
4250 auto symsRef = operandsRef.drop_front(map.getValue().getNumDims());
4252 flatDimOperands.append(map.getValue().getNumResults(), dims);
4253 flatSymOperands.append(map.getValue().getNumResults(), syms);
4254 numMapsPerGroup.push_back(map.getValue().getNumResults());
4257 flatSymOperands.emplace_back(),
4258 flatExprs.emplace_back())))
4260 numMapsPerGroup.push_back(1);
4267 unsigned totalNumDims = 0;
4268 unsigned totalNumSyms = 0;
4269 for (
unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
4270 unsigned numDims = flatDimOperands[i].size();
4271 unsigned numSyms = flatSymOperands[i].size();
4272 flatExprs[i] = flatExprs[i]
4273 .shiftDims(numDims, totalNumDims)
4274 .shiftSymbols(numSyms, totalNumSyms);
4275 totalNumDims += numDims;
4276 totalNumSyms += numSyms;
4288 result.
operands.append(dimOperands.begin(), dimOperands.end());
4289 result.
operands.append(symOperands.begin(), symOperands.end());
4292 auto flatMap =
AffineMap::get(totalNumDims, totalNumSyms, flatExprs,
4294 flatMap = flatMap.replaceDimsAndSymbols(
4295 dimRplacements, symRepacements, dimOperands.size(), symOperands.size());
4319 AffineMapAttr stepsMapAttr;
4324 result.
addAttribute(AffineParallelOp::getStepsAttrStrName(),
4328 AffineParallelOp::getStepsAttrStrName(),
4335 auto stepsMap = stepsMapAttr.getValue();
4336 for (
const auto &result : stepsMap.getResults()) {
4337 auto constExpr = dyn_cast<AffineConstantExpr>(result);
4340 "steps must be constant integers");
4341 steps.push_back(constExpr.getValue());
4343 result.
addAttribute(AffineParallelOp::getStepsAttrStrName(),
4353 auto parseAttributes = [&]() -> ParseResult {
4363 std::optional<arith::AtomicRMWKind> reduction =
4364 arith::symbolizeAtomicRMWKind(attrVal.getValue());
4366 return parser.
emitError(loc,
"invalid reduction value: ") << attrVal;
4367 reductions.push_back(
4375 result.
addAttribute(AffineParallelOp::getReductionsAttrStrName(),
4384 for (
auto &iv : ivs)
4385 iv.type = indexType;
4391 AffineParallelOp::ensureTerminator(*body, builder, result.
location);
4400 auto *parentOp = (*this)->getParentOp();
4401 auto results = parentOp->getResults();
4402 auto operands = getOperands();
4404 if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
4405 return emitOpError() <<
"only terminates affine.if/for/parallel regions";
4406 if (parentOp->getNumResults() != getNumOperands())
4407 return emitOpError() <<
"parent of yield must have same number of "
4408 "results as the yield operands";
4409 for (
auto it : llvm::zip(results, operands)) {
4411 return emitOpError() <<
"types mismatch between yield op and its parent";
4424 assert(operands.size() == 1 + map.
getNumInputs() &&
"inconsistent operands");
4428 result.
types.push_back(resultType);
4432 VectorType resultType,
Value memref,
4434 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
4438 result.
types.push_back(resultType);
4442 VectorType resultType,
Value memref,
4444 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
4445 int64_t rank = memrefType.getRank();
4450 build(builder, result, resultType, memref, map, indices);
4453 void AffineVectorLoadOp::getCanonicalizationPatterns(
RewritePatternSet &results,
4455 results.
add<SimplifyAffineOp<AffineVectorLoadOp>>(context);
4463 MemRefType memrefType;
4464 VectorType resultType;
4466 AffineMapAttr mapAttr;
4471 AffineVectorLoadOp::getMapAttrStrName(),
4482 p <<
" " << getMemRef() <<
'[';
4483 if (AffineMapAttr mapAttr =
4484 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4488 {getMapAttrStrName()});
4494 VectorType vectorType) {
4496 if (memrefType.getElementType() != vectorType.getElementType())
4498 "requires memref and vector types of the same elemental type");
4505 *
this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4506 getMapOperands(), memrefType,
4507 getNumOperands() - 1)))
4523 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
4534 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
4535 int64_t rank = memrefType.getRank();
4540 build(builder, result, valueToStore, memref, map, indices);
4542 void AffineVectorStoreOp::getCanonicalizationPatterns(
4544 results.
add<SimplifyAffineOp<AffineVectorStoreOp>>(context);
4551 MemRefType memrefType;
4552 VectorType resultType;
4555 AffineMapAttr mapAttr;
4561 AffineVectorStoreOp::getMapAttrStrName(),
4572 p <<
" " << getValueToStore();
4573 p <<
", " << getMemRef() <<
'[';
4574 if (AffineMapAttr mapAttr =
4575 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4579 {getMapAttrStrName()});
4580 p <<
" : " <<
getMemRefType() <<
", " << getValueToStore().getType();
4586 *
this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4587 getMapOperands(), memrefType,
4588 getNumOperands() - 2)))
4601 void AffineDelinearizeIndexOp::build(
OpBuilder &odsBuilder,
4605 bool hasOuterBound) {
4607 : staticBasis.size() + 1,
4609 build(odsBuilder, odsState, returnTypes, linearIndex, dynamicBasis,
4613 void AffineDelinearizeIndexOp::build(
OpBuilder &odsBuilder,
4616 bool hasOuterBound) {
4617 if (hasOuterBound && !basis.empty() && basis.front() ==
nullptr) {
4618 hasOuterBound =
false;
4619 basis = basis.drop_front();
4625 build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4629 void AffineDelinearizeIndexOp::build(
OpBuilder &odsBuilder,
4633 bool hasOuterBound) {
4634 if (hasOuterBound && !basis.empty() && basis.front() ==
OpFoldResult()) {
4635 hasOuterBound =
false;
4636 basis = basis.drop_front();
4641 build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4645 void AffineDelinearizeIndexOp::build(
OpBuilder &odsBuilder,
4648 bool hasOuterBound) {
4649 build(odsBuilder, odsState, linearIndex,
ValueRange{}, basis, hasOuterBound);
4654 if (getNumResults() != staticBasis.size() &&
4655 getNumResults() != staticBasis.size() + 1)
4656 return emitOpError(
"should return an index for each basis element and up "
4657 "to one extra index");
4659 auto dynamicMarkersCount = llvm::count_if(staticBasis, ShapedType::isDynamic);
4660 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
4662 "mismatch between dynamic and static basis (kDynamic marker but no "
4663 "corresponding dynamic basis entry) -- this can only happen due to an "
4664 "incorrect fold/rewrite");
4666 if (!llvm::all_of(staticBasis, [](int64_t v) {
4667 return v > 0 || ShapedType::isDynamic(v);
4669 return emitOpError(
"no basis element may be statically non-positive");
4678 static std::optional<SmallVector<int64_t>>
4682 uint64_t dynamicBasisIndex = 0;
4685 mutableDynamicBasis.
erase(dynamicBasisIndex);
4687 ++dynamicBasisIndex;
4692 if (dynamicBasisIndex == dynamicBasis.size())
4693 return std::nullopt;
4699 staticBasis.push_back(ShapedType::kDynamic);
4701 staticBasis.push_back(*basisVal);
4708 AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
4710 std::optional<SmallVector<int64_t>> maybeStaticBasis =
4712 adaptor.getDynamicBasis());
4713 if (maybeStaticBasis) {
4714 setStaticBasis(*maybeStaticBasis);
4719 if (getNumResults() == 1) {
4720 result.push_back(getLinearIndex());
4724 if (adaptor.getLinearIndex() ==
nullptr)
4727 if (!adaptor.getDynamicBasis().empty())
4730 int64_t highPart = cast<IntegerAttr>(adaptor.getLinearIndex()).getInt();
4731 Type attrType = getLinearIndex().getType();
4734 if (hasOuterBound())
4735 staticBasis = staticBasis.drop_front();
4736 for (int64_t modulus : llvm::reverse(staticBasis)) {
4737 result.push_back(
IntegerAttr::get(attrType, llvm::mod(highPart, modulus)));
4738 highPart = llvm::divideFloorSigned(highPart, modulus);
4741 std::reverse(result.begin(), result.end());
4747 if (hasOuterBound()) {
4748 if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
4750 getDynamicBasis().drop_front(), builder);
4752 return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
4756 return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
4761 if (!hasOuterBound())
4769 struct DropUnitExtentBasis
4773 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4776 std::optional<Value> zero = std::nullopt;
4777 Location loc = delinearizeOp->getLoc();
4780 zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
4781 return zero.value();
4787 for (
auto [index, basis] :
4789 std::optional<int64_t> basisVal =
4792 replacements[index] =
getZero();
4794 newBasis.push_back(basis);
4797 if (newBasis.size() == delinearizeOp.getNumResults())
4799 "no unit basis elements");
4801 if (!newBasis.empty()) {
4803 auto newDelinearizeOp = rewriter.
create<affine::AffineDelinearizeIndexOp>(
4804 loc, delinearizeOp.getLinearIndex(), newBasis);
4807 for (
auto &replacement : replacements) {
4810 replacement = newDelinearizeOp->
getResult(newIndex++);
4814 rewriter.
replaceOp(delinearizeOp, replacements);
4829 struct CancelDelinearizeOfLinearizeDisjointExactTail
4833 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4835 auto linearizeOp = delinearizeOp.getLinearIndex()
4836 .getDefiningOp<affine::AffineLinearizeIndexOp>();
4839 "index doesn't come from linearize");
4841 if (!linearizeOp.getDisjoint())
4844 ValueRange linearizeIns = linearizeOp.getMultiIndex();
4848 size_t numMatches = 0;
4849 for (
auto [linSize, delinSize] : llvm::zip(
4850 llvm::reverse(linearizeBasis), llvm::reverse(delinearizeBasis))) {
4851 if (linSize != delinSize)
4856 if (numMatches == 0)
4858 delinearizeOp,
"final basis element doesn't match linearize");
4861 if (numMatches == linearizeBasis.size() &&
4862 numMatches == delinearizeBasis.size() &&
4863 linearizeIns.size() == delinearizeOp.getNumResults()) {
4864 rewriter.
replaceOp(delinearizeOp, linearizeOp.getMultiIndex());
4868 Value newLinearize = rewriter.
create<affine::AffineLinearizeIndexOp>(
4869 linearizeOp.getLoc(), linearizeIns.drop_back(numMatches),
4871 linearizeOp.getDisjoint());
4872 auto newDelinearize = rewriter.
create<affine::AffineDelinearizeIndexOp>(
4873 delinearizeOp.getLoc(), newLinearize,
4875 delinearizeOp.hasOuterBound());
4877 mergedResults.append(linearizeIns.take_back(numMatches).begin(),
4878 linearizeIns.take_back(numMatches).end());
4879 rewriter.
replaceOp(delinearizeOp, mergedResults);
4897 struct SplitDelinearizeSpanningLastLinearizeArg final
4901 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4903 auto linearizeOp = delinearizeOp.getLinearIndex()
4904 .getDefiningOp<affine::AffineLinearizeIndexOp>();
4907 "index doesn't come from linearize");
4909 if (!linearizeOp.getDisjoint())
4911 "linearize isn't disjoint");
4913 int64_t target = linearizeOp.getStaticBasis().back();
4914 if (ShapedType::isDynamic(target))
4916 linearizeOp,
"linearize ends with dynamic basis value");
4918 int64_t sizeToSplit = 1;
4919 size_t elemsToSplit = 0;
4921 for (int64_t basisElem : llvm::reverse(basis)) {
4922 if (ShapedType::isDynamic(basisElem))
4924 delinearizeOp,
"dynamic basis element while scanning for split");
4925 sizeToSplit *= basisElem;
4928 if (sizeToSplit > target)
4930 "overshot last argument size");
4931 if (sizeToSplit == target)
4935 if (sizeToSplit < target)
4937 delinearizeOp,
"product of known basis elements doesn't exceed last "
4938 "linearize argument");
4940 if (elemsToSplit < 2)
4943 "need at least two elements to form the basis product");
4945 Value linearizeWithoutBack =
4946 rewriter.
create<affine::AffineLinearizeIndexOp>(
4947 linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(),
4948 linearizeOp.getDynamicBasis(),
4949 linearizeOp.getStaticBasis().drop_back(),
4950 linearizeOp.getDisjoint());
4951 auto delinearizeWithoutSplitPart =
4952 rewriter.
create<affine::AffineDelinearizeIndexOp>(
4953 delinearizeOp.getLoc(), linearizeWithoutBack,
4954 delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),
4955 delinearizeOp.hasOuterBound());
4956 auto delinearizeBack = rewriter.
create<affine::AffineDelinearizeIndexOp>(
4957 delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(),
4958 basis.take_back(elemsToSplit),
true);
4960 llvm::concat<Value>(delinearizeWithoutSplitPart.getResults(),
4961 delinearizeBack.getResults()));
4962 rewriter.
replaceOp(delinearizeOp, results);
4969 void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
4972 .insert<CancelDelinearizeOfLinearizeDisjointExactTail,
4973 DropUnitExtentBasis, SplitDelinearizeSpanningLastLinearizeArg>(
4981 void AffineLinearizeIndexOp::build(
OpBuilder &odsBuilder,
4985 if (!basis.empty() && basis.front() ==
Value())
4986 basis = basis.drop_front();
4991 build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
4994 void AffineLinearizeIndexOp::build(
OpBuilder &odsBuilder,
5000 basis = basis.drop_front();
5004 build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
5007 void AffineLinearizeIndexOp::build(
OpBuilder &odsBuilder,
5011 build(odsBuilder, odsState, multiIndex,
ValueRange{}, basis, disjoint);
5015 size_t numIndexes = getMultiIndex().size();
5016 size_t numBasisElems = getStaticBasis().size();
5017 if (numIndexes != numBasisElems && numIndexes != numBasisElems + 1)
5018 return emitOpError(
"should be passed a basis element for each index except "
5019 "possibly the first");
5021 auto dynamicMarkersCount =
5022 llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
5023 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
5025 "mismatch between dynamic and static basis (kDynamic marker but no "
5026 "corresponding dynamic basis entry) -- this can only happen due to an "
5027 "incorrect fold/rewrite");
5032 OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
5033 std::optional<SmallVector<int64_t>> maybeStaticBasis =
5035 adaptor.getDynamicBasis());
5036 if (maybeStaticBasis) {
5037 setStaticBasis(*maybeStaticBasis);
5041 if (getMultiIndex().empty())
5045 if (getMultiIndex().size() == 1)
5046 return getMultiIndex().front();
5048 if (llvm::is_contained(adaptor.getMultiIndex(),
nullptr))
5051 if (!adaptor.getDynamicBasis().empty())
5056 for (
auto [length, indexAttr] :
5057 llvm::zip_first(llvm::reverse(getStaticBasis()),
5058 llvm::reverse(adaptor.getMultiIndex()))) {
5059 result = result + cast<IntegerAttr>(indexAttr).getInt() * stride;
5060 stride = stride * length;
5063 if (!hasOuterBound())
5066 cast<IntegerAttr>(adaptor.getMultiIndex().front()).getInt() * stride;
5073 if (hasOuterBound()) {
5074 if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
5076 getDynamicBasis().drop_front(), builder);
5078 return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
5082 return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
5087 if (!hasOuterBound())
5103 struct DropLinearizeUnitComponentsIfDisjointOrZero final
5107 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5110 size_t numIndices = multiIndex.size();
5112 newIndices.reserve(numIndices);
5114 newBasis.reserve(numIndices);
5116 if (!op.hasOuterBound()) {
5117 newIndices.push_back(multiIndex.front());
5118 multiIndex = multiIndex.drop_front();
5122 for (
auto [index, basisElem] : llvm::zip_equal(multiIndex, basis)) {
5124 if (!basisEntry || *basisEntry != 1) {
5125 newIndices.push_back(index);
5126 newBasis.push_back(basisElem);
5131 if (!op.getDisjoint() && (!indexValue || *indexValue != 0)) {
5132 newIndices.push_back(index);
5133 newBasis.push_back(basisElem);
5137 if (newIndices.size() == numIndices)
5139 "no unit basis entries to replace");
5141 if (newIndices.size() == 0) {
5146 op, newIndices, newBasis, op.getDisjoint());
5153 int64_t nDynamic = 0;
5163 dynamicPart.push_back(cast<Value>(term));
5167 if (
auto constant = dyn_cast<AffineConstantExpr>(result))
5169 return builder.
create<AffineApplyOp>(loc, result, dynamicPart).getResult();
5199 struct CancelLinearizeOfDelinearizePortion final
5210 unsigned linStart = 0;
5211 unsigned delinStart = 0;
5212 unsigned length = 0;
5216 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp,
5223 ValueRange multiIndex = linearizeOp.getMultiIndex();
5224 unsigned numLinArgs = multiIndex.size();
5225 unsigned linArgIdx = 0;
5229 while (linArgIdx < numLinArgs) {
5230 auto asResult = dyn_cast<OpResult>(multiIndex[linArgIdx]);
5236 auto delinearizeOp =
5237 dyn_cast<AffineDelinearizeIndexOp>(asResult.getOwner());
5238 if (!delinearizeOp) {
5255 unsigned delinArgIdx = asResult.getResultNumber();
5257 OpFoldResult firstDelinBound = delinBasis[delinArgIdx];
5259 bool boundsMatch = firstDelinBound == firstLinBound;
5260 bool bothAtFront = linArgIdx == 0 && delinArgIdx == 0;
5261 bool knownByDisjoint =
5262 linearizeOp.getDisjoint() && delinArgIdx == 0 && !firstDelinBound;
5263 if (!boundsMatch && !bothAtFront && !knownByDisjoint) {
5269 unsigned numDelinOuts = delinearizeOp.getNumResults();
5270 for (;
j + linArgIdx < numLinArgs &&
j + delinArgIdx < numDelinOuts;
5272 if (multiIndex[linArgIdx +
j] !=
5273 delinearizeOp.getResult(delinArgIdx +
j))
5275 if (linBasis[linArgIdx +
j] != delinBasis[delinArgIdx +
j])
5281 if (
j <= 1 || !alreadyMatchedDelinearize.insert(delinearizeOp).second) {
5285 matches.push_back(Match{delinearizeOp, linArgIdx, delinArgIdx,
j});
5289 if (matches.empty())
5291 linearizeOp,
"no run of delinearize outputs to deal with");
5299 newIndex.reserve(numLinArgs);
5301 newBasis.reserve(numLinArgs);
5302 unsigned prevMatchEnd = 0;
5303 for (Match m : matches) {
5304 unsigned gap = m.linStart - prevMatchEnd;
5305 llvm::append_range(newIndex, multiIndex.slice(prevMatchEnd, gap));
5306 llvm::append_range(newBasis, linBasisRef.slice(prevMatchEnd, gap));
5308 prevMatchEnd = m.linStart + m.length;
5310 PatternRewriter::InsertionGuard g(rewriter);
5314 linBasisRef.slice(m.linStart, m.length);
5321 if (m.length == m.delinearize.getNumResults()) {
5322 newIndex.push_back(m.delinearize.getLinearIndex());
5323 newBasis.push_back(newSize);
5331 newDelinBasis.erase(newDelinBasis.begin() + m.delinStart,
5332 newDelinBasis.begin() + m.delinStart + m.length);
5333 newDelinBasis.insert(newDelinBasis.begin() + m.delinStart, newSize);
5334 auto newDelinearize = rewriter.
create<AffineDelinearizeIndexOp>(
5335 m.delinearize.getLoc(), m.delinearize.getLinearIndex(),
5341 Value combinedElem = newDelinearize.getResult(m.delinStart);
5342 auto residualDelinearize = rewriter.
create<AffineDelinearizeIndexOp>(
5343 m.delinearize.getLoc(), combinedElem, basisToMerge);
5348 llvm::append_range(newDelinResults,
5349 newDelinearize.getResults().take_front(m.delinStart));
5350 llvm::append_range(newDelinResults, residualDelinearize.getResults());
5353 newDelinearize.getResults().drop_front(m.delinStart + 1));
5355 delinearizeReplacements.push_back(newDelinResults);
5356 newIndex.push_back(combinedElem);
5357 newBasis.push_back(newSize);
5359 llvm::append_range(newIndex, multiIndex.drop_front(prevMatchEnd));
5360 llvm::append_range(newBasis, linBasisRef.drop_front(prevMatchEnd));
5362 linearizeOp, newIndex, newBasis, linearizeOp.getDisjoint());
5364 for (
auto [m, newResults] :
5365 llvm::zip_equal(matches, delinearizeReplacements)) {
5366 if (newResults.empty())
5368 rewriter.
replaceOp(m.delinearize, newResults);
5379 struct DropLinearizeLeadingZero final
5383 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5385 Value leadingIdx = op.getMultiIndex().front();
5389 if (op.getMultiIndex().size() == 1) {
5396 if (op.hasOuterBound())
5397 newMixedBasis = newMixedBasis.drop_front();
5400 op, op.getMultiIndex().drop_front(), newMixedBasis, op.getDisjoint());
5406 void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
5408 patterns.add<CancelLinearizeOfDelinearizePortion, DropLinearizeLeadingZero,
5409 DropLinearizeUnitComponentsIfDisjointOrZero>(context);
5416 #define GET_OP_CLASSES
5417 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
static AffineForOp buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb, int64_t ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds known to be constants.
static bool hasTrivialZeroTripCount(AffineForOp op)
Returns true if the affine.for has zero iterations in trivial cases.
static void composeMultiResultAffineMap(AffineMap &map, SmallVectorImpl< Value > &operands)
Composes the given affine map with the given list of operands, pulling in the maps from any affine....
static LogicalResult verifyMemoryOpIndexing(AffineMemOpTy op, AffineMapAttr mapAttr, Operation::operand_range mapOperands, MemRefType memrefType, unsigned numIndexOperands)
Verify common indexing invariants of affine.load, affine.store, affine.vector_load and affine....
static void printAffineMinMaxOp(OpAsmPrinter &p, T op)
static bool isResultTypeMatchAtomicRMWKind(Type resultType, arith::AtomicRMWKind op)
static bool remainsLegalAfterInline(Value value, Region *src, Region *dest, const IRMapping &mapping, function_ref< bool(Value, Region *)> legalityCheck)
Checks if value known to be a legal affine dimension or symbol in src region remains legal if the ope...
static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr, DenseIntElementsAttr group, ValueRange operands, StringRef keyword)
Prints a lower(upper) bound of an affine parallel loop with max(min) conditions in it.
static void LLVM_ATTRIBUTE_UNUSED simplifyMapWithOperands(AffineMap &map, ArrayRef< Value > operands)
Simplify the map while exploiting information on the values in operands.
static OpFoldResult foldMinMaxOp(T op, ArrayRef< Attribute > operands)
Fold an affine min or max operation with the given operands.
static LogicalResult canonicalizeLoopBounds(AffineForOp forOp)
Canonicalize the bounds of the given loop.
static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims, unsigned numSymbols, ArrayRef< Value > operands)
Simplify expr while exploiting information from the values in operands.
static bool isValidAffineIndexOperand(Value value, Region *region)
static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
static void composeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Iterate over operands and fold away all those produced by an AffineApplyOp iteratively.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static ParseResult parseBound(bool isLower, OperationState &result, OpAsmParser &p)
Parse a for operation loop bounds.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static void composeSetAndOperands(IntegerSet &set, SmallVectorImpl< Value > &operands)
Compose any affine.apply ops feeding into operands of the integer set set by composing the maps of su...
static LogicalResult replaceDimOrSym(AffineMap *map, unsigned dimOrSymbolPosition, SmallVectorImpl< Value > &dims, SmallVectorImpl< Value > &syms)
Replace all occurrences of AffineExpr at position pos in map by the defining AffineApplyOp expression...
static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType)
Verify common invariants of affine.vector_load and affine.vector_store.
static void simplifyMinOrMaxExprWithOperands(AffineMap &map, ArrayRef< Value > operands, bool isMax)
Simplify the expressions in map while making use of lower or upper bounds of its operands.
static ParseResult parseAffineMinMaxOp(OpAsmParser &parser, OperationState &result)
static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, Region *region)
Returns true if the 'index' dimension of the memref defined by memrefDefOp is a statically shaped one...
static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef< Value > operands, int64_t k)
Check if e is known to be: 0 <= e < k.
static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser, OperationState &result, MinMaxKind kind)
Parses an affine map that can contain a min/max for groups of its results, e.g., max(expr-1,...
static AffineForOp buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds that may or may not be constants.
static void printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, OpAsmPrinter &printer)
Prints dimension and symbol list.
static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef< Value > operands)
Returns the largest known divisor of e.
static void legalizeDemotedDims(MapOrSet &mapOrSet, SmallVectorImpl< Value > &operands)
A valid affine dimension may appear as a symbol in affine.apply operations.
static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
static void buildAffineLoopNestImpl(OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, LoopCreatorTy &&loopCreatorFn)
Builds an affine loop nest, using "loopCreatorFn" to create individual loop operations.
static LogicalResult foldLoopBounds(AffineForOp forOp)
Fold the constant bounds of a loop.
static LogicalResult verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, unsigned numDims)
Utility function to verify that a set of operands are valid dimension and symbol identifiers.
static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region)
Returns true if the result of the dim op is a valid symbol for region.
static bool isQTimesDPlusR(AffineExpr e, ArrayRef< Value > operands, int64_t &div, AffineExpr "ientTimesDiv, AffineExpr &rem)
Check if expression e is of the form d*e_1 + e_2 where 0 <= e_2 < d.
static ParseResult deduplicateAndResolveOperands(OpAsmParser &parser, ArrayRef< SmallVector< OpAsmParser::UnresolvedOperand >> operands, SmallVectorImpl< Value > &uniqueOperands, SmallVectorImpl< AffineExpr > &replacements, AffineExprKind kind)
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult verifyAffineMinMaxOp(T op)
static void printBound(AffineMapAttr boundMap, Operation::operand_range boundOperands, const char *prefix, OpAsmPrinter &p)
static std::optional< SmallVector< int64_t > > foldCstValueToCstAttrBasis(ArrayRef< OpFoldResult > mixedBasis, MutableOperandRange mutableDynamicBasis, ArrayRef< Attribute > dynamicBasis)
Given mixed basis of affine.delinearize_index/linearize_index replace constant SSA values with the co...
static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map)
Canonicalize the result expression order of an affine map and return success if the order changed.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
union mlir::linalg::@1204::ArityGroupAndKind::Kind kind
static Operation::operand_range getLowerBoundOperands(AffineForOp forOp)
static Operation::operand_range getUpperBoundOperands(AffineForOp forOp)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
AffineExprKind getKind() const
Return the classification for this type.
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
MLIRContext * getContext() const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
MLIRContext * getContext() const
bool isFunctionOfDim(unsigned position) const
Return true if any affine expression involves AffineDimExpr position.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
bool isFunctionOfSymbol(unsigned position) const
Return true if any affine expression involves AffineSymbolExpr position.
unsigned getNumResults() const
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
unsigned getNumInputs() const
AffineMap shiftSymbols(unsigned shift, unsigned offset=0) const
Replace symbols[offset ...
AffineExpr getResult(unsigned idx) const
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
static AffineMap getConstantMap(int64_t val, MLIRContext *context)
Returns a single constant result affine map.
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
LogicalResult constantFold(ArrayRef< Attribute > operandConstants, SmallVectorImpl< Attribute > &results, bool *hasPoison=nullptr) const
Folds the results of the application of an affine map on the provided operands to a constant if possi...
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
AffineMap getDimIdentityMap()
AffineMap getMultiDimIdentityMap(unsigned rank)
AffineExpr getAffineSymbolExpr(unsigned position)
AffineExpr getAffineConstantExpr(int64_t constant)
DenseIntElementsAttr getI32TensorAttr(ArrayRef< int32_t > values)
Tensor-typed DenseIntElementsAttr getters.
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
AffineMap getEmptyAffineMap()
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
MLIRContext * getContext() const
AffineMap getSymbolIdentityMap()
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
An attribute that represents a reference to a dense integer vector or tensor object.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
An integer set representing a conjunction of one or more affine equalities and inequalities.
unsigned getNumDims() const
static IntegerSet get(unsigned dimCount, unsigned symbolCount, ArrayRef< AffineExpr > constraints, ArrayRef< bool > eqFlags)
MLIRContext * getContext() const
unsigned getNumInputs() const
ArrayRef< AffineExpr > getConstraints() const
ArrayRef< bool > getEqFlags() const
Returns the equality bits, which specify whether each of the constraints is an equality or inequality...
unsigned getNumSymbols() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
void erase(unsigned subStart, unsigned subLen=1)
Erase the operands within the given sub-range.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void pop_back()
Pop last element from list.
Attribute erase(StringAttr name)
Erase the attribute with the given name from the list.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseAffineMapOfSSAIds(SmallVectorImpl< UnresolvedOperand > &operands, Attribute &map, StringRef attrName, NamedAttrList &attrs, Delimiter delimiter=Delimiter::Square)=0
Parses an affine map attribute where dims and symbols are SSA operands.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseAffineExprOfSSAIds(SmallVectorImpl< UnresolvedOperand > &dimOperands, SmallVectorImpl< UnresolvedOperand > &symbOperands, AffineExpr &expr)=0
Parses an affine expression where dims and symbols are SSA operands.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, ValueRange symOperands)=0
Prints an affine expression of SSA ids with SSA id names used instead of dims and symbols.
virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, ValueRange operands)=0
Prints an affine map of SSA ids, where SSA id names are used in place of dims/symbols.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
A trait of region holding operations that defines a new scope for polyhedral optimization purposes.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
operand_range::iterator operand_iterator
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
std::vector< SmallVector< int64_t, 8 > > operandExprStack
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
AffineBound represents a lower or upper bound in the for operation.
AffineDmaStartOp starts a non-blocking DMA operation that transfers data from a source memref to a de...
AffineDmaWaitOp blocks until the completion of a DMA operation associated with the tag element 'tag[i...
An AffineValueMap is an affine map plus its ML value operands and results for analysis purposes.
LogicalResult canonicalize()
Attempts to canonicalize the map and operands.
ArrayRef< Value > getOperands() const
AffineExpr getResult(unsigned i)
AffineMap getAffineMap() const
unsigned getNumResults() const
Operation * getOwner() const
Return the owner of this operand.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
void buildAffineLoopNest(OpBuilder &builder, Location loc, ArrayRef< int64_t > lbs, ArrayRef< int64_t > ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn=nullptr)
Builds a perfect nest of affine.for loops, i.e., each loop except the innermost one contains only ano...
void fullyComposeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Given an affine map map and its input operands, this method composes into map, maps of AffineApplyOps...
void extractForInductionVars(ArrayRef< AffineForOp > forInsts, SmallVectorImpl< Value > *ivs)
Extracts the induction variables from a list of AffineForOps and places them in the output argument i...
bool isValidDim(Value value)
Returns true if the given Value can be used as a dimension id in the region of the closest surroundin...
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
bool isAffineInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp or AffineParallelOp.
AffineForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
bool isAffineForInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp.
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
bool isTopLevelValue(Value value)
A utility function to check if a value is defined at the top level of an op with trait AffineScope or...
Region * getAffineAnalysisScope(Operation *op)
Returns the closest region enclosing op that is held by a non-affine operation; nullptr if there is n...
void canonicalizeSetAndOperands(IntegerSet *set, SmallVectorImpl< Value > *operands)
Canonicalizes an integer set the same way canonicalizeMapAndOperands does for affine maps.
void extractInductionVars(ArrayRef< Operation * > affineOps, SmallVectorImpl< Value > &ivs)
Extracts the induction variables from a list of either AffineForOp or AffineParallelOp and places the...
bool isValidSymbol(Value value)
Returns true if the given value can be used as a symbol in the region of the closest surrounding op t...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
AffineParallelOp getAffineParallelInductionVarOwner(Value val)
Returns true if the provided value is among the induction variables of an AffineParallelOp.
Region * getAffineScope(Operation *op)
Returns the closest region enclosing op that is held by an operation with trait AffineScope; nullptr ...
ParseResult parseDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl< Value > &operands, unsigned &numDims)
Parses dimension and symbol list.
bool isAffineParallelInductionVar(Value val)
Returns true if val is the induction variable of an AffineParallelOp.
AffineMinOp makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns an AffineMinOp obtained by composing map and operands with AffineApplyOps supplying those ope...
BaseMemRefType getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the TensorType can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
AffineMap removeDuplicateExprs(AffineMap map)
Returns a map with the same dimension and symbol count as map, but whose results are the unique affin...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::optional< int64_t > getBoundForAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols, ArrayRef< std::optional< int64_t >> constLowerBounds, ArrayRef< std::optional< int64_t >> constUpperBounds, bool isUpper)
Get a lower or upper (depending on isUpper) bound for expr while using the constant lower and upper b...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ SymbolId
Symbolic identifier.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
AffineMap foldAttributesIntoMap(Builder &b, AffineMap map, ArrayRef< OpFoldResult > operands, SmallVector< Value > &remainingValues)
Fold all attributes among the given operands into the affine map.
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Canonicalize the affine map result expression order of an affine min/max operation.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Remove duplicated expressions in affine min/max ops.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Merge an affine min/max op to its consumers if its consumer is also an affine min/max op.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.