25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallBitVector.h"
27 #include "llvm/ADT/SmallVectorExtras.h"
28 #include "llvm/ADT/TypeSwitch.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/LogicalResult.h"
31 #include "llvm/Support/MathExtras.h"
38 using llvm::divideCeilSigned;
39 using llvm::divideFloorSigned;
42 #define DEBUG_TYPE "affine-ops"
43 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
45 #include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
52 if (
auto arg = llvm::dyn_cast<BlockArgument>(value))
53 return arg.getParentRegion() == region;
76 if (llvm::isa<BlockArgument>(value))
77 return legalityCheck(mapping.
lookup(value), dest);
84 bool isDimLikeOp = isa<ShapedDimOpInterface>(value.
getDefiningOp());
95 return llvm::all_of(values, [&](
Value v) {
102 template <
typename OpTy>
105 static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
106 AffineWriteOpInterface>::value,
107 "only ops with affine read/write interface are supported");
114 dimOperands, src, dest, mapping,
118 symbolOperands, src, dest, mapping,
135 op.getMapOperands(), src, dest, mapping,
140 op.getMapOperands(), src, dest, mapping,
167 if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
172 if (!llvm::hasSingleElement(*src))
180 if (
auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
181 if (iface.hasNoEffect())
189 .Case<AffineApplyOp, AffineReadOpInterface,
190 AffineWriteOpInterface>([&](
auto op) {
215 isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
219 bool shouldAnalyzeRecursively(
Operation *op)
const final {
return true; }
227 void AffineDialect::initialize() {
230 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
232 addInterfaces<AffineInlinerInterface>();
233 declarePromisedInterfaces<ValueBoundsOpInterface, AffineApplyOp, AffineMaxOp,
242 if (
auto poison = dyn_cast<ub::PoisonAttr>(value))
243 return builder.
create<ub::PoisonOp>(loc, type, poison);
244 return arith::ConstantOp::materialize(builder, value, type, loc);
252 if (
auto arg = llvm::dyn_cast<BlockArgument>(value)) {
268 while (
auto *parentOp = curOp->getParentOp()) {
279 if (!isa<AffineForOp, AffineIfOp, AffineParallelOp>(parentOp))
304 auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
332 if (
auto applyOp = dyn_cast<AffineApplyOp>(op))
333 return applyOp.isValidDim(region);
336 if (isa<AffineDelinearizeIndexOp, AffineLinearizeIndexOp>(op))
337 return llvm::all_of(op->getOperands(),
338 [&](
Value arg) { return ::isValidDim(arg, region); });
341 if (
auto dimOp = dyn_cast<ShapedDimOpInterface>(op))
349 template <
typename AnyMemRefDefOp>
352 MemRefType memRefType = memrefDefOp.getType();
355 if (index >= memRefType.getRank()) {
360 if (!memRefType.isDynamicDim(index))
363 unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
364 return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
376 if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
384 if (!index.has_value())
388 Operation *op = dimOp.getShapedValue().getDefiningOp();
389 while (
auto castOp = dyn_cast<memref::CastOp>(op)) {
391 if (isa<UnrankedMemRefType>(castOp.getSource().getType()))
393 op = castOp.getSource().getDefiningOp();
398 int64_t i = index.value();
400 .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
402 .Default([](
Operation *) {
return false; });
469 if (
isPure(defOp) && llvm::all_of(defOp->getOperands(), [&](
Value operand) {
470 return affine::isValidSymbol(operand, region);
476 if (
auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp))
500 printer <<
'(' << operands.take_front(numDims) <<
')';
501 if (operands.size() > numDims)
502 printer <<
'[' << operands.drop_front(numDims) <<
']';
512 numDims = opInfos.size();
526 template <
typename OpTy>
531 for (
auto operand : operands) {
532 if (opIt++ < numDims) {
534 return op.emitOpError(
"operand cannot be used as a dimension id");
536 return op.emitOpError(
"operand cannot be used as a symbol");
547 return AffineValueMap(getAffineMap(), getOperands(), getResult());
554 AffineMapAttr mapAttr;
560 auto map = mapAttr.getValue();
562 if (map.getNumDims() != numDims ||
563 numDims + map.getNumSymbols() != result.
operands.size()) {
565 "dimension or symbol index mismatch");
568 result.
types.append(map.getNumResults(), indexTy);
573 p <<
" " << getMapAttr();
575 getAffineMap().getNumDims(), p);
586 "operand count and affine map dimension and symbol count must match");
590 return emitOpError(
"mapping must produce one value");
596 for (
Value operand : getMapOperands().drop_front(affineMap.
getNumDims())) {
598 return emitError(
"dimensional operand cannot be used as a symbol");
607 return llvm::all_of(getOperands(),
615 return llvm::all_of(getOperands(),
622 return llvm::all_of(getOperands(),
629 return llvm::all_of(getOperands(), [&](
Value operand) {
635 auto map = getAffineMap();
638 auto expr = map.getResult(0);
639 if (
auto dim = dyn_cast<AffineDimExpr>(expr))
640 return getOperand(dim.getPosition());
641 if (
auto sym = dyn_cast<AffineSymbolExpr>(expr))
642 return getOperand(map.getNumDims() + sym.getPosition());
646 bool hasPoison =
false;
648 map.constantFold(adaptor.getMapOperands(), result, &hasPoison);
651 if (failed(foldResult))
668 auto dimExpr = dyn_cast<AffineDimExpr>(e);
678 Value operand = operands[dimExpr.getPosition()];
679 int64_t operandDivisor = 1;
683 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() == 0) {
684 operandDivisor = forOp.getStepAsInt();
686 uint64_t lbLargestKnownDivisor =
687 forOp.getLowerBoundMap().getLargestKnownDivisorOfMapExprs();
688 operandDivisor = std::gcd(lbLargestKnownDivisor, forOp.getStepAsInt());
691 return operandDivisor;
698 if (
auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
699 int64_t constVal = constExpr.getValue();
700 return constVal >= 0 && constVal < k;
702 auto dimExpr = dyn_cast<AffineDimExpr>(e);
705 Value operand = operands[dimExpr.getPosition()];
709 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() >= 0 &&
710 forOp.hasConstantUpperBound() && forOp.getConstantUpperBound() <= k) {
726 auto bin = dyn_cast<AffineBinaryOpExpr>(e);
734 quotientTimesDiv = llhs;
740 quotientTimesDiv = rlhs;
750 if (forOp && forOp.hasConstantLowerBound())
751 return forOp.getConstantLowerBound();
758 if (!forOp || !forOp.hasConstantUpperBound())
763 if (forOp.hasConstantLowerBound()) {
764 return forOp.getConstantUpperBound() - 1 -
765 (forOp.getConstantUpperBound() - forOp.getConstantLowerBound() - 1) %
766 forOp.getStepAsInt();
768 return forOp.getConstantUpperBound() - 1;
779 constLowerBounds.reserve(operands.size());
780 constUpperBounds.reserve(operands.size());
781 for (
Value operand : operands) {
786 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr))
787 return constExpr.getValue();
802 constLowerBounds.reserve(operands.size());
803 constUpperBounds.reserve(operands.size());
804 for (
Value operand : operands) {
809 std::optional<int64_t> lowerBound;
810 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
811 lowerBound = constExpr.getValue();
814 constLowerBounds, constUpperBounds,
825 auto binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
836 binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
844 lhs = binExpr.getLHS();
845 rhs = binExpr.getRHS();
846 auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
850 int64_t rhsConstVal = rhsConst.getValue();
852 if (rhsConstVal <= 0)
857 std::optional<int64_t> lhsLbConst =
859 std::optional<int64_t> lhsUbConst =
861 if (lhsLbConst && lhsUbConst) {
862 int64_t lhsLbConstVal = *lhsLbConst;
863 int64_t lhsUbConstVal = *lhsUbConst;
867 divideFloorSigned(lhsLbConstVal, rhsConstVal) ==
868 divideFloorSigned(lhsUbConstVal, rhsConstVal)) {
870 divideFloorSigned(lhsLbConstVal, rhsConstVal), context);
876 divideCeilSigned(lhsLbConstVal, rhsConstVal) ==
877 divideCeilSigned(lhsUbConstVal, rhsConstVal)) {
884 lhsLbConstVal < rhsConstVal && lhsUbConstVal < rhsConstVal) {
896 if (
isQTimesDPlusR(lhs, operands, divisor, quotientTimesDiv, rem)) {
897 if (rhsConstVal % divisor == 0 &&
899 expr = quotientTimesDiv.
floorDiv(rhsConst);
900 }
else if (divisor % rhsConstVal == 0 &&
902 expr = rem % rhsConst;
928 if (operands.empty())
934 constLowerBounds.reserve(operands.size());
935 constUpperBounds.reserve(operands.size());
936 for (
Value operand : operands) {
950 if (
auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
951 lowerBounds.push_back(constExpr.getValue());
952 upperBounds.push_back(constExpr.getValue());
954 lowerBounds.push_back(
956 constLowerBounds, constUpperBounds,
958 upperBounds.push_back(
960 constLowerBounds, constUpperBounds,
969 unsigned i = exprEn.index();
971 if (lowerBounds[i] && upperBounds[i] && *lowerBounds[i] == *upperBounds[i])
976 if (!upperBounds[i]) {
977 irredundantExprs.push_back(e);
983 auto otherLowerBound = en.value();
984 unsigned pos = en.index();
985 if (pos == i || !otherLowerBound)
987 if (*otherLowerBound > *upperBounds[i])
989 if (*otherLowerBound < *upperBounds[i])
994 if (upperBounds[pos] && lowerBounds[i] &&
995 lowerBounds[i] == upperBounds[i] &&
996 otherLowerBound == *upperBounds[pos] && i < pos)
1000 irredundantExprs.push_back(e);
1002 if (!lowerBounds[i]) {
1003 irredundantExprs.push_back(e);
1008 auto otherUpperBound = en.value();
1009 unsigned pos = en.index();
1010 if (pos == i || !otherUpperBound)
1012 if (*otherUpperBound < *lowerBounds[i])
1014 if (*otherUpperBound > *lowerBounds[i])
1016 if (lowerBounds[pos] && upperBounds[i] &&
1017 lowerBounds[i] == upperBounds[i] &&
1018 otherUpperBound == lowerBounds[pos] && i < pos)
1022 irredundantExprs.push_back(e);
1034 static void LLVM_ATTRIBUTE_UNUSED
1036 assert(map.
getNumInputs() == operands.size() &&
"invalid operands for map");
1042 newResults.push_back(expr);
1065 AffineMap affineMinMap = minOp.getAffineMap();
1068 DBGS() <<
"replaceAffineMinBoundingBoxExpression: `" << minOp <<
"`\n";
1072 for (
unsigned i = 0, e = affineMinMap.
getNumResults(); i < e; ++i) {
1076 ValueBoundsConstraintSet::ComparisonOperator::LT,
1078 minOp.getOperands())))
1087 auto it = llvm::find(dims, dim);
1088 if (it == dims.end()) {
1089 unmappedDims.push_back(i);
1096 auto it = llvm::find(syms, sym);
1097 if (it == syms.end()) {
1098 unmappedSyms.push_back(i);
1111 if (llvm::any_of(unmappedDims,
1112 [&](
unsigned i) {
return expr.isFunctionOfDim(i); }) ||
1113 llvm::any_of(unmappedSyms,
1114 [&](
unsigned i) {
return expr.isFunctionOfSymbol(i); }))
1120 repl[dimOrSym.
ceilDiv(convertedExpr)] = c1;
1122 repl[(dimOrSym + convertedExpr - 1).floorDiv(convertedExpr)] = c1;
1127 return success(*map != initialMap);
1143 unsigned dimOrSymbolPosition,
1146 bool replaceAffineMin) {
1148 bool isDimReplacement = (dimOrSymbolPosition < dims.size());
1149 unsigned pos = isDimReplacement ? dimOrSymbolPosition
1150 : dimOrSymbolPosition - dims.size();
1151 Value &v = isDimReplacement ? dims[pos] : syms[pos];
1155 if (
auto minOp = v.
getDefiningOp<AffineMinOp>(); minOp && replaceAffineMin) {
1171 AffineMap composeMap = affineApply.getAffineMap();
1172 assert(composeMap.
getNumResults() == 1 &&
"affine.apply with >1 results");
1174 affineApply.getMapOperands().end());
1188 dims.append(composeDims.begin(), composeDims.end());
1189 syms.append(composeSyms.begin(), composeSyms.end());
1190 *map = map->
replace(toReplace, replacementExpr, dims.size(), syms.size());
1200 bool composeAffineMin =
false) {
1220 for (
unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
1233 unsigned nDims = 0, nSyms = 0;
1235 dimReplacements.reserve(dims.size());
1236 symReplacements.reserve(syms.size());
1237 for (
auto *container : {&dims, &syms}) {
1238 bool isDim = (container == &dims);
1239 auto &repls = isDim ? dimReplacements : symReplacements;
1241 Value v = en.value();
1245 "map is function of unexpected expr@pos");
1251 operands->push_back(v);
1264 while (llvm::any_of(*operands, [](
Value v) {
1270 if (composeAffineMin && llvm::any_of(*operands, [](
Value v) {
1280 bool composeAffineMin) {
1285 return b.
create<AffineApplyOp>(loc, map, valueOperands);
1291 bool composeAffineMin) {
1296 operands, composeAffineMin);
1303 bool composeAffineMin =
false) {
1309 for (
unsigned i : llvm::seq<unsigned>(0, map.
getNumResults())) {
1317 llvm::append_range(dims,
1319 llvm::append_range(symbols,
1326 operands = llvm::to_vector(llvm::concat<Value>(dims, symbols));
1333 bool composeAffineMin) {
1334 assert(map.
getNumResults() == 1 &&
"building affine.apply with !=1 result");
1344 AffineApplyOp applyOp =
1349 for (
unsigned i = 0, e = constOperands.size(); i != e; ++i)
1354 if (failed(applyOp->fold(constOperands, foldResults)) ||
1355 foldResults.empty()) {
1357 listener->notifyOperationInserted(applyOp, {});
1358 return applyOp.getResult();
1362 return llvm::getSingleElement(foldResults);
1372 operands, composeAffineMin);
1378 bool composeAffineMin) {
1379 return llvm::map_to_vector(
1380 llvm::seq<unsigned>(0, map.
getNumResults()), [&](
unsigned i) {
1381 return makeComposedFoldedAffineApply(b, loc, map.getSubMap({i}),
1382 operands, composeAffineMin);
1386 template <
typename OpTy>
1398 return makeComposedMinMax<AffineMinOp>(b, loc, map, operands);
1401 template <
typename OpTy>
1413 auto minMaxOp = makeComposedMinMax<OpTy>(newBuilder, loc, map, operands);
1417 for (
unsigned i = 0, e = constOperands.size(); i != e; ++i)
1422 if (failed(minMaxOp->fold(constOperands, foldResults)) ||
1423 foldResults.empty()) {
1425 listener->notifyOperationInserted(minMaxOp, {});
1426 return minMaxOp.getResult();
1430 return llvm::getSingleElement(foldResults);
1437 return makeComposedFoldedMinMax<AffineMinOp>(b, loc, map, operands);
1444 return makeComposedFoldedMinMax<AffineMaxOp>(b, loc, map, operands);
1449 template <
class MapOrSet>
1452 if (!mapOrSet || operands->empty())
1455 assert(mapOrSet->getNumInputs() == operands->size() &&
1456 "map/set inputs must match number of operands");
1458 auto *context = mapOrSet->getContext();
1460 resultOperands.reserve(operands->size());
1462 remappedSymbols.reserve(operands->size());
1463 unsigned nextDim = 0;
1464 unsigned nextSym = 0;
1465 unsigned oldNumSyms = mapOrSet->getNumSymbols();
1467 for (
unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
1468 if (i < mapOrSet->getNumDims()) {
1472 remappedSymbols.push_back((*operands)[i]);
1475 resultOperands.push_back((*operands)[i]);
1478 resultOperands.push_back((*operands)[i]);
1482 resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
1483 *operands = resultOperands;
1484 *mapOrSet = mapOrSet->replaceDimsAndSymbols(
1485 dimRemapping, {}, nextDim, oldNumSyms + nextSym);
1487 assert(mapOrSet->getNumInputs() == operands->size() &&
1488 "map/set inputs must match number of operands");
1497 template <
class MapOrSet>
1500 if (!mapOrSet || operands.empty())
1503 unsigned numOperands = operands.size();
1505 assert(mapOrSet.getNumInputs() == numOperands &&
1506 "map/set inputs must match number of operands");
1508 auto *context = mapOrSet.getContext();
1510 resultOperands.reserve(numOperands);
1512 remappedDims.reserve(numOperands);
1514 symOperands.reserve(mapOrSet.getNumSymbols());
1515 unsigned nextSym = 0;
1516 unsigned nextDim = 0;
1517 unsigned oldNumDims = mapOrSet.getNumDims();
1519 resultOperands.assign(operands.begin(), operands.begin() + oldNumDims);
1520 for (
unsigned i = oldNumDims, e = mapOrSet.getNumInputs(); i != e; ++i) {
1523 symRemapping[i - oldNumDims] =
1525 remappedDims.push_back(operands[i]);
1528 symOperands.push_back(operands[i]);
1532 append_range(resultOperands, remappedDims);
1533 append_range(resultOperands, symOperands);
1534 operands = resultOperands;
1535 mapOrSet = mapOrSet.replaceDimsAndSymbols(
1536 {}, symRemapping, oldNumDims + nextDim, nextSym);
1538 assert(mapOrSet.getNumInputs() == operands.size() &&
1539 "map/set inputs must match number of operands");
1543 template <
class MapOrSet>
1546 static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
1547 "Argument must be either of AffineMap or IntegerSet type");
1549 if (!mapOrSet || operands->empty())
1552 assert(mapOrSet->getNumInputs() == operands->size() &&
1553 "map/set inputs must match number of operands");
1555 canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
1556 legalizeDemotedDims<MapOrSet>(*mapOrSet, *operands);
1559 llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
1560 llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
1562 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr))
1563 usedDims[dimExpr.getPosition()] =
true;
1564 else if (
auto symExpr = dyn_cast<AffineSymbolExpr>(expr))
1565 usedSyms[symExpr.getPosition()] =
true;
1568 auto *context = mapOrSet->getContext();
1571 resultOperands.reserve(operands->size());
1573 llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
1575 unsigned nextDim = 0;
1576 for (
unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
1579 auto it = seenDims.find((*operands)[i]);
1580 if (it == seenDims.end()) {
1582 resultOperands.push_back((*operands)[i]);
1583 seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
1585 dimRemapping[i] = it->second;
1589 llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
1591 unsigned nextSym = 0;
1592 for (
unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
1598 IntegerAttr operandCst;
1599 if (
matchPattern((*operands)[i + mapOrSet->getNumDims()],
1606 auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
1607 if (it == seenSymbols.end()) {
1609 resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
1610 seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
1613 symRemapping[i] = it->second;
1616 *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
1618 *operands = resultOperands;
1623 canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
1628 canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
1635 template <
typename AffineOpTy>
1644 LogicalResult matchAndRewrite(AffineOpTy affineOp,
1647 llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
1648 AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp,
1649 AffineVectorStoreOp, AffineVectorLoadOp>::value,
1650 "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op "
1652 auto map = affineOp.getAffineMap();
1654 auto oldOperands = affineOp.getMapOperands();
1659 if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
1660 resultOperands.begin()))
1663 replaceAffineOp(rewriter, affineOp, map, resultOperands);
1671 void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
1678 void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
1682 prefetch, prefetch.getMemref(), map, mapOperands, prefetch.getIsWrite(),
1683 prefetch.getLocalityHint(), prefetch.getIsDataCache());
1686 void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
1690 store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
1693 void SimplifyAffineOp<AffineVectorLoadOp>::replaceAffineOp(
1697 vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map,
1701 void SimplifyAffineOp<AffineVectorStoreOp>::replaceAffineOp(
1705 vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map,
1710 template <
typename AffineOpTy>
1711 void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
1720 results.
add<SimplifyAffineOp<AffineApplyOp>>(context);
1751 p <<
" " << getSrcMemRef() <<
'[';
1753 p <<
"], " << getDstMemRef() <<
'[';
1755 p <<
"], " << getTagMemRef() <<
'[';
1760 p <<
", " << getNumElementsPerStride();
1762 p <<
" : " << getSrcMemRefType() <<
", " << getDstMemRefType() <<
", "
1763 << getTagMemRefType();
1775 AffineMapAttr srcMapAttr;
1778 AffineMapAttr dstMapAttr;
1781 AffineMapAttr tagMapAttr;
1796 getSrcMapAttrStrName(),
1800 getDstMapAttrStrName(),
1804 getTagMapAttrStrName(),
1813 if (!strideInfo.empty() && strideInfo.size() != 2) {
1815 "expected two stride related operands");
1817 bool isStrided = strideInfo.size() == 2;
1822 if (types.size() != 3)
1840 if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
1841 dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
1842 tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1844 "memref operand count not equal to map.numInputs");
1848 LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
1849 if (!llvm::isa<MemRefType>(getOperand(getSrcMemRefOperandIndex()).
getType()))
1850 return emitOpError(
"expected DMA source to be of memref type");
1851 if (!llvm::isa<MemRefType>(getOperand(getDstMemRefOperandIndex()).
getType()))
1852 return emitOpError(
"expected DMA destination to be of memref type");
1853 if (!llvm::isa<MemRefType>(getOperand(getTagMemRefOperandIndex()).
getType()))
1854 return emitOpError(
"expected DMA tag to be of memref type");
1856 unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
1857 getDstMap().getNumInputs() +
1858 getTagMap().getNumInputs();
1859 if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
1860 getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
1861 return emitOpError(
"incorrect number of operands");
1865 for (
auto idx : getSrcIndices()) {
1866 if (!idx.getType().isIndex())
1867 return emitOpError(
"src index to dma_start must have 'index' type");
1870 "src index must be a valid dimension or symbol identifier");
1872 for (
auto idx : getDstIndices()) {
1873 if (!idx.getType().isIndex())
1874 return emitOpError(
"dst index to dma_start must have 'index' type");
1877 "dst index must be a valid dimension or symbol identifier");
1879 for (
auto idx : getTagIndices()) {
1880 if (!idx.getType().isIndex())
1881 return emitOpError(
"tag index to dma_start must have 'index' type");
1884 "tag index must be a valid dimension or symbol identifier");
1895 void AffineDmaStartOp::getEffects(
1921 p <<
" " << getTagMemRef() <<
'[';
1926 p <<
" : " << getTagMemRef().getType();
1937 AffineMapAttr tagMapAttr;
1946 getTagMapAttrStrName(),
1955 if (!llvm::isa<MemRefType>(type))
1957 "expected tag to be of memref type");
1959 if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1961 "tag memref operand count != to map.numInputs");
1965 LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
1966 if (!llvm::isa<MemRefType>(getOperand(0).
getType()))
1967 return emitOpError(
"expected DMA tag to be of memref type");
1969 for (
auto idx : getTagIndices()) {
1970 if (!idx.getType().isIndex())
1971 return emitOpError(
"index to dma_wait must have 'index' type");
1974 "index must be a valid dimension or symbol identifier");
1985 void AffineDmaWaitOp::getEffects(
2001 ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
2002 assert(((!lbMap && lbOperands.empty()) ||
2004 "lower bound operand count does not match the affine map");
2005 assert(((!ubMap && ubOperands.empty()) ||
2007 "upper bound operand count does not match the affine map");
2008 assert(step > 0 &&
"step has to be a positive integer constant");
2014 getOperandSegmentSizeAttr(),
2016 static_cast<int32_t>(ubOperands.size()),
2017 static_cast<int32_t>(iterArgs.size())}));
2019 for (
Value val : iterArgs)
2041 Value inductionVar =
2043 for (
Value val : iterArgs)
2044 bodyBlock->
addArgument(val.getType(), val.getLoc());
2049 if (iterArgs.empty() && !bodyBuilder) {
2050 ensureTerminator(*bodyRegion, builder, result.
location);
2051 }
else if (bodyBuilder) {
2054 bodyBuilder(builder, result.
location, inductionVar,
2060 int64_t ub, int64_t step,
ValueRange iterArgs,
2061 BodyBuilderFn bodyBuilder) {
2064 return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,
2068 LogicalResult AffineForOp::verifyRegions() {
2071 auto *body = getBody();
2072 if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
2073 return emitOpError(
"expected body to have a single index argument for the "
2074 "induction variable");
2078 if (getLowerBoundMap().getNumInputs() > 0)
2080 getLowerBoundMap().getNumDims())))
2083 if (getUpperBoundMap().getNumInputs() > 0)
2085 getUpperBoundMap().getNumDims())))
2087 if (getLowerBoundMap().getNumResults() < 1)
2088 return emitOpError(
"expected lower bound map to have at least one result");
2089 if (getUpperBoundMap().getNumResults() < 1)
2090 return emitOpError(
"expected upper bound map to have at least one result");
2092 unsigned opNumResults = getNumResults();
2093 if (opNumResults == 0)
2099 if (getNumIterOperands() != opNumResults)
2101 "mismatch between the number of loop-carried values and results");
2102 if (getNumRegionIterArgs() != opNumResults)
2104 "mismatch between the number of basic block args and results");
2114 bool failedToParsedMinMax =
2118 auto boundAttrStrName =
2119 isLower ? AffineForOp::getLowerBoundMapAttrName(result.
name)
2120 : AffineForOp::getUpperBoundMapAttrName(result.
name);
2127 if (!boundOpInfos.empty()) {
2129 if (boundOpInfos.size() > 1)
2131 "expected only one loop bound operand");
2156 if (
auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(boundAttr)) {
2157 unsigned currentNumOperands = result.
operands.size();
2162 auto map = affineMapAttr.getValue();
2166 "dim operand count and affine map dim count must match");
2168 unsigned numDimAndSymbolOperands =
2169 result.
operands.size() - currentNumOperands;
2170 if (numDims + map.
getNumSymbols() != numDimAndSymbolOperands)
2173 "symbol operand count and affine map symbol count must match");
2179 return p.
emitError(attrLoc,
"lower loop bound affine map with "
2180 "multiple results requires 'max' prefix");
2182 return p.
emitError(attrLoc,
"upper loop bound affine map with multiple "
2183 "results requires 'min' prefix");
2189 if (
auto integerAttr = llvm::dyn_cast<IntegerAttr>(boundAttr)) {
2199 "expected valid affine map representation for loop bounds");
2211 int64_t numOperands = result.
operands.size();
2214 int64_t numLbOperands = result.
operands.size() - numOperands;
2217 numOperands = result.
operands.size();
2220 int64_t numUbOperands = result.
operands.size() - numOperands;
2225 getStepAttrName(result.
name),
2229 IntegerAttr stepAttr;
2231 getStepAttrName(result.
name).data(),
2235 if (stepAttr.getValue().isNegative())
2238 "expected step to be representable as a positive signed integer");
2246 regionArgs.push_back(inductionVariable);
2254 for (
auto argOperandType :
2255 llvm::zip(llvm::drop_begin(regionArgs), operands, result.
types)) {
2256 Type type = std::get<2>(argOperandType);
2257 std::get<0>(argOperandType).type = type;
2265 getOperandSegmentSizeAttr(),
2267 static_cast<int32_t>(numUbOperands),
2268 static_cast<int32_t>(operands.size())}));
2272 if (regionArgs.size() != result.
types.size() + 1)
2275 "mismatch between the number of loop-carried values and results");
2279 AffineForOp::ensureTerminator(*body, builder, result.
location);
2301 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
2302 p << constExpr.getValue();
2310 if (isa<AffineSymbolExpr>(expr)) {
2326 unsigned AffineForOp::getNumIterOperands() {
2327 AffineMap lbMap = getLowerBoundMapAttr().getValue();
2328 AffineMap ubMap = getUpperBoundMapAttr().getValue();
2333 std::optional<MutableArrayRef<OpOperand>>
2334 AffineForOp::getYieldedValuesMutable() {
2335 return cast<AffineYieldOp>(getBody()->getTerminator()).getOperandsMutable();
2347 if (getStepAsInt() != 1)
2348 p <<
" step " << getStepAsInt();
2350 bool printBlockTerminators =
false;
2351 if (getNumIterOperands() > 0) {
2353 auto regionArgs = getRegionIterArgs();
2354 auto operands = getInits();
2356 llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](
auto it) {
2357 p << std::get<0>(it) <<
" = " << std::get<1>(it);
2359 p <<
") -> (" << getResultTypes() <<
")";
2360 printBlockTerminators =
true;
2365 printBlockTerminators);
2367 (*this)->getAttrs(),
2368 {getLowerBoundMapAttrName(getOperation()->getName()),
2369 getUpperBoundMapAttrName(getOperation()->getName()),
2370 getStepAttrName(getOperation()->getName()),
2371 getOperandSegmentSizeAttr()});
2376 auto foldLowerOrUpperBound = [&forOp](
bool lower) {
2380 auto boundOperands =
2381 lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
2382 for (
auto operand : boundOperands) {
2385 operandConstants.push_back(operandCst);
2389 lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
2391 "bound maps should have at least one result");
2393 if (failed(boundMap.
constantFold(operandConstants, foldedResults)))
2397 assert(!foldedResults.empty() &&
"bounds should have at least one result");
2398 auto maxOrMin = llvm::cast<IntegerAttr>(foldedResults[0]).getValue();
2399 for (
unsigned i = 1, e = foldedResults.size(); i < e; i++) {
2400 auto foldedResult = llvm::cast<IntegerAttr>(foldedResults[i]).getValue();
2401 maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
2402 : llvm::APIntOps::smin(maxOrMin, foldedResult);
2404 lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
2405 : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
2410 bool folded =
false;
2411 if (!forOp.hasConstantLowerBound())
2412 folded |= succeeded(foldLowerOrUpperBound(
true));
2415 if (!forOp.hasConstantUpperBound())
2416 folded |= succeeded(foldLowerOrUpperBound(
false));
2417 return success(folded);
2425 auto lbMap = forOp.getLowerBoundMap();
2426 auto ubMap = forOp.getUpperBoundMap();
2427 auto prevLbMap = lbMap;
2428 auto prevUbMap = ubMap;
2441 if (lbMap == prevLbMap && ubMap == prevUbMap)
2444 if (lbMap != prevLbMap)
2445 forOp.setLowerBound(lbOperands, lbMap);
2446 if (ubMap != prevUbMap)
2447 forOp.setUpperBound(ubOperands, ubMap);
2453 static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
2454 int64_t step = forOp.getStepAsInt();
2455 if (!forOp.hasConstantBounds() || step <= 0)
2456 return std::nullopt;
2457 int64_t lb = forOp.getConstantLowerBound();
2458 int64_t ub = forOp.getConstantUpperBound();
2459 return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
2467 LogicalResult matchAndRewrite(AffineForOp forOp,
2470 if (!llvm::hasSingleElement(*forOp.getBody()))
2472 if (forOp.getNumResults() == 0)
2474 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
2475 if (tripCount == 0) {
2478 rewriter.
replaceOp(forOp, forOp.getInits());
2482 auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
2483 auto iterArgs = forOp.getRegionIterArgs();
2484 bool hasValDefinedOutsideLoop =
false;
2485 bool iterArgsNotInOrder =
false;
2486 for (
unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
2487 Value val = yieldOp.getOperand(i);
2488 auto *iterArgIt = llvm::find(iterArgs, val);
2491 if (val == forOp.getInductionVar())
2493 if (iterArgIt == iterArgs.end()) {
2495 assert(forOp.isDefinedOutsideOfLoop(val) &&
2496 "must be defined outside of the loop");
2497 hasValDefinedOutsideLoop =
true;
2498 replacements.push_back(val);
2500 unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
2502 iterArgsNotInOrder =
true;
2503 replacements.push_back(forOp.getInits()[pos]);
2508 if (!tripCount.has_value() &&
2509 (hasValDefinedOutsideLoop || iterArgsNotInOrder))
2513 if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
2515 rewriter.
replaceOp(forOp, replacements);
2523 results.
add<AffineForEmptyLoopFolder>(context);
2527 assert((point.
isParent() || point == getRegion()) &&
"invalid region point");
2534 void AffineForOp::getSuccessorRegions(
2536 assert((point.
isParent() || point == getRegion()) &&
"expected loop region");
2541 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*
this);
2542 if (point.
isParent() && tripCount.has_value()) {
2543 if (tripCount.value() > 0) {
2544 regions.push_back(
RegionSuccessor(&getRegion(), getRegionIterArgs()));
2547 if (tripCount.value() == 0) {
2555 if (!point.
isParent() && tripCount == 1) {
2562 regions.push_back(
RegionSuccessor(&getRegion(), getRegionIterArgs()));
2568 return getTrivialConstantTripCount(op) == 0;
2571 LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
2581 results.assign(getInits().begin(), getInits().end());
2584 return success(folded);
2597 assert(map.
getNumResults() >= 1 &&
"bound map has at least one result");
2598 getLowerBoundOperandsMutable().assign(lbOperands);
2599 setLowerBoundMap(map);
2604 assert(map.
getNumResults() >= 1 &&
"bound map has at least one result");
2605 getUpperBoundOperandsMutable().assign(ubOperands);
2606 setUpperBoundMap(map);
2609 bool AffineForOp::hasConstantLowerBound() {
2610 return getLowerBoundMap().isSingleConstant();
2613 bool AffineForOp::hasConstantUpperBound() {
2614 return getUpperBoundMap().isSingleConstant();
2617 int64_t AffineForOp::getConstantLowerBound() {
2618 return getLowerBoundMap().getSingleConstantResult();
2621 int64_t AffineForOp::getConstantUpperBound() {
2622 return getUpperBoundMap().getSingleConstantResult();
2625 void AffineForOp::setConstantLowerBound(int64_t value) {
2629 void AffineForOp::setConstantUpperBound(int64_t value) {
2633 AffineForOp::operand_range AffineForOp::getControlOperands() {
2638 bool AffineForOp::matchingBoundOperandList() {
2639 auto lbMap = getLowerBoundMap();
2640 auto ubMap = getUpperBoundMap();
2646 for (
unsigned i = 0, e = lbMap.
getNumInputs(); i < e; i++) {
2648 if (getOperand(i) != getOperand(numOperands + i))
2656 std::optional<SmallVector<Value>> AffineForOp::getLoopInductionVars() {
2660 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopLowerBounds() {
2661 if (!hasConstantLowerBound())
2662 return std::nullopt;
2665 OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))};
2668 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopSteps() {
2674 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopUpperBounds() {
2675 if (!hasConstantUpperBound())
2679 OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()))};
2682 FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
2684 bool replaceInitOperandUsesInLoop,
2689 auto inits = llvm::to_vector(getInits());
2690 inits.append(newInitOperands.begin(), newInitOperands.end());
2691 AffineForOp newLoop = rewriter.
create<AffineForOp>(
2696 auto yieldOp = cast<AffineYieldOp>(getBody()->getTerminator());
2698 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
2703 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
2704 assert(newInitOperands.size() == newYieldedValues.size() &&
2705 "expected as many new yield values as new iter operands");
2707 yieldOp.getOperandsMutable().append(newYieldedValues);
2712 rewriter.
mergeBlocks(getBody(), newLoop.getBody(),
2713 newLoop.getBody()->getArguments().take_front(
2714 getBody()->getNumArguments()));
2716 if (replaceInitOperandUsesInLoop) {
2719 for (
auto it : llvm::zip(newInitOperands, newIterArgs)) {
2730 newLoop->getResults().take_front(getNumResults()));
2731 return cast<LoopLikeOpInterface>(newLoop.getOperation());
2759 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2760 if (!ivArg || !ivArg.getOwner() || !ivArg.getOwner()->getParent())
2761 return AffineForOp();
2763 ivArg.getOwner()->getParent()->getParentOfType<AffineForOp>())
2765 return forOp.getInductionVar() == val ? forOp : AffineForOp();
2766 return AffineForOp();
2770 auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2771 if (!ivArg || !ivArg.getOwner())
2774 auto parallelOp = dyn_cast_if_present<AffineParallelOp>(containingOp);
2775 if (parallelOp && llvm::is_contained(parallelOp.getIVs(), val))
2784 ivs->reserve(forInsts.size());
2785 for (
auto forInst : forInsts)
2786 ivs->push_back(forInst.getInductionVar());
2791 ivs.reserve(affineOps.size());
2794 if (
auto forOp = dyn_cast<AffineForOp>(op))
2795 ivs.push_back(forOp.getInductionVar());
2796 else if (
auto parallelOp = dyn_cast<AffineParallelOp>(op))
2797 for (
size_t i = 0; i < parallelOp.getBody()->getNumArguments(); i++)
2798 ivs.push_back(parallelOp.getBody()->getArgument(i));
2804 template <
typename BoundListTy,
typename LoopCreatorTy>
2809 LoopCreatorTy &&loopCreatorFn) {
2810 assert(lbs.size() == ubs.size() &&
"Mismatch in number of arguments");
2811 assert(lbs.size() == steps.size() &&
"Mismatch in number of arguments");
2823 ivs.reserve(lbs.size());
2824 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
2830 if (i == e - 1 && bodyBuilderFn) {
2832 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2834 nestedBuilder.
create<AffineYieldOp>(nestedLoc);
2839 auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
2847 int64_t ub, int64_t step,
2848 AffineForOp::BodyBuilderFn bodyBuilderFn) {
2849 return builder.
create<AffineForOp>(loc, lb, ub, step,
2857 AffineForOp::BodyBuilderFn bodyBuilderFn) {
2860 if (lbConst && ubConst)
2862 ubConst.value(), step, bodyBuilderFn);
2893 LogicalResult matchAndRewrite(AffineIfOp ifOp,
2895 if (ifOp.getElseRegion().empty() ||
2896 !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
2911 LogicalResult matchAndRewrite(AffineIfOp op,
2914 auto isTriviallyFalse = [](
IntegerSet iSet) {
2915 return iSet.isEmptyIntegerSet();
2919 return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 &&
2920 iSet.getConstraint(0) == 0);
2923 IntegerSet affineIfConditions = op.getIntegerSet();
2925 if (isTriviallyFalse(affineIfConditions)) {
2929 if (op.getNumResults() == 0 && !op.hasElse()) {
2935 blockToMove = op.getElseBlock();
2936 }
else if (isTriviallyTrue(affineIfConditions)) {
2937 blockToMove = op.getThenBlock();
2955 rewriter.
eraseOp(blockToMoveTerminator);
2963 void AffineIfOp::getSuccessorRegions(
2972 if (getElseRegion().empty()) {
2973 regions.push_back(getResults());
2989 auto conditionAttr =
2990 (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2992 return emitOpError(
"requires an integer set attribute named 'condition'");
2995 IntegerSet condition = conditionAttr.getValue();
2997 return emitOpError(
"operand count and condition integer set dimension and "
2998 "symbol count must match");
3010 IntegerSetAttr conditionAttr;
3013 AffineIfOp::getConditionAttrStrName(),
3019 auto set = conditionAttr.getValue();
3020 if (set.getNumDims() != numDims)
3023 "dim operand count and integer set dim count must match");
3024 if (numDims + set.getNumSymbols() != result.
operands.size())
3027 "symbol operand count and integer set symbol count must match");
3041 AffineIfOp::ensureTerminator(*thenRegion, parser.
getBuilder(),
3048 AffineIfOp::ensureTerminator(*elseRegion, parser.
getBuilder(),
3060 auto conditionAttr =
3061 (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
3062 p <<
" " << conditionAttr;
3064 conditionAttr.getValue().getNumDims(), p);
3071 auto &elseRegion = this->getElseRegion();
3072 if (!elseRegion.
empty()) {
3081 getConditionAttrStrName());
3086 ->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName())
3090 void AffineIfOp::setIntegerSet(
IntegerSet newSet) {
3096 (*this)->setOperands(operands);
3101 bool withElseRegion) {
3102 assert(resultTypes.empty() || withElseRegion);
3111 if (resultTypes.empty())
3112 AffineIfOp::ensureTerminator(*thenRegion, builder, result.
location);
3115 if (withElseRegion) {
3117 if (resultTypes.empty())
3118 AffineIfOp::ensureTerminator(*elseRegion, builder, result.
location);
3124 AffineIfOp::build(builder, result, {}, set, args,
3133 bool composeAffineMin =
false) {
3140 if (llvm::none_of(operands,
3151 auto set = getIntegerSet();
3157 if (getIntegerSet() == set && llvm::equal(operands, getOperands()))
3160 setConditional(set, operands);
3166 results.
add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context);
3175 assert(operands.size() == 1 + map.
getNumInputs() &&
"inconsistent operands");
3179 auto memrefType = llvm::cast<MemRefType>(operands[0].
getType());
3180 result.
types.push_back(memrefType.getElementType());
3185 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
3188 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
3190 result.
types.push_back(memrefType.getElementType());
3195 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
3196 int64_t rank = memrefType.getRank();
3201 build(builder, result, memref, map, indices);
3210 AffineMapAttr mapAttr;
3215 AffineLoadOp::getMapAttrStrName(),
3225 p <<
" " << getMemRef() <<
'[';
3226 if (AffineMapAttr mapAttr =
3227 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3231 {getMapAttrStrName()});
3237 template <
typename AffineMemOpTy>
3238 static LogicalResult
3241 MemRefType memrefType,
unsigned numIndexOperands) {
3244 return op->emitOpError(
"affine map num results must equal memref rank");
3246 return op->emitOpError(
"expects as many subscripts as affine map inputs");
3248 for (
auto idx : mapOperands) {
3249 if (!idx.getType().isIndex())
3250 return op->emitOpError(
"index to load must have 'index' type");
3260 if (
getType() != memrefType.getElementType())
3261 return emitOpError(
"result type must match element type of memref");
3264 *
this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3265 getMapOperands(), memrefType,
3266 getNumOperands() - 1)))
3274 results.
add<SimplifyAffineOp<AffineLoadOp>>(context);
3283 auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
3290 auto global = dyn_cast_or_null<memref::GlobalOp>(
3297 llvm::dyn_cast_or_null<DenseElementsAttr>(global.getConstantInitValue());
3301 if (
auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(cstAttr))
3302 return splatAttr.getSplatValue<
Attribute>();
3304 if (!getAffineMap().isConstant())
3306 auto indices = llvm::to_vector<4>(
3307 llvm::map_range(getAffineMap().getConstantResults(),
3308 [](int64_t v) -> uint64_t {
return v; }));
3309 return cstAttr.getValues<
Attribute>()[indices];
3319 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
3330 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
3331 int64_t rank = memrefType.getRank();
3336 build(builder, result, valueToStore, memref, map, indices);
3345 AffineMapAttr mapAttr;
3350 mapOperands, mapAttr, AffineStoreOp::getMapAttrStrName(),
3361 p <<
" " << getValueToStore();
3362 p <<
", " << getMemRef() <<
'[';
3363 if (AffineMapAttr mapAttr =
3364 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3368 {getMapAttrStrName()});
3375 if (getValueToStore().
getType() != memrefType.getElementType())
3377 "value to store must have the same type as memref element type");
3380 *
this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3381 getMapOperands(), memrefType,
3382 getNumOperands() - 2)))
3390 results.
add<SimplifyAffineOp<AffineStoreOp>>(context);
3393 LogicalResult AffineStoreOp::fold(FoldAdaptor adaptor,
3403 template <
typename T>
3406 if (op.getNumOperands() !=
3407 op.getMap().getNumDims() + op.getMap().getNumSymbols())
3408 return op.emitOpError(
3409 "operand count and affine map dimension and symbol count must match");
3411 if (op.getMap().getNumResults() == 0)
3412 return op.emitOpError(
"affine map expect at least one result");
3416 template <
typename T>
3418 p <<
' ' << op->getAttr(T::getMapAttrStrName());
3419 auto operands = op.getOperands();
3420 unsigned numDims = op.getMap().getNumDims();
3421 p <<
'(' << operands.take_front(numDims) <<
')';
3423 if (operands.size() != numDims)
3424 p <<
'[' << operands.drop_front(numDims) <<
']';
3426 {T::getMapAttrStrName()});
3429 template <
typename T>
3436 AffineMapAttr mapAttr;
3452 template <
typename T>
3454 static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
3455 "expected affine min or max op");
3461 auto foldedMap = op.getMap().partialConstantFold(operands, &results);
3463 if (foldedMap.getNumSymbols() == 1 && foldedMap.isSymbolIdentity())
3464 return op.getOperand(0);
3467 if (results.empty()) {
3469 if (foldedMap == op.getMap())
3472 return op.getResult();
3476 auto resultIt = std::is_same<T, AffineMinOp>::value
3477 ? llvm::min_element(results)
3478 : llvm::max_element(results);
3479 if (resultIt == results.end())
3485 template <
typename T>
3491 AffineMap oldMap = affineOp.getAffineMap();
3497 if (!llvm::is_contained(newExprs, expr))
3498 newExprs.push_back(expr);
3528 template <
typename T>
3534 AffineMap oldMap = affineOp.getAffineMap();
3536 affineOp.getMapOperands().take_front(oldMap.
getNumDims());
3538 affineOp.getMapOperands().take_back(oldMap.
getNumSymbols());
3540 auto newDimOperands = llvm::to_vector<8>(dimOperands);
3541 auto newSymOperands = llvm::to_vector<8>(symOperands);
3549 if (
auto symExpr = dyn_cast<AffineSymbolExpr>(expr)) {
3550 Value symValue = symOperands[symExpr.getPosition()];
3552 producerOps.push_back(producerOp);
3555 }
else if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
3556 Value dimValue = dimOperands[dimExpr.getPosition()];
3558 producerOps.push_back(producerOp);
3565 newExprs.push_back(expr);
3568 if (producerOps.empty())
3575 for (T producerOp : producerOps) {
3576 AffineMap producerMap = producerOp.getAffineMap();
3577 unsigned numProducerDims = producerMap.
getNumDims();
3582 producerOp.getMapOperands().take_front(numProducerDims);
3584 producerOp.getMapOperands().take_back(numProducerSyms);
3585 newDimOperands.append(dimValues.begin(), dimValues.end());
3586 newSymOperands.append(symValues.begin(), symValues.end());
3590 newExprs.push_back(expr.shiftDims(numProducerDims, numUsedDims)
3591 .shiftSymbols(numProducerSyms, numUsedSyms));
3594 numUsedDims += numProducerDims;
3595 numUsedSyms += numProducerSyms;
3601 llvm::to_vector<8>(llvm::concat<Value>(newDimOperands, newSymOperands));
3620 if (!resultExpr.isPureAffine())
3625 if (failed(flattenResult))
3638 if (llvm::is_sorted(flattenedExprs))
3643 llvm::to_vector(llvm::seq<unsigned>(0, map.
getNumResults()));
3644 llvm::sort(resultPermutation, [&](
unsigned lhs,
unsigned rhs) {
3645 return flattenedExprs[lhs] < flattenedExprs[rhs];
3648 for (
unsigned idx : resultPermutation)
3669 template <
typename T>
3675 AffineMap map = affineOp.getAffineMap();
3683 template <
typename T>
3689 if (affineOp.getMap().getNumResults() != 1)
3692 affineOp.getOperands());
3720 return parseAffineMinMaxOp<AffineMinOp>(parser, result);
3748 return parseAffineMinMaxOp<AffineMaxOp>(parser, result);
3767 IntegerAttr hintInfo;
3769 StringRef readOrWrite, cacheType;
3771 AffineMapAttr mapAttr;
3775 AffinePrefetchOp::getMapAttrStrName(),
3781 AffinePrefetchOp::getLocalityHintAttrStrName(),
3791 if (readOrWrite !=
"read" && readOrWrite !=
"write")
3793 "rw specifier has to be 'read' or 'write'");
3794 result.
addAttribute(AffinePrefetchOp::getIsWriteAttrStrName(),
3797 if (cacheType !=
"data" && cacheType !=
"instr")
3799 "cache type has to be 'data' or 'instr'");
3801 result.
addAttribute(AffinePrefetchOp::getIsDataCacheAttrStrName(),
3808 p <<
" " << getMemref() <<
'[';
3809 AffineMapAttr mapAttr =
3810 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3813 p <<
']' <<
", " << (getIsWrite() ?
"write" :
"read") <<
", "
3814 <<
"locality<" << getLocalityHint() <<
">, "
3815 << (getIsDataCache() ?
"data" :
"instr");
3817 (*this)->getAttrs(),
3818 {getMapAttrStrName(), getLocalityHintAttrStrName(),
3819 getIsDataCacheAttrStrName(), getIsWriteAttrStrName()});
3824 auto mapAttr = (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3828 return emitOpError(
"affine.prefetch affine map num results must equal"
3831 return emitOpError(
"too few operands");
3833 if (getNumOperands() != 1)
3834 return emitOpError(
"too few operands");
3838 for (
auto idx : getMapOperands()) {
3841 "index must be a valid dimension or symbol identifier");
3849 results.
add<SimplifyAffineOp<AffinePrefetchOp>>(context);
3852 LogicalResult AffinePrefetchOp::fold(FoldAdaptor adaptor,
3867 auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) {
3871 build(builder, result, resultTypes, reductions, lbs, {}, ubs,
3881 assert(llvm::all_of(lbMaps,
3883 return m.
getNumDims() == lbMaps[0].getNumDims() &&
3886 "expected all lower bounds maps to have the same number of dimensions "
3888 assert(llvm::all_of(ubMaps,
3890 return m.
getNumDims() == ubMaps[0].getNumDims() &&
3893 "expected all upper bounds maps to have the same number of dimensions "
3895 assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) &&
3896 "expected lower bound maps to have as many inputs as lower bound "
3898 assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) &&
3899 "expected upper bound maps to have as many inputs as upper bound "
3907 for (arith::AtomicRMWKind reduction : reductions)
3908 reductionAttrs.push_back(
3920 groups.reserve(groups.size() + maps.size());
3921 exprs.reserve(maps.size());
3926 return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs,
3932 AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups);
3933 AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups);
3951 for (
unsigned i = 0, e = steps.size(); i < e; ++i)
3953 if (resultTypes.empty())
3954 ensureTerminator(*bodyRegion, builder, result.
location);
3958 return {&getRegion()};
3961 unsigned AffineParallelOp::getNumDims() {
return getSteps().size(); }
3963 AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
3964 return getOperands().take_front(getLowerBoundsMap().getNumInputs());
3967 AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
3968 return getOperands().drop_front(getLowerBoundsMap().getNumInputs());
3971 AffineMap AffineParallelOp::getLowerBoundMap(
unsigned pos) {
3972 auto values = getLowerBoundsGroups().getValues<int32_t>();
3974 for (
unsigned i = 0; i < pos; ++i)
3976 return getLowerBoundsMap().getSliceMap(start, values[pos]);
3979 AffineMap AffineParallelOp::getUpperBoundMap(
unsigned pos) {
3980 auto values = getUpperBoundsGroups().getValues<int32_t>();
3982 for (
unsigned i = 0; i < pos; ++i)
3984 return getUpperBoundsMap().getSliceMap(start, values[pos]);
3988 return AffineValueMap(getLowerBoundsMap(), getLowerBoundsOperands());
3992 return AffineValueMap(getUpperBoundsMap(), getUpperBoundsOperands());
3995 std::optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
3996 if (hasMinMaxBounds())
3997 return std::nullopt;
4002 AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
4005 for (
unsigned i = 0, e = rangesValueMap.
getNumResults(); i < e; ++i) {
4006 auto expr = rangesValueMap.
getResult(i);
4007 auto cst = dyn_cast<AffineConstantExpr>(expr);
4009 return std::nullopt;
4010 out.push_back(cst.getValue());
4015 Block *AffineParallelOp::getBody() {
return &getRegion().
front(); }
4017 OpBuilder AffineParallelOp::getBodyBuilder() {
4018 return OpBuilder(getBody(), std::prev(getBody()->end()));
4023 "operands to map must match number of inputs");
4025 auto ubOperands = getUpperBoundsOperands();
4028 newOperands.append(ubOperands.begin(), ubOperands.end());
4029 (*this)->setOperands(newOperands);
4036 "operands to map must match number of inputs");
4039 newOperands.append(ubOperands.begin(), ubOperands.end());
4040 (*this)->setOperands(newOperands);
4046 setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
4051 arith::AtomicRMWKind op) {
4053 case arith::AtomicRMWKind::addf:
4054 return isa<FloatType>(resultType);
4055 case arith::AtomicRMWKind::addi:
4056 return isa<IntegerType>(resultType);
4057 case arith::AtomicRMWKind::assign:
4059 case arith::AtomicRMWKind::mulf:
4060 return isa<FloatType>(resultType);
4061 case arith::AtomicRMWKind::muli:
4062 return isa<IntegerType>(resultType);
4063 case arith::AtomicRMWKind::maximumf:
4064 return isa<FloatType>(resultType);
4065 case arith::AtomicRMWKind::minimumf:
4066 return isa<FloatType>(resultType);
4067 case arith::AtomicRMWKind::maxs: {
4068 auto intType = llvm::dyn_cast<IntegerType>(resultType);
4069 return intType && intType.isSigned();
4071 case arith::AtomicRMWKind::mins: {
4072 auto intType = llvm::dyn_cast<IntegerType>(resultType);
4073 return intType && intType.isSigned();
4075 case arith::AtomicRMWKind::maxu: {
4076 auto intType = llvm::dyn_cast<IntegerType>(resultType);
4077 return intType && intType.isUnsigned();
4079 case arith::AtomicRMWKind::minu: {
4080 auto intType = llvm::dyn_cast<IntegerType>(resultType);
4081 return intType && intType.isUnsigned();
4083 case arith::AtomicRMWKind::ori:
4084 return isa<IntegerType>(resultType);
4085 case arith::AtomicRMWKind::andi:
4086 return isa<IntegerType>(resultType);
4093 auto numDims = getNumDims();
4096 getSteps().size() != numDims || getBody()->getNumArguments() != numDims) {
4097 return emitOpError() <<
"the number of region arguments ("
4098 << getBody()->getNumArguments()
4099 <<
") and the number of map groups for lower ("
4100 << getLowerBoundsGroups().getNumElements()
4101 <<
") and upper bound ("
4102 << getUpperBoundsGroups().getNumElements()
4103 <<
"), and the number of steps (" << getSteps().size()
4104 <<
") must all match";
4107 unsigned expectedNumLBResults = 0;
4108 for (APInt v : getLowerBoundsGroups()) {
4109 unsigned results = v.getZExtValue();
4111 return emitOpError()
4112 <<
"expected lower bound map to have at least one result";
4113 expectedNumLBResults += results;
4115 if (expectedNumLBResults != getLowerBoundsMap().getNumResults())
4116 return emitOpError() <<
"expected lower bounds map to have "
4117 << expectedNumLBResults <<
" results";
4118 unsigned expectedNumUBResults = 0;
4119 for (APInt v : getUpperBoundsGroups()) {
4120 unsigned results = v.getZExtValue();
4122 return emitOpError()
4123 <<
"expected upper bound map to have at least one result";
4124 expectedNumUBResults += results;
4126 if (expectedNumUBResults != getUpperBoundsMap().getNumResults())
4127 return emitOpError() <<
"expected upper bounds map to have "
4128 << expectedNumUBResults <<
" results";
4130 if (getReductions().size() != getNumResults())
4131 return emitOpError(
"a reduction must be specified for each output");
4137 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
4138 if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
4139 return emitOpError(
"invalid reduction attribute");
4140 auto kind = arith::symbolizeAtomicRMWKind(intAttr.getInt()).value();
4142 return emitOpError(
"result type cannot match reduction attribute");
4148 getLowerBoundsMap().getNumDims())))
4152 getUpperBoundsMap().getNumDims())))
4157 LogicalResult AffineValueMap::canonicalize() {
4159 auto newMap = getAffineMap();
4161 if (newMap == getAffineMap() && newOperands == operands)
4163 reset(newMap, newOperands);
4176 if (!lbCanonicalized && !ubCanonicalized)
4179 if (lbCanonicalized)
4181 if (ubCanonicalized)
4187 LogicalResult AffineParallelOp::fold(FoldAdaptor adaptor,
4199 StringRef keyword) {
4202 ValueRange dimOperands = operands.take_front(numDims);
4203 ValueRange symOperands = operands.drop_front(numDims);
4205 for (llvm::APInt groupSize : group) {
4209 unsigned size = groupSize.getZExtValue();
4214 p << keyword <<
'(';
4224 p <<
" (" << getBody()->getArguments() <<
") = (";
4226 getLowerBoundsOperands(),
"max");
4229 getUpperBoundsOperands(),
"min");
4232 bool elideSteps = llvm::all_of(steps, [](int64_t step) {
return step == 1; });
4235 llvm::interleaveComma(steps, p);
4238 if (getNumResults()) {
4240 llvm::interleaveComma(getReductions(), p, [&](
auto &attr) {
4241 arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(
4242 llvm::cast<IntegerAttr>(attr).getInt());
4243 p <<
"\"" << arith::stringifyAtomicRMWKind(sym) <<
"\"";
4245 p <<
") -> (" << getResultTypes() <<
")";
4252 (*this)->getAttrs(),
4253 {AffineParallelOp::getReductionsAttrStrName(),
4254 AffineParallelOp::getLowerBoundsMapAttrStrName(),
4255 AffineParallelOp::getLowerBoundsGroupsAttrStrName(),
4256 AffineParallelOp::getUpperBoundsMapAttrStrName(),
4257 AffineParallelOp::getUpperBoundsGroupsAttrStrName(),
4258 AffineParallelOp::getStepsAttrStrName()});
4271 "expected operands to be dim or symbol expression");
4274 for (
const auto &list : operands) {
4278 for (
Value operand : valueOperands) {
4279 unsigned pos = std::distance(uniqueOperands.begin(),
4280 llvm::find(uniqueOperands, operand));
4281 if (pos == uniqueOperands.size())
4282 uniqueOperands.push_back(operand);
4283 replacements.push_back(
4293 enum class MinMaxKind { Min, Max };
4317 const llvm::StringLiteral tmpAttrStrName =
"__pseudo_bound_map";
4319 StringRef mapName =
kind == MinMaxKind::Min
4320 ? AffineParallelOp::getUpperBoundsMapAttrStrName()
4321 : AffineParallelOp::getLowerBoundsMapAttrStrName();
4322 StringRef groupsName =
4323 kind == MinMaxKind::Min
4324 ? AffineParallelOp::getUpperBoundsGroupsAttrStrName()
4325 : AffineParallelOp::getLowerBoundsGroupsAttrStrName();
4342 auto parseOperands = [&]() {
4344 kind == MinMaxKind::Min ?
"min" :
"max"))) {
4345 mapOperands.clear();
4352 llvm::append_range(flatExprs, map.getValue().getResults());
4354 auto dimsRef = operandsRef.take_front(map.getValue().getNumDims());
4356 auto symsRef = operandsRef.drop_front(map.getValue().getNumDims());
4358 flatDimOperands.append(map.getValue().getNumResults(), dims);
4359 flatSymOperands.append(map.getValue().getNumResults(), syms);
4360 numMapsPerGroup.push_back(map.getValue().getNumResults());
4363 flatSymOperands.emplace_back(),
4364 flatExprs.emplace_back())))
4366 numMapsPerGroup.push_back(1);
4373 unsigned totalNumDims = 0;
4374 unsigned totalNumSyms = 0;
4375 for (
unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
4376 unsigned numDims = flatDimOperands[i].size();
4377 unsigned numSyms = flatSymOperands[i].size();
4378 flatExprs[i] = flatExprs[i]
4379 .shiftDims(numDims, totalNumDims)
4380 .shiftSymbols(numSyms, totalNumSyms);
4381 totalNumDims += numDims;
4382 totalNumSyms += numSyms;
4394 result.
operands.append(dimOperands.begin(), dimOperands.end());
4395 result.
operands.append(symOperands.begin(), symOperands.end());
4398 auto flatMap =
AffineMap::get(totalNumDims, totalNumSyms, flatExprs,
4400 flatMap = flatMap.replaceDimsAndSymbols(
4401 dimRplacements, symRepacements, dimOperands.size(), symOperands.size());
4425 AffineMapAttr stepsMapAttr;
4430 result.
addAttribute(AffineParallelOp::getStepsAttrStrName(),
4434 AffineParallelOp::getStepsAttrStrName(),
4441 auto stepsMap = stepsMapAttr.getValue();
4442 for (
const auto &result : stepsMap.getResults()) {
4443 auto constExpr = dyn_cast<AffineConstantExpr>(result);
4446 "steps must be constant integers");
4447 steps.push_back(constExpr.getValue());
4449 result.
addAttribute(AffineParallelOp::getStepsAttrStrName(),
4459 auto parseAttributes = [&]() -> ParseResult {
4469 std::optional<arith::AtomicRMWKind> reduction =
4470 arith::symbolizeAtomicRMWKind(attrVal.getValue());
4472 return parser.
emitError(loc,
"invalid reduction value: ") << attrVal;
4473 reductions.push_back(
4481 result.
addAttribute(AffineParallelOp::getReductionsAttrStrName(),
4490 for (
auto &iv : ivs)
4491 iv.type = indexType;
4497 AffineParallelOp::ensureTerminator(*body, builder, result.
location);
4506 auto *parentOp = (*this)->getParentOp();
4507 auto results = parentOp->getResults();
4508 auto operands = getOperands();
4510 if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
4511 return emitOpError() <<
"only terminates affine.if/for/parallel regions";
4512 if (parentOp->getNumResults() != getNumOperands())
4513 return emitOpError() <<
"parent of yield must have same number of "
4514 "results as the yield operands";
4515 for (
auto it : llvm::zip(results, operands)) {
4517 return emitOpError() <<
"types mismatch between yield op and its parent";
4530 assert(operands.size() == 1 + map.
getNumInputs() &&
"inconsistent operands");
4534 result.
types.push_back(resultType);
4538 VectorType resultType,
Value memref,
4540 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
4544 result.
types.push_back(resultType);
4548 VectorType resultType,
Value memref,
4550 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
4551 int64_t rank = memrefType.getRank();
4556 build(builder, result, resultType, memref, map, indices);
4559 void AffineVectorLoadOp::getCanonicalizationPatterns(
RewritePatternSet &results,
4561 results.
add<SimplifyAffineOp<AffineVectorLoadOp>>(context);
4569 MemRefType memrefType;
4570 VectorType resultType;
4572 AffineMapAttr mapAttr;
4577 AffineVectorLoadOp::getMapAttrStrName(),
4588 p <<
" " << getMemRef() <<
'[';
4589 if (AffineMapAttr mapAttr =
4590 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4594 {getMapAttrStrName()});
4600 VectorType vectorType) {
4602 if (memrefType.getElementType() != vectorType.getElementType())
4604 "requires memref and vector types of the same elemental type");
4611 *
this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4612 getMapOperands(), memrefType,
4613 getNumOperands() - 1)))
4629 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
4640 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
4641 int64_t rank = memrefType.getRank();
4646 build(builder, result, valueToStore, memref, map, indices);
4648 void AffineVectorStoreOp::getCanonicalizationPatterns(
4650 results.
add<SimplifyAffineOp<AffineVectorStoreOp>>(context);
4657 MemRefType memrefType;
4658 VectorType resultType;
4661 AffineMapAttr mapAttr;
4667 AffineVectorStoreOp::getMapAttrStrName(),
4678 p <<
" " << getValueToStore();
4679 p <<
", " << getMemRef() <<
'[';
4680 if (AffineMapAttr mapAttr =
4681 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4685 {getMapAttrStrName()});
4686 p <<
" : " <<
getMemRefType() <<
", " << getValueToStore().getType();
4692 *
this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4693 getMapOperands(), memrefType,
4694 getNumOperands() - 2)))
4707 void AffineDelinearizeIndexOp::build(
OpBuilder &odsBuilder,
4711 bool hasOuterBound) {
4713 : staticBasis.size() + 1,
4715 build(odsBuilder, odsState, returnTypes, linearIndex, dynamicBasis,
4719 void AffineDelinearizeIndexOp::build(
OpBuilder &odsBuilder,
4722 bool hasOuterBound) {
4723 if (hasOuterBound && !basis.empty() && basis.front() ==
nullptr) {
4724 hasOuterBound =
false;
4725 basis = basis.drop_front();
4731 build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4735 void AffineDelinearizeIndexOp::build(
OpBuilder &odsBuilder,
4739 bool hasOuterBound) {
4740 if (hasOuterBound && !basis.empty() && basis.front() ==
OpFoldResult()) {
4741 hasOuterBound =
false;
4742 basis = basis.drop_front();
4747 build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4751 void AffineDelinearizeIndexOp::build(
OpBuilder &odsBuilder,
4754 bool hasOuterBound) {
4755 build(odsBuilder, odsState, linearIndex,
ValueRange{}, basis, hasOuterBound);
4760 if (getNumResults() != staticBasis.size() &&
4761 getNumResults() != staticBasis.size() + 1)
4762 return emitOpError(
"should return an index for each basis element and up "
4763 "to one extra index");
4765 auto dynamicMarkersCount = llvm::count_if(staticBasis, ShapedType::isDynamic);
4766 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
4768 "mismatch between dynamic and static basis (kDynamic marker but no "
4769 "corresponding dynamic basis entry) -- this can only happen due to an "
4770 "incorrect fold/rewrite");
4772 if (!llvm::all_of(staticBasis, [](int64_t v) {
4773 return v > 0 || ShapedType::isDynamic(v);
4775 return emitOpError(
"no basis element may be statically non-positive");
4784 static std::optional<SmallVector<int64_t>>
4788 uint64_t dynamicBasisIndex = 0;
4791 mutableDynamicBasis.
erase(dynamicBasisIndex);
4793 ++dynamicBasisIndex;
4798 if (dynamicBasisIndex == dynamicBasis.size())
4799 return std::nullopt;
4805 staticBasis.push_back(ShapedType::kDynamic);
4807 staticBasis.push_back(*basisVal);
4814 AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
4816 std::optional<SmallVector<int64_t>> maybeStaticBasis =
4818 adaptor.getDynamicBasis());
4819 if (maybeStaticBasis) {
4820 setStaticBasis(*maybeStaticBasis);
4825 if (getNumResults() == 1) {
4826 result.push_back(getLinearIndex());
4830 if (adaptor.getLinearIndex() ==
nullptr)
4833 if (!adaptor.getDynamicBasis().empty())
4836 int64_t highPart = cast<IntegerAttr>(adaptor.getLinearIndex()).getInt();
4837 Type attrType = getLinearIndex().getType();
4840 if (hasOuterBound())
4841 staticBasis = staticBasis.drop_front();
4842 for (int64_t modulus : llvm::reverse(staticBasis)) {
4843 result.push_back(
IntegerAttr::get(attrType, llvm::mod(highPart, modulus)));
4844 highPart = llvm::divideFloorSigned(highPart, modulus);
4847 std::reverse(result.begin(), result.end());
4853 if (hasOuterBound()) {
4854 if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
4856 getDynamicBasis().drop_front(), builder);
4858 return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
4862 return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
4867 if (!hasOuterBound())
4875 struct DropUnitExtentBasis
4879 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4882 std::optional<Value> zero = std::nullopt;
4883 Location loc = delinearizeOp->getLoc();
4886 zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
4887 return zero.value();
4893 for (
auto [index, basis] :
4895 std::optional<int64_t> basisVal =
4898 replacements[index] =
getZero();
4900 newBasis.push_back(basis);
4903 if (newBasis.size() == delinearizeOp.getNumResults())
4905 "no unit basis elements");
4907 if (!newBasis.empty()) {
4909 auto newDelinearizeOp = rewriter.
create<affine::AffineDelinearizeIndexOp>(
4910 loc, delinearizeOp.getLinearIndex(), newBasis);
4913 for (
auto &replacement : replacements) {
4916 replacement = newDelinearizeOp->
getResult(newIndex++);
4920 rewriter.
replaceOp(delinearizeOp, replacements);
4935 struct CancelDelinearizeOfLinearizeDisjointExactTail
4939 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4941 auto linearizeOp = delinearizeOp.getLinearIndex()
4942 .getDefiningOp<affine::AffineLinearizeIndexOp>();
4945 "index doesn't come from linearize");
4947 if (!linearizeOp.getDisjoint())
4950 ValueRange linearizeIns = linearizeOp.getMultiIndex();
4954 size_t numMatches = 0;
4955 for (
auto [linSize, delinSize] : llvm::zip(
4956 llvm::reverse(linearizeBasis), llvm::reverse(delinearizeBasis))) {
4957 if (linSize != delinSize)
4962 if (numMatches == 0)
4964 delinearizeOp,
"final basis element doesn't match linearize");
4967 if (numMatches == linearizeBasis.size() &&
4968 numMatches == delinearizeBasis.size() &&
4969 linearizeIns.size() == delinearizeOp.getNumResults()) {
4970 rewriter.
replaceOp(delinearizeOp, linearizeOp.getMultiIndex());
4974 Value newLinearize = rewriter.
create<affine::AffineLinearizeIndexOp>(
4975 linearizeOp.getLoc(), linearizeIns.drop_back(numMatches),
4977 linearizeOp.getDisjoint());
4978 auto newDelinearize = rewriter.
create<affine::AffineDelinearizeIndexOp>(
4979 delinearizeOp.getLoc(), newLinearize,
4981 delinearizeOp.hasOuterBound());
4983 mergedResults.append(linearizeIns.take_back(numMatches).begin(),
4984 linearizeIns.take_back(numMatches).end());
4985 rewriter.
replaceOp(delinearizeOp, mergedResults);
5003 struct SplitDelinearizeSpanningLastLinearizeArg final
5007 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
5009 auto linearizeOp = delinearizeOp.getLinearIndex()
5010 .getDefiningOp<affine::AffineLinearizeIndexOp>();
5013 "index doesn't come from linearize");
5015 if (!linearizeOp.getDisjoint())
5017 "linearize isn't disjoint");
5019 int64_t target = linearizeOp.getStaticBasis().back();
5020 if (ShapedType::isDynamic(target))
5022 linearizeOp,
"linearize ends with dynamic basis value");
5024 int64_t sizeToSplit = 1;
5025 size_t elemsToSplit = 0;
5027 for (int64_t basisElem : llvm::reverse(basis)) {
5028 if (ShapedType::isDynamic(basisElem))
5030 delinearizeOp,
"dynamic basis element while scanning for split");
5031 sizeToSplit *= basisElem;
5034 if (sizeToSplit > target)
5036 "overshot last argument size");
5037 if (sizeToSplit == target)
5041 if (sizeToSplit < target)
5043 delinearizeOp,
"product of known basis elements doesn't exceed last "
5044 "linearize argument");
5046 if (elemsToSplit < 2)
5049 "need at least two elements to form the basis product");
5051 Value linearizeWithoutBack =
5052 rewriter.
create<affine::AffineLinearizeIndexOp>(
5053 linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(),
5054 linearizeOp.getDynamicBasis(),
5055 linearizeOp.getStaticBasis().drop_back(),
5056 linearizeOp.getDisjoint());
5057 auto delinearizeWithoutSplitPart =
5058 rewriter.
create<affine::AffineDelinearizeIndexOp>(
5059 delinearizeOp.getLoc(), linearizeWithoutBack,
5060 delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),
5061 delinearizeOp.hasOuterBound());
5062 auto delinearizeBack = rewriter.
create<affine::AffineDelinearizeIndexOp>(
5063 delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(),
5064 basis.take_back(elemsToSplit),
true);
5066 llvm::concat<Value>(delinearizeWithoutSplitPart.getResults(),
5067 delinearizeBack.getResults()));
5068 rewriter.
replaceOp(delinearizeOp, results);
5075 void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
5078 .insert<CancelDelinearizeOfLinearizeDisjointExactTail,
5079 DropUnitExtentBasis, SplitDelinearizeSpanningLastLinearizeArg>(
5087 void AffineLinearizeIndexOp::build(
OpBuilder &odsBuilder,
5091 if (!basis.empty() && basis.front() ==
Value())
5092 basis = basis.drop_front();
5097 build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
5100 void AffineLinearizeIndexOp::build(
OpBuilder &odsBuilder,
5106 basis = basis.drop_front();
5110 build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
5113 void AffineLinearizeIndexOp::build(
OpBuilder &odsBuilder,
5117 build(odsBuilder, odsState, multiIndex,
ValueRange{}, basis, disjoint);
5121 size_t numIndexes = getMultiIndex().size();
5122 size_t numBasisElems = getStaticBasis().size();
5123 if (numIndexes != numBasisElems && numIndexes != numBasisElems + 1)
5124 return emitOpError(
"should be passed a basis element for each index except "
5125 "possibly the first");
5127 auto dynamicMarkersCount =
5128 llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
5129 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
5131 "mismatch between dynamic and static basis (kDynamic marker but no "
5132 "corresponding dynamic basis entry) -- this can only happen due to an "
5133 "incorrect fold/rewrite");
5138 OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
5139 std::optional<SmallVector<int64_t>> maybeStaticBasis =
5141 adaptor.getDynamicBasis());
5142 if (maybeStaticBasis) {
5143 setStaticBasis(*maybeStaticBasis);
5147 if (getMultiIndex().empty())
5151 if (getMultiIndex().size() == 1)
5152 return getMultiIndex().front();
5154 if (llvm::is_contained(adaptor.getMultiIndex(),
nullptr))
5157 if (!adaptor.getDynamicBasis().empty())
5162 for (
auto [length, indexAttr] :
5163 llvm::zip_first(llvm::reverse(getStaticBasis()),
5164 llvm::reverse(adaptor.getMultiIndex()))) {
5165 result = result + cast<IntegerAttr>(indexAttr).getInt() * stride;
5166 stride = stride * length;
5169 if (!hasOuterBound())
5172 cast<IntegerAttr>(adaptor.getMultiIndex().front()).getInt() * stride;
5179 if (hasOuterBound()) {
5180 if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
5182 getDynamicBasis().drop_front(), builder);
5184 return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
5188 return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
5193 if (!hasOuterBound())
5209 struct DropLinearizeUnitComponentsIfDisjointOrZero final
5213 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5216 size_t numIndices = multiIndex.size();
5218 newIndices.reserve(numIndices);
5220 newBasis.reserve(numIndices);
5222 if (!op.hasOuterBound()) {
5223 newIndices.push_back(multiIndex.front());
5224 multiIndex = multiIndex.drop_front();
5228 for (
auto [index, basisElem] : llvm::zip_equal(multiIndex, basis)) {
5230 if (!basisEntry || *basisEntry != 1) {
5231 newIndices.push_back(index);
5232 newBasis.push_back(basisElem);
5237 if (!op.getDisjoint() && (!indexValue || *indexValue != 0)) {
5238 newIndices.push_back(index);
5239 newBasis.push_back(basisElem);
5243 if (newIndices.size() == numIndices)
5245 "no unit basis entries to replace");
5247 if (newIndices.size() == 0) {
5252 op, newIndices, newBasis, op.getDisjoint());
5259 int64_t nDynamic = 0;
5269 dynamicPart.push_back(cast<Value>(term));
5273 if (
auto constant = dyn_cast<AffineConstantExpr>(result))
5275 return builder.
create<AffineApplyOp>(loc, result, dynamicPart).getResult();
5305 struct CancelLinearizeOfDelinearizePortion final
5316 unsigned linStart = 0;
5317 unsigned delinStart = 0;
5318 unsigned length = 0;
5322 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp,
5329 ValueRange multiIndex = linearizeOp.getMultiIndex();
5330 unsigned numLinArgs = multiIndex.size();
5331 unsigned linArgIdx = 0;
5335 while (linArgIdx < numLinArgs) {
5336 auto asResult = dyn_cast<OpResult>(multiIndex[linArgIdx]);
5342 auto delinearizeOp =
5343 dyn_cast<AffineDelinearizeIndexOp>(asResult.getOwner());
5344 if (!delinearizeOp) {
5361 unsigned delinArgIdx = asResult.getResultNumber();
5363 OpFoldResult firstDelinBound = delinBasis[delinArgIdx];
5365 bool boundsMatch = firstDelinBound == firstLinBound;
5366 bool bothAtFront = linArgIdx == 0 && delinArgIdx == 0;
5367 bool knownByDisjoint =
5368 linearizeOp.getDisjoint() && delinArgIdx == 0 && !firstDelinBound;
5369 if (!boundsMatch && !bothAtFront && !knownByDisjoint) {
5375 unsigned numDelinOuts = delinearizeOp.getNumResults();
5376 for (;
j + linArgIdx < numLinArgs &&
j + delinArgIdx < numDelinOuts;
5378 if (multiIndex[linArgIdx +
j] !=
5379 delinearizeOp.getResult(delinArgIdx +
j))
5381 if (linBasis[linArgIdx +
j] != delinBasis[delinArgIdx +
j])
5387 if (
j <= 1 || !alreadyMatchedDelinearize.insert(delinearizeOp).second) {
5391 matches.push_back(Match{delinearizeOp, linArgIdx, delinArgIdx,
j});
5395 if (matches.empty())
5397 linearizeOp,
"no run of delinearize outputs to deal with");
5405 newIndex.reserve(numLinArgs);
5407 newBasis.reserve(numLinArgs);
5408 unsigned prevMatchEnd = 0;
5409 for (Match m : matches) {
5410 unsigned gap = m.linStart - prevMatchEnd;
5411 llvm::append_range(newIndex, multiIndex.slice(prevMatchEnd, gap));
5412 llvm::append_range(newBasis, linBasisRef.slice(prevMatchEnd, gap));
5414 prevMatchEnd = m.linStart + m.length;
5416 PatternRewriter::InsertionGuard g(rewriter);
5420 linBasisRef.slice(m.linStart, m.length);
5428 newIndex.push_back(m.delinearize.getLinearIndex());
5429 newBasis.push_back(newSize);
5437 newDelinBasis.erase(newDelinBasis.begin() + m.delinStart,
5438 newDelinBasis.begin() + m.delinStart + m.length);
5439 newDelinBasis.insert(newDelinBasis.begin() + m.delinStart, newSize);
5440 auto newDelinearize = rewriter.
create<AffineDelinearizeIndexOp>(
5441 m.delinearize.getLoc(), m.delinearize.getLinearIndex(),
5447 Value combinedElem = newDelinearize.getResult(m.delinStart);
5448 auto residualDelinearize = rewriter.
create<AffineDelinearizeIndexOp>(
5449 m.delinearize.getLoc(), combinedElem, basisToMerge);
5454 llvm::append_range(newDelinResults,
5455 newDelinearize.getResults().take_front(m.delinStart));
5456 llvm::append_range(newDelinResults, residualDelinearize.getResults());
5459 newDelinearize.getResults().drop_front(m.delinStart + 1));
5461 delinearizeReplacements.push_back(newDelinResults);
5462 newIndex.push_back(combinedElem);
5463 newBasis.push_back(newSize);
5465 llvm::append_range(newIndex, multiIndex.drop_front(prevMatchEnd));
5466 llvm::append_range(newBasis, linBasisRef.drop_front(prevMatchEnd));
5468 linearizeOp, newIndex, newBasis, linearizeOp.getDisjoint());
5470 for (
auto [m, newResults] :
5471 llvm::zip_equal(matches, delinearizeReplacements)) {
5472 if (newResults.empty())
5474 rewriter.
replaceOp(m.delinearize, newResults);
5485 struct DropLinearizeLeadingZero final
5489 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5491 Value leadingIdx = op.getMultiIndex().front();
5495 if (op.getMultiIndex().size() == 1) {
5502 if (op.hasOuterBound())
5503 newMixedBasis = newMixedBasis.drop_front();
5506 op, op.getMultiIndex().drop_front(), newMixedBasis, op.getDisjoint());
5512 void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
5514 patterns.add<CancelLinearizeOfDelinearizePortion, DropLinearizeLeadingZero,
5515 DropLinearizeUnitComponentsIfDisjointOrZero>(context);
5522 #define GET_OP_CLASSES
5523 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
static AffineForOp buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb, int64_t ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds known to be constants.
static bool hasTrivialZeroTripCount(AffineForOp op)
Returns true if the affine.for has zero iterations in trivial cases.
static LogicalResult verifyMemoryOpIndexing(AffineMemOpTy op, AffineMapAttr mapAttr, Operation::operand_range mapOperands, MemRefType memrefType, unsigned numIndexOperands)
Verify common indexing invariants of affine.load, affine.store, affine.vector_load and affine....
static void printAffineMinMaxOp(OpAsmPrinter &p, T op)
static bool isResultTypeMatchAtomicRMWKind(Type resultType, arith::AtomicRMWKind op)
static bool remainsLegalAfterInline(Value value, Region *src, Region *dest, const IRMapping &mapping, function_ref< bool(Value, Region *)> legalityCheck)
Checks if value known to be a legal affine dimension or symbol in src region remains legal if the ope...
static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr, DenseIntElementsAttr group, ValueRange operands, StringRef keyword)
Prints a lower(upper) bound of an affine parallel loop with max(min) conditions in it.
static void LLVM_ATTRIBUTE_UNUSED simplifyMapWithOperands(AffineMap &map, ArrayRef< Value > operands)
Simplify the map while exploiting information on the values in operands.
static OpFoldResult foldMinMaxOp(T op, ArrayRef< Attribute > operands)
Fold an affine min or max operation with the given operands.
static LogicalResult canonicalizeLoopBounds(AffineForOp forOp)
Canonicalize the bounds of the given loop.
static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims, unsigned numSymbols, ArrayRef< Value > operands)
Simplify expr while exploiting information from the values in operands.
static bool isValidAffineIndexOperand(Value value, Region *region)
static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static ParseResult parseBound(bool isLower, OperationState &result, OpAsmParser &p)
Parse a for operation loop bounds.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType)
Verify common invariants of affine.vector_load and affine.vector_store.
static void simplifyMinOrMaxExprWithOperands(AffineMap &map, ArrayRef< Value > operands, bool isMax)
Simplify the expressions in map while making use of lower or upper bounds of its operands.
static ParseResult parseAffineMinMaxOp(OpAsmParser &parser, OperationState &result)
static void composeSetAndOperands(IntegerSet &set, SmallVectorImpl< Value > &operands, bool composeAffineMin=false)
Compose any affine.apply ops feeding into operands of the integer set set by composing the maps of su...
static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, Region *region)
Returns true if the 'index' dimension of the memref defined by memrefDefOp is a statically shaped one...
static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef< Value > operands, int64_t k)
Check if e is known to be: 0 <= e < k.
static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser, OperationState &result, MinMaxKind kind)
Parses an affine map that can contain a min/max for groups of its results, e.g., max(expr-1,...
static AffineForOp buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds that may or may not be constants.
static void printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, OpAsmPrinter &printer)
Prints dimension and symbol list.
static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef< Value > operands)
Returns the largest known divisor of e.
static void composeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands, bool composeAffineMin=false)
Iterate over operands and fold away all those produced by an AffineApplyOp iteratively.
static void legalizeDemotedDims(MapOrSet &mapOrSet, SmallVectorImpl< Value > &operands)
A valid affine dimension may appear as a symbol in affine.apply operations.
static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
static void buildAffineLoopNestImpl(OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, LoopCreatorTy &&loopCreatorFn)
Builds an affine loop nest, using "loopCreatorFn" to create individual loop operations.
static LogicalResult foldLoopBounds(AffineForOp forOp)
Fold the constant bounds of a loop.
static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp, AffineExpr dimOrSym, AffineMap *map, ValueRange dims, ValueRange syms)
Assuming dimOrSym is a quantity in the apply op map map and defined by minOp = affine_min(x_1,...
static LogicalResult verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, unsigned numDims)
Utility function to verify that a set of operands are valid dimension and symbol identifiers.
static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region)
Returns true if the result of the dim op is a valid symbol for region.
static bool isQTimesDPlusR(AffineExpr e, ArrayRef< Value > operands, int64_t &div, AffineExpr "ientTimesDiv, AffineExpr &rem)
Check if expression e is of the form d*e_1 + e_2 where 0 <= e_2 < d.
static ParseResult deduplicateAndResolveOperands(OpAsmParser &parser, ArrayRef< SmallVector< OpAsmParser::UnresolvedOperand >> operands, SmallVectorImpl< Value > &uniqueOperands, SmallVectorImpl< AffineExpr > &replacements, AffineExprKind kind)
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult replaceDimOrSym(AffineMap *map, unsigned dimOrSymbolPosition, SmallVectorImpl< Value > &dims, SmallVectorImpl< Value > &syms, bool replaceAffineMin)
Replace all occurrences of AffineExpr at position pos in map by the defining AffineApplyOp expression...
static LogicalResult verifyAffineMinMaxOp(T op)
static void printBound(AffineMapAttr boundMap, Operation::operand_range boundOperands, const char *prefix, OpAsmPrinter &p)
static void composeMultiResultAffineMap(AffineMap &map, SmallVectorImpl< Value > &operands, bool composeAffineMin=false)
Composes the given affine map with the given list of operands, pulling in the maps from any affine....
static std::optional< SmallVector< int64_t > > foldCstValueToCstAttrBasis(ArrayRef< OpFoldResult > mixedBasis, MutableOperandRange mutableDynamicBasis, ArrayRef< Attribute > dynamicBasis)
Given mixed basis of affine.delinearize_index/linearize_index replace constant SSA values with the co...
static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map)
Canonicalize the result expression order of an affine map and return success if the order changed.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
union mlir::linalg::@1221::ArityGroupAndKind::Kind kind
static Operation::operand_range getLowerBoundOperands(AffineForOp forOp)
static Operation::operand_range getUpperBoundOperands(AffineForOp forOp)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
AffineExprKind getKind() const
Return the classification for this type.
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
MLIRContext * getContext() const
AffineExpr replace(AffineExpr expr, AffineExpr replacement) const
Sparse replace method.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
MLIRContext * getContext() const
bool isFunctionOfDim(unsigned position) const
Return true if any affine expression involves AffineDimExpr position.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
bool isFunctionOfSymbol(unsigned position) const
Return true if any affine expression involves AffineSymbolExpr position.
unsigned getNumResults() const
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
unsigned getNumInputs() const
AffineMap shiftSymbols(unsigned shift, unsigned offset=0) const
Replace symbols[offset ...
AffineExpr getResult(unsigned idx) const
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
static AffineMap getConstantMap(int64_t val, MLIRContext *context)
Returns a single constant result affine map.
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
LogicalResult constantFold(ArrayRef< Attribute > operandConstants, SmallVectorImpl< Attribute > &results, bool *hasPoison=nullptr) const
Folds the results of the application of an affine map on the provided operands to a constant if possi...
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
AffineMap getDimIdentityMap()
AffineMap getMultiDimIdentityMap(unsigned rank)
AffineExpr getAffineSymbolExpr(unsigned position)
AffineExpr getAffineConstantExpr(int64_t constant)
DenseIntElementsAttr getI32TensorAttr(ArrayRef< int32_t > values)
Tensor-typed DenseIntElementsAttr getters.
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
AffineMap getEmptyAffineMap()
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
MLIRContext * getContext() const
AffineMap getSymbolIdentityMap()
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
An attribute that represents a reference to a dense integer vector or tensor object.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
An integer set representing a conjunction of one or more affine equalities and inequalities.
unsigned getNumDims() const
static IntegerSet get(unsigned dimCount, unsigned symbolCount, ArrayRef< AffineExpr > constraints, ArrayRef< bool > eqFlags)
MLIRContext * getContext() const
unsigned getNumInputs() const
ArrayRef< AffineExpr > getConstraints() const
ArrayRef< bool > getEqFlags() const
Returns the equality bits, which specify whether each of the constraints is an equality or inequality...
unsigned getNumSymbols() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
void erase(unsigned subStart, unsigned subLen=1)
Erase the operands within the given sub-range.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void pop_back()
Pop last element from list.
Attribute erase(StringAttr name)
Erase the attribute with the given name from the list.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseAffineMapOfSSAIds(SmallVectorImpl< UnresolvedOperand > &operands, Attribute &map, StringRef attrName, NamedAttrList &attrs, Delimiter delimiter=Delimiter::Square)=0
Parses an affine map attribute where dims and symbols are SSA operands.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseAffineExprOfSSAIds(SmallVectorImpl< UnresolvedOperand > &dimOperands, SmallVectorImpl< UnresolvedOperand > &symbOperands, AffineExpr &expr)=0
Parses an affine expression where dims and symbols are SSA operands.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, ValueRange symOperands)=0
Prints an affine expression of SSA ids with SSA id names used instead of dims and symbols.
virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, ValueRange operands)=0
Prints an affine map of SSA ids, where SSA id names are used in place of dims/symbols.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
A trait of region holding operations that defines a new scope for polyhedral optimization purposes.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
operand_range::iterator operand_iterator
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
std::vector< SmallVector< int64_t, 8 > > operandExprStack
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
A variable that can be added to the constraint set as a "column".
static bool compare(const Variable &lhs, ComparisonOperator cmp, const Variable &rhs)
Return "true" if "lhs cmp rhs" was proven to hold.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
AffineBound represents a lower or upper bound in the for operation.
AffineDmaStartOp starts a non-blocking DMA operation that transfers data from a source memref to a de...
AffineDmaWaitOp blocks until the completion of a DMA operation associated with the tag element 'tag[i...
An AffineValueMap is an affine map plus its ML value operands and results for analysis purposes.
LogicalResult canonicalize()
Attempts to canonicalize the map and operands.
ArrayRef< Value > getOperands() const
AffineExpr getResult(unsigned i)
AffineMap getAffineMap() const
unsigned getNumResults() const
Operation * getOwner() const
Return the owner of this operand.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
void buildAffineLoopNest(OpBuilder &builder, Location loc, ArrayRef< int64_t > lbs, ArrayRef< int64_t > ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn=nullptr)
Builds a perfect nest of affine.for loops, i.e., each loop except the innermost one contains only ano...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
void extractForInductionVars(ArrayRef< AffineForOp > forInsts, SmallVectorImpl< Value > *ivs)
Extracts the induction variables from a list of AffineForOps and places them in the output argument i...
bool isValidDim(Value value)
Returns true if the given Value can be used as a dimension id in the region of the closest surroundin...
bool isAffineInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp or AffineParallelOp.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
AffineForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
bool isAffineForInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
bool isTopLevelValue(Value value)
A utility function to check if a value is defined at the top level of an op with trait AffineScope or...
Region * getAffineAnalysisScope(Operation *op)
Returns the closest region enclosing op that is held by a non-affine operation; nullptr if there is n...
void fullyComposeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands, bool composeAffineMin=false)
Given an affine map map and its input operands, this method composes into map, maps of AffineApplyOps...
void canonicalizeSetAndOperands(IntegerSet *set, SmallVectorImpl< Value > *operands)
Canonicalizes an integer set the same way canonicalizeMapAndOperands does for affine maps.
void extractInductionVars(ArrayRef< Operation * > affineOps, SmallVectorImpl< Value > &ivs)
Extracts the induction variables from a list of either AffineForOp or AffineParallelOp and places the...
bool isValidSymbol(Value value)
Returns true if the given value can be used as a symbol in the region of the closest surrounding op t...
AffineParallelOp getAffineParallelInductionVarOwner(Value val)
Returns true if the provided value is among the induction variables of an AffineParallelOp.
Region * getAffineScope(Operation *op)
Returns the closest region enclosing op that is held by an operation with trait AffineScope; nullptr ...
ParseResult parseDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl< Value > &operands, unsigned &numDims)
Parses dimension and symbol list.
bool isAffineParallelInductionVar(Value val)
Returns true if val is the induction variable of an AffineParallelOp.
AffineMinOp makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns an AffineMinOp obtained by composing map and operands with AffineApplyOps supplying those ope...
BaseMemRefType getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the TensorType can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
AffineMap removeDuplicateExprs(AffineMap map)
Returns a map with the same dimension and symbol count as map, but whose results are the unique affin...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::optional< int64_t > getBoundForAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols, ArrayRef< std::optional< int64_t >> constLowerBounds, ArrayRef< std::optional< int64_t >> constUpperBounds, bool isUpper)
Get a lower or upper (depending on isUpper) bound for expr while using the constant lower and upper b...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ SymbolId
Symbolic identifier.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
AffineMap foldAttributesIntoMap(Builder &b, AffineMap map, ArrayRef< OpFoldResult > operands, SmallVector< Value > &remainingValues)
Fold all attributes among the given operands into the affine map.
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Canonicalize the affine map result expression order of an affine min/max operation.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Remove duplicated expressions in affine min/max ops.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Merge an affine min/max op to its consumers if its consumer is also an affine min/max op.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.