25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallBitVector.h"
27 #include "llvm/ADT/SmallVectorExtras.h"
28 #include "llvm/ADT/TypeSwitch.h"
29 #include "llvm/Support/DebugLog.h"
30 #include "llvm/Support/LogicalResult.h"
31 #include "llvm/Support/MathExtras.h"
38 using llvm::divideCeilSigned;
39 using llvm::divideFloorSigned;
42 #define DEBUG_TYPE "affine-ops"
44 #include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
51 if (
auto arg = dyn_cast<BlockArgument>(value))
52 return arg.getParentRegion() == region;
75 if (llvm::isa<BlockArgument>(value))
76 return legalityCheck(mapping.
lookup(value), dest);
83 bool isDimLikeOp = isa<ShapedDimOpInterface>(value.
getDefiningOp());
94 return llvm::all_of(values, [&](
Value v) {
101 template <
typename OpTy>
104 static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
105 AffineWriteOpInterface>::value,
106 "only ops with affine read/write interface are supported");
113 dimOperands, src, dest, mapping,
117 symbolOperands, src, dest, mapping,
134 op.getMapOperands(), src, dest, mapping,
139 op.getMapOperands(), src, dest, mapping,
166 if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
179 if (
auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
180 if (iface.hasNoEffect())
188 .Case<AffineApplyOp, AffineReadOpInterface,
189 AffineWriteOpInterface>([&](
auto op) {
214 isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
218 bool shouldAnalyzeRecursively(
Operation *op)
const final {
return true; }
226 void AffineDialect::initialize() {
229 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
231 addInterfaces<AffineInlinerInterface>();
232 declarePromisedInterfaces<ValueBoundsOpInterface, AffineApplyOp, AffineMaxOp,
241 if (
auto poison = dyn_cast<ub::PoisonAttr>(value))
242 return ub::PoisonOp::create(builder, loc, type, poison);
243 return arith::ConstantOp::materialize(builder, value, type, loc);
251 if (
auto arg = dyn_cast<BlockArgument>(value)) {
267 while (
auto *parentOp = curOp->getParentOp()) {
278 if (!isa<AffineForOp, AffineIfOp, AffineParallelOp>(parentOp))
303 auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
331 if (
auto applyOp = dyn_cast<AffineApplyOp>(op))
332 return applyOp.isValidDim(region);
335 if (isa<AffineDelinearizeIndexOp, AffineLinearizeIndexOp>(op))
336 return llvm::all_of(op->getOperands(),
337 [&](
Value arg) { return ::isValidDim(arg, region); });
340 if (
auto dimOp = dyn_cast<ShapedDimOpInterface>(op))
348 template <
typename AnyMemRefDefOp>
351 MemRefType memRefType = memrefDefOp.getType();
354 if (index >= memRefType.getRank()) {
359 if (!memRefType.isDynamicDim(index))
362 unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
363 return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
375 if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
383 if (!index.has_value())
387 Operation *op = dimOp.getShapedValue().getDefiningOp();
388 while (
auto castOp = dyn_cast<memref::CastOp>(op)) {
390 if (isa<UnrankedMemRefType>(castOp.getSource().getType()))
392 op = castOp.getSource().getDefiningOp();
397 int64_t i = index.value();
399 .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
401 .Default([](
Operation *) {
return false; });
468 if (
isPure(defOp) && llvm::all_of(defOp->getOperands(), [&](
Value operand) {
469 return affine::isValidSymbol(operand, region);
475 if (
auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp))
499 printer <<
'(' << operands.take_front(numDims) <<
')';
500 if (operands.size() > numDims)
501 printer <<
'[' << operands.drop_front(numDims) <<
']';
511 numDims = opInfos.size();
525 template <
typename OpTy>
530 for (
auto operand : operands) {
531 if (opIt++ < numDims) {
533 return op.emitOpError(
"operand cannot be used as a dimension id");
535 return op.emitOpError(
"operand cannot be used as a symbol");
546 return AffineValueMap(getAffineMap(), getOperands(), getResult());
553 AffineMapAttr mapAttr;
559 auto map = mapAttr.getValue();
561 if (map.getNumDims() != numDims ||
562 numDims + map.getNumSymbols() != result.
operands.size()) {
564 "dimension or symbol index mismatch");
567 result.
types.append(map.getNumResults(), indexTy);
572 p <<
" " << getMapAttr();
574 getAffineMap().getNumDims(), p);
585 "operand count and affine map dimension and symbol count must match");
589 return emitOpError(
"mapping must produce one value");
595 for (
Value operand : getMapOperands().drop_front(affineMap.
getNumDims())) {
597 return emitError(
"dimensional operand cannot be used as a symbol");
606 return llvm::all_of(getOperands(),
614 return llvm::all_of(getOperands(),
621 return llvm::all_of(getOperands(),
628 return llvm::all_of(getOperands(), [&](
Value operand) {
634 auto map = getAffineMap();
637 auto expr = map.getResult(0);
638 if (
auto dim = dyn_cast<AffineDimExpr>(expr))
639 return getOperand(dim.getPosition());
640 if (
auto sym = dyn_cast<AffineSymbolExpr>(expr))
641 return getOperand(map.getNumDims() + sym.getPosition());
645 bool hasPoison =
false;
647 map.constantFold(adaptor.getMapOperands(), result, &hasPoison);
650 if (failed(foldResult))
667 auto dimExpr = dyn_cast<AffineDimExpr>(e);
677 Value operand = operands[dimExpr.getPosition()];
678 int64_t operandDivisor = 1;
682 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() == 0) {
683 operandDivisor = forOp.getStepAsInt();
685 uint64_t lbLargestKnownDivisor =
686 forOp.getLowerBoundMap().getLargestKnownDivisorOfMapExprs();
687 operandDivisor = std::gcd(lbLargestKnownDivisor, forOp.getStepAsInt());
690 return operandDivisor;
697 if (
auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
698 int64_t constVal = constExpr.getValue();
699 return constVal >= 0 && constVal < k;
701 auto dimExpr = dyn_cast<AffineDimExpr>(e);
704 Value operand = operands[dimExpr.getPosition()];
708 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() >= 0 &&
709 forOp.hasConstantUpperBound() && forOp.getConstantUpperBound() <= k) {
725 auto bin = dyn_cast<AffineBinaryOpExpr>(e);
733 quotientTimesDiv = llhs;
739 quotientTimesDiv = rlhs;
749 if (forOp && forOp.hasConstantLowerBound())
750 return forOp.getConstantLowerBound();
757 if (!forOp || !forOp.hasConstantUpperBound())
762 if (forOp.hasConstantLowerBound()) {
763 return forOp.getConstantUpperBound() - 1 -
764 (forOp.getConstantUpperBound() - forOp.getConstantLowerBound() - 1) %
765 forOp.getStepAsInt();
767 return forOp.getConstantUpperBound() - 1;
778 constLowerBounds.reserve(operands.size());
779 constUpperBounds.reserve(operands.size());
780 for (
Value operand : operands) {
785 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr))
786 return constExpr.getValue();
801 constLowerBounds.reserve(operands.size());
802 constUpperBounds.reserve(operands.size());
803 for (
Value operand : operands) {
808 std::optional<int64_t> lowerBound;
809 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
810 lowerBound = constExpr.getValue();
813 constLowerBounds, constUpperBounds,
824 auto binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
835 binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
843 lhs = binExpr.getLHS();
844 rhs = binExpr.getRHS();
845 auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
849 int64_t rhsConstVal = rhsConst.getValue();
851 if (rhsConstVal <= 0)
856 std::optional<int64_t> lhsLbConst =
858 std::optional<int64_t> lhsUbConst =
860 if (lhsLbConst && lhsUbConst) {
861 int64_t lhsLbConstVal = *lhsLbConst;
862 int64_t lhsUbConstVal = *lhsUbConst;
866 divideFloorSigned(lhsLbConstVal, rhsConstVal) ==
867 divideFloorSigned(lhsUbConstVal, rhsConstVal)) {
869 divideFloorSigned(lhsLbConstVal, rhsConstVal), context);
875 divideCeilSigned(lhsLbConstVal, rhsConstVal) ==
876 divideCeilSigned(lhsUbConstVal, rhsConstVal)) {
883 lhsLbConstVal < rhsConstVal && lhsUbConstVal < rhsConstVal) {
895 if (
isQTimesDPlusR(lhs, operands, divisor, quotientTimesDiv, rem)) {
896 if (rhsConstVal % divisor == 0 &&
898 expr = quotientTimesDiv.
floorDiv(rhsConst);
899 }
else if (divisor % rhsConstVal == 0 &&
901 expr = rem % rhsConst;
927 if (operands.empty())
933 constLowerBounds.reserve(operands.size());
934 constUpperBounds.reserve(operands.size());
935 for (
Value operand : operands) {
949 if (
auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
950 lowerBounds.push_back(constExpr.getValue());
951 upperBounds.push_back(constExpr.getValue());
953 lowerBounds.push_back(
955 constLowerBounds, constUpperBounds,
957 upperBounds.push_back(
959 constLowerBounds, constUpperBounds,
968 unsigned i = exprEn.index();
970 if (lowerBounds[i] && upperBounds[i] && *lowerBounds[i] == *upperBounds[i])
975 if (!upperBounds[i]) {
976 irredundantExprs.push_back(e);
982 auto otherLowerBound = en.value();
983 unsigned pos = en.index();
984 if (pos == i || !otherLowerBound)
986 if (*otherLowerBound > *upperBounds[i])
988 if (*otherLowerBound < *upperBounds[i])
993 if (upperBounds[pos] && lowerBounds[i] &&
994 lowerBounds[i] == upperBounds[i] &&
995 otherLowerBound == *upperBounds[pos] && i < pos)
999 irredundantExprs.push_back(e);
1001 if (!lowerBounds[i]) {
1002 irredundantExprs.push_back(e);
1007 auto otherUpperBound = en.value();
1008 unsigned pos = en.index();
1009 if (pos == i || !otherUpperBound)
1011 if (*otherUpperBound < *lowerBounds[i])
1013 if (*otherUpperBound > *lowerBounds[i])
1015 if (lowerBounds[pos] && upperBounds[i] &&
1016 lowerBounds[i] == upperBounds[i] &&
1017 otherUpperBound == lowerBounds[pos] && i < pos)
1021 irredundantExprs.push_back(e);
1033 static void LLVM_ATTRIBUTE_UNUSED
1035 assert(map.
getNumInputs() == operands.size() &&
"invalid operands for map");
1041 newResults.push_back(expr);
1064 LDBG() <<
"replaceAffineMinBoundingBoxExpression: `" << minOp <<
"`";
1065 AffineMap affineMinMap = minOp.getAffineMap();
1068 for (
unsigned i = 0, e = affineMinMap.
getNumResults(); i < e; ++i) {
1072 ValueBoundsConstraintSet::ComparisonOperator::LT,
1074 minOp.getOperands())))
1083 auto it = llvm::find(dims, dim);
1084 if (it == dims.end()) {
1085 unmappedDims.push_back(i);
1092 auto it = llvm::find(syms, sym);
1093 if (it == syms.end()) {
1094 unmappedSyms.push_back(i);
1107 if (llvm::any_of(unmappedDims,
1108 [&](
unsigned i) {
return expr.isFunctionOfDim(i); }) ||
1109 llvm::any_of(unmappedSyms,
1110 [&](
unsigned i) {
return expr.isFunctionOfSymbol(i); }))
1116 repl[dimOrSym.
ceilDiv(convertedExpr)] = c1;
1118 repl[(dimOrSym + convertedExpr - 1).floorDiv(convertedExpr)] = c1;
1123 return success(*map != initialMap);
1139 unsigned dimOrSymbolPosition,
1142 bool replaceAffineMin) {
1144 bool isDimReplacement = (dimOrSymbolPosition < dims.size());
1145 unsigned pos = isDimReplacement ? dimOrSymbolPosition
1146 : dimOrSymbolPosition - dims.size();
1147 Value &v = isDimReplacement ? dims[pos] : syms[pos];
1151 if (
auto minOp = v.
getDefiningOp<AffineMinOp>(); minOp && replaceAffineMin) {
1167 AffineMap composeMap = affineApply.getAffineMap();
1168 assert(composeMap.
getNumResults() == 1 &&
"affine.apply with >1 results");
1170 affineApply.getMapOperands().end());
1184 dims.append(composeDims.begin(), composeDims.end());
1185 syms.append(composeSyms.begin(), composeSyms.end());
1186 *map = map->
replace(toReplace, replacementExpr, dims.size(), syms.size());
1196 bool composeAffineMin =
false) {
1216 for (
unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
1229 unsigned nDims = 0, nSyms = 0;
1231 dimReplacements.reserve(dims.size());
1232 symReplacements.reserve(syms.size());
1233 for (
auto *container : {&dims, &syms}) {
1234 bool isDim = (container == &dims);
1235 auto &repls = isDim ? dimReplacements : symReplacements;
1237 Value v = en.value();
1241 "map is function of unexpected expr@pos");
1247 operands->push_back(v);
1260 while (llvm::any_of(*operands, [](
Value v) {
1266 if (composeAffineMin && llvm::any_of(*operands, [](
Value v) {
1276 bool composeAffineMin) {
1281 return AffineApplyOp::create(b, loc, map, valueOperands);
1287 bool composeAffineMin) {
1292 operands, composeAffineMin);
1299 bool composeAffineMin =
false) {
1305 for (
unsigned i : llvm::seq<unsigned>(0, map.
getNumResults())) {
1313 llvm::append_range(dims,
1315 llvm::append_range(symbols,
1322 operands = llvm::to_vector(llvm::concat<Value>(dims, symbols));
1329 bool composeAffineMin) {
1330 assert(map.
getNumResults() == 1 &&
"building affine.apply with !=1 result");
1340 AffineApplyOp applyOp =
1345 for (
unsigned i = 0, e = constOperands.size(); i != e; ++i)
1350 if (failed(applyOp->fold(constOperands, foldResults)) ||
1351 foldResults.empty()) {
1353 listener->notifyOperationInserted(applyOp, {});
1354 return applyOp.getResult();
1358 return llvm::getSingleElement(foldResults);
1368 operands, composeAffineMin);
1374 bool composeAffineMin) {
1375 return llvm::map_to_vector(
1376 llvm::seq<unsigned>(0, map.
getNumResults()), [&](
unsigned i) {
1377 return makeComposedFoldedAffineApply(b, loc, map.getSubMap({i}),
1378 operands, composeAffineMin);
1382 template <
typename OpTy>
1388 return OpTy::create(b, loc, b.
getIndexType(), map, valueOperands);
1394 return makeComposedMinMax<AffineMinOp>(b, loc, map, operands);
1397 template <
typename OpTy>
1409 auto minMaxOp = makeComposedMinMax<OpTy>(newBuilder, loc, map, operands);
1413 for (
unsigned i = 0, e = constOperands.size(); i != e; ++i)
1418 if (failed(minMaxOp->fold(constOperands, foldResults)) ||
1419 foldResults.empty()) {
1421 listener->notifyOperationInserted(minMaxOp, {});
1422 return minMaxOp.getResult();
1426 return llvm::getSingleElement(foldResults);
1433 return makeComposedFoldedMinMax<AffineMinOp>(b, loc, map, operands);
1440 return makeComposedFoldedMinMax<AffineMaxOp>(b, loc, map, operands);
1445 template <
class MapOrSet>
1448 if (!mapOrSet || operands->empty())
1451 assert(mapOrSet->getNumInputs() == operands->size() &&
1452 "map/set inputs must match number of operands");
1454 auto *context = mapOrSet->getContext();
1456 resultOperands.reserve(operands->size());
1458 remappedSymbols.reserve(operands->size());
1459 unsigned nextDim = 0;
1460 unsigned nextSym = 0;
1461 unsigned oldNumSyms = mapOrSet->getNumSymbols();
1463 for (
unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
1464 if (i < mapOrSet->getNumDims()) {
1468 remappedSymbols.push_back((*operands)[i]);
1471 resultOperands.push_back((*operands)[i]);
1474 resultOperands.push_back((*operands)[i]);
1478 resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
1479 *operands = resultOperands;
1480 *mapOrSet = mapOrSet->replaceDimsAndSymbols(
1481 dimRemapping, {}, nextDim, oldNumSyms + nextSym);
1483 assert(mapOrSet->getNumInputs() == operands->size() &&
1484 "map/set inputs must match number of operands");
1493 template <
class MapOrSet>
1496 if (!mapOrSet || operands.empty())
1499 unsigned numOperands = operands.size();
1501 assert(mapOrSet.getNumInputs() == numOperands &&
1502 "map/set inputs must match number of operands");
1504 auto *context = mapOrSet.getContext();
1506 resultOperands.reserve(numOperands);
1508 remappedDims.reserve(numOperands);
1510 symOperands.reserve(mapOrSet.getNumSymbols());
1511 unsigned nextSym = 0;
1512 unsigned nextDim = 0;
1513 unsigned oldNumDims = mapOrSet.getNumDims();
1515 resultOperands.assign(operands.begin(), operands.begin() + oldNumDims);
1516 for (
unsigned i = oldNumDims, e = mapOrSet.getNumInputs(); i != e; ++i) {
1519 symRemapping[i - oldNumDims] =
1521 remappedDims.push_back(operands[i]);
1524 symOperands.push_back(operands[i]);
1528 append_range(resultOperands, remappedDims);
1529 append_range(resultOperands, symOperands);
1530 operands = resultOperands;
1531 mapOrSet = mapOrSet.replaceDimsAndSymbols(
1532 {}, symRemapping, oldNumDims + nextDim, nextSym);
1534 assert(mapOrSet.getNumInputs() == operands.size() &&
1535 "map/set inputs must match number of operands");
1539 template <
class MapOrSet>
1542 static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
1543 "Argument must be either of AffineMap or IntegerSet type");
1545 if (!mapOrSet || operands->empty())
1548 assert(mapOrSet->getNumInputs() == operands->size() &&
1549 "map/set inputs must match number of operands");
1551 canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
1552 legalizeDemotedDims<MapOrSet>(*mapOrSet, *operands);
1555 llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
1556 llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
1558 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr))
1559 usedDims[dimExpr.getPosition()] =
true;
1560 else if (
auto symExpr = dyn_cast<AffineSymbolExpr>(expr))
1561 usedSyms[symExpr.getPosition()] =
true;
1564 auto *context = mapOrSet->getContext();
1567 resultOperands.reserve(operands->size());
1569 llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
1571 unsigned nextDim = 0;
1572 for (
unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
1575 auto it = seenDims.find((*operands)[i]);
1576 if (it == seenDims.end()) {
1578 resultOperands.push_back((*operands)[i]);
1579 seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
1581 dimRemapping[i] = it->second;
1585 llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
1587 unsigned nextSym = 0;
1588 for (
unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
1594 IntegerAttr operandCst;
1595 if (
matchPattern((*operands)[i + mapOrSet->getNumDims()],
1602 auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
1603 if (it == seenSymbols.end()) {
1605 resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
1606 seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
1609 symRemapping[i] = it->second;
1612 *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
1614 *operands = resultOperands;
1619 canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
1624 canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
1631 template <
typename AffineOpTy>
1640 LogicalResult matchAndRewrite(AffineOpTy affineOp,
1643 llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
1644 AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp,
1645 AffineVectorStoreOp, AffineVectorLoadOp>::value,
1646 "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op "
1648 auto map = affineOp.getAffineMap();
1650 auto oldOperands = affineOp.getMapOperands();
1655 if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
1656 resultOperands.begin()))
1659 replaceAffineOp(rewriter, affineOp, map, resultOperands);
1667 void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
1674 void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
1678 prefetch, prefetch.getMemref(), map, mapOperands, prefetch.getIsWrite(),
1679 prefetch.getLocalityHint(), prefetch.getIsDataCache());
1682 void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
1686 store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
1689 void SimplifyAffineOp<AffineVectorLoadOp>::replaceAffineOp(
1693 vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map,
1697 void SimplifyAffineOp<AffineVectorStoreOp>::replaceAffineOp(
1701 vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map,
1706 template <
typename AffineOpTy>
1707 void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
1716 results.
add<SimplifyAffineOp<AffineApplyOp>>(context);
1751 Value elementsPerStride) {
1753 build(builder, state, srcMemRef, srcMap, srcIndices, destMemRef, dstMap,
1754 destIndices, tagMemRef, tagMap, tagIndices, numElements, stride,
1756 auto result = dyn_cast<AffineDmaStartOp>(builder.
create(state));
1757 assert(result &&
"builder didn't return the right type");
1766 Value elementsPerStride) {
1767 return create(builder, builder.
getLoc(), srcMemRef, srcMap, srcIndices,
1768 destMemRef, dstMap, destIndices, tagMemRef, tagMap, tagIndices,
1769 numElements, stride, elementsPerStride);
1773 p <<
" " << getSrcMemRef() <<
'[';
1775 p <<
"], " << getDstMemRef() <<
'[';
1777 p <<
"], " << getTagMemRef() <<
'[';
1782 p <<
", " << getNumElementsPerStride();
1784 p <<
" : " << getSrcMemRefType() <<
", " << getDstMemRefType() <<
", "
1785 << getTagMemRefType();
1797 AffineMapAttr srcMapAttr;
1800 AffineMapAttr dstMapAttr;
1803 AffineMapAttr tagMapAttr;
1818 getSrcMapAttrStrName(),
1822 getDstMapAttrStrName(),
1826 getTagMapAttrStrName(),
1835 if (!strideInfo.empty() && strideInfo.size() != 2) {
1837 "expected two stride related operands");
1839 bool isStrided = strideInfo.size() == 2;
1844 if (types.size() != 3)
1862 if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
1863 dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
1864 tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1866 "memref operand count not equal to map.numInputs");
1870 LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
1871 if (!llvm::isa<MemRefType>(getOperand(getSrcMemRefOperandIndex()).
getType()))
1872 return emitOpError(
"expected DMA source to be of memref type");
1873 if (!llvm::isa<MemRefType>(getOperand(getDstMemRefOperandIndex()).
getType()))
1874 return emitOpError(
"expected DMA destination to be of memref type");
1875 if (!llvm::isa<MemRefType>(getOperand(getTagMemRefOperandIndex()).
getType()))
1876 return emitOpError(
"expected DMA tag to be of memref type");
1878 unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
1879 getDstMap().getNumInputs() +
1880 getTagMap().getNumInputs();
1881 if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
1882 getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
1883 return emitOpError(
"incorrect number of operands");
1887 for (
auto idx : getSrcIndices()) {
1888 if (!idx.getType().isIndex())
1889 return emitOpError(
"src index to dma_start must have 'index' type");
1892 "src index must be a valid dimension or symbol identifier");
1894 for (
auto idx : getDstIndices()) {
1895 if (!idx.getType().isIndex())
1896 return emitOpError(
"dst index to dma_start must have 'index' type");
1899 "dst index must be a valid dimension or symbol identifier");
1901 for (
auto idx : getTagIndices()) {
1902 if (!idx.getType().isIndex())
1903 return emitOpError(
"tag index to dma_start must have 'index' type");
1906 "tag index must be a valid dimension or symbol identifier");
1917 void AffineDmaStartOp::getEffects(
1945 Value numElements) {
1947 build(builder, state, tagMemRef, tagMap, tagIndices, numElements);
1948 auto result = dyn_cast<AffineDmaWaitOp>(builder.
create(state));
1949 assert(result &&
"builder didn't return the right type");
1956 Value numElements) {
1957 return create(builder, builder.
getLoc(), tagMemRef, tagMap, tagIndices,
1962 p <<
" " << getTagMemRef() <<
'[';
1967 p <<
" : " << getTagMemRef().getType();
1978 AffineMapAttr tagMapAttr;
1987 getTagMapAttrStrName(),
1996 if (!llvm::isa<MemRefType>(type))
1998 "expected tag to be of memref type");
2000 if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
2002 "tag memref operand count != to map.numInputs");
2006 LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
2007 if (!llvm::isa<MemRefType>(getOperand(0).
getType()))
2008 return emitOpError(
"expected DMA tag to be of memref type");
2010 for (
auto idx : getTagIndices()) {
2011 if (!idx.getType().isIndex())
2012 return emitOpError(
"index to dma_wait must have 'index' type");
2015 "index must be a valid dimension or symbol identifier");
2026 void AffineDmaWaitOp::getEffects(
2042 ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
2043 assert(((!lbMap && lbOperands.empty()) ||
2045 "lower bound operand count does not match the affine map");
2046 assert(((!ubMap && ubOperands.empty()) ||
2048 "upper bound operand count does not match the affine map");
2049 assert(step > 0 &&
"step has to be a positive integer constant");
2055 getOperandSegmentSizeAttr(),
2057 static_cast<int32_t>(ubOperands.size()),
2058 static_cast<int32_t>(iterArgs.size())}));
2060 for (
Value val : iterArgs)
2082 Value inductionVar =
2084 for (
Value val : iterArgs)
2085 bodyBlock->
addArgument(val.getType(), val.getLoc());
2090 if (iterArgs.empty() && !bodyBuilder) {
2091 ensureTerminator(*bodyRegion, builder, result.
location);
2092 }
else if (bodyBuilder) {
2095 bodyBuilder(builder, result.
location, inductionVar,
2101 int64_t ub, int64_t step,
ValueRange iterArgs,
2102 BodyBuilderFn bodyBuilder) {
2105 return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,
2109 LogicalResult AffineForOp::verifyRegions() {
2112 auto *body = getBody();
2113 if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
2114 return emitOpError(
"expected body to have a single index argument for the "
2115 "induction variable");
2119 if (getLowerBoundMap().getNumInputs() > 0)
2121 getLowerBoundMap().getNumDims())))
2124 if (getUpperBoundMap().getNumInputs() > 0)
2126 getUpperBoundMap().getNumDims())))
2128 if (getLowerBoundMap().getNumResults() < 1)
2129 return emitOpError(
"expected lower bound map to have at least one result");
2130 if (getUpperBoundMap().getNumResults() < 1)
2131 return emitOpError(
"expected upper bound map to have at least one result");
2133 unsigned opNumResults = getNumResults();
2134 if (opNumResults == 0)
2140 if (getNumIterOperands() != opNumResults)
2142 "mismatch between the number of loop-carried values and results");
2143 if (getNumRegionIterArgs() != opNumResults)
2145 "mismatch between the number of basic block args and results");
2155 bool failedToParsedMinMax =
2159 auto boundAttrStrName =
2160 isLower ? AffineForOp::getLowerBoundMapAttrName(result.
name)
2161 : AffineForOp::getUpperBoundMapAttrName(result.
name);
2168 if (!boundOpInfos.empty()) {
2170 if (boundOpInfos.size() > 1)
2172 "expected only one loop bound operand");
2197 if (
auto affineMapAttr = dyn_cast<AffineMapAttr>(boundAttr)) {
2198 unsigned currentNumOperands = result.
operands.size();
2203 auto map = affineMapAttr.getValue();
2207 "dim operand count and affine map dim count must match");
2209 unsigned numDimAndSymbolOperands =
2210 result.
operands.size() - currentNumOperands;
2211 if (numDims + map.
getNumSymbols() != numDimAndSymbolOperands)
2214 "symbol operand count and affine map symbol count must match");
2220 return p.
emitError(attrLoc,
"lower loop bound affine map with "
2221 "multiple results requires 'max' prefix");
2223 return p.
emitError(attrLoc,
"upper loop bound affine map with multiple "
2224 "results requires 'min' prefix");
2230 if (
auto integerAttr = dyn_cast<IntegerAttr>(boundAttr)) {
2240 "expected valid affine map representation for loop bounds");
2252 int64_t numOperands = result.
operands.size();
2255 int64_t numLbOperands = result.
operands.size() - numOperands;
2258 numOperands = result.
operands.size();
2261 int64_t numUbOperands = result.
operands.size() - numOperands;
2266 getStepAttrName(result.
name),
2270 IntegerAttr stepAttr;
2272 getStepAttrName(result.
name).data(),
2276 if (stepAttr.getValue().isNegative())
2279 "expected step to be representable as a positive signed integer");
2287 regionArgs.push_back(inductionVariable);
2295 for (
auto argOperandType :
2296 llvm::zip(llvm::drop_begin(regionArgs), operands, result.
types)) {
2297 Type type = std::get<2>(argOperandType);
2298 std::get<0>(argOperandType).type = type;
2306 getOperandSegmentSizeAttr(),
2308 static_cast<int32_t>(numUbOperands),
2309 static_cast<int32_t>(operands.size())}));
2313 if (regionArgs.size() != result.
types.size() + 1)
2316 "mismatch between the number of loop-carried values and results");
2320 AffineForOp::ensureTerminator(*body, builder, result.
location);
2342 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
2343 p << constExpr.getValue();
2351 if (isa<AffineSymbolExpr>(expr)) {
2367 unsigned AffineForOp::getNumIterOperands() {
2368 AffineMap lbMap = getLowerBoundMapAttr().getValue();
2369 AffineMap ubMap = getUpperBoundMapAttr().getValue();
2374 std::optional<MutableArrayRef<OpOperand>>
2375 AffineForOp::getYieldedValuesMutable() {
2376 return cast<AffineYieldOp>(getBody()->getTerminator()).getOperandsMutable();
2388 if (getStepAsInt() != 1)
2389 p <<
" step " << getStepAsInt();
2391 bool printBlockTerminators =
false;
2392 if (getNumIterOperands() > 0) {
2394 auto regionArgs = getRegionIterArgs();
2395 auto operands = getInits();
2397 llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](
auto it) {
2398 p << std::get<0>(it) <<
" = " << std::get<1>(it);
2400 p <<
") -> (" << getResultTypes() <<
")";
2401 printBlockTerminators =
true;
2406 printBlockTerminators);
2408 (*this)->getAttrs(),
2409 {getLowerBoundMapAttrName(getOperation()->getName()),
2410 getUpperBoundMapAttrName(getOperation()->getName()),
2411 getStepAttrName(getOperation()->getName()),
2412 getOperandSegmentSizeAttr()});
2417 auto foldLowerOrUpperBound = [&forOp](
bool lower) {
2421 auto boundOperands =
2422 lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
2423 for (
auto operand : boundOperands) {
2426 operandConstants.push_back(operandCst);
2430 lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
2432 "bound maps should have at least one result");
2434 if (failed(boundMap.
constantFold(operandConstants, foldedResults)))
2438 assert(!foldedResults.empty() &&
"bounds should have at least one result");
2439 auto maxOrMin = llvm::cast<IntegerAttr>(foldedResults[0]).getValue();
2440 for (
unsigned i = 1, e = foldedResults.size(); i < e; i++) {
2441 auto foldedResult = llvm::cast<IntegerAttr>(foldedResults[i]).getValue();
2442 maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
2443 : llvm::APIntOps::smin(maxOrMin, foldedResult);
2445 lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
2446 : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
2451 bool folded =
false;
2452 if (!forOp.hasConstantLowerBound())
2453 folded |= succeeded(foldLowerOrUpperBound(
true));
2456 if (!forOp.hasConstantUpperBound())
2457 folded |= succeeded(foldLowerOrUpperBound(
false));
2458 return success(folded);
2466 auto lbMap = forOp.getLowerBoundMap();
2467 auto ubMap = forOp.getUpperBoundMap();
2468 auto prevLbMap = lbMap;
2469 auto prevUbMap = ubMap;
2482 if (lbMap == prevLbMap && ubMap == prevUbMap)
2485 if (lbMap != prevLbMap)
2486 forOp.setLowerBound(lbOperands, lbMap);
2487 if (ubMap != prevUbMap)
2488 forOp.setUpperBound(ubOperands, ubMap);
2494 static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
2495 int64_t step = forOp.getStepAsInt();
2496 if (!forOp.hasConstantBounds() || step <= 0)
2497 return std::nullopt;
2498 int64_t lb = forOp.getConstantLowerBound();
2499 int64_t ub = forOp.getConstantUpperBound();
2500 return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
2508 LogicalResult matchAndRewrite(AffineForOp forOp,
2511 if (!llvm::hasSingleElement(*forOp.getBody()))
2513 if (forOp.getNumResults() == 0)
2515 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
2516 if (tripCount == 0) {
2519 rewriter.
replaceOp(forOp, forOp.getInits());
2523 auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
2524 auto iterArgs = forOp.getRegionIterArgs();
2525 bool hasValDefinedOutsideLoop =
false;
2526 bool iterArgsNotInOrder =
false;
2527 for (
unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
2528 Value val = yieldOp.getOperand(i);
2529 auto *iterArgIt = llvm::find(iterArgs, val);
2532 if (val == forOp.getInductionVar())
2534 if (iterArgIt == iterArgs.end()) {
2536 assert(forOp.isDefinedOutsideOfLoop(val) &&
2537 "must be defined outside of the loop");
2538 hasValDefinedOutsideLoop =
true;
2539 replacements.push_back(val);
2541 unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
2543 iterArgsNotInOrder =
true;
2544 replacements.push_back(forOp.getInits()[pos]);
2549 if (!tripCount.has_value() &&
2550 (hasValDefinedOutsideLoop || iterArgsNotInOrder))
2554 if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
2556 rewriter.
replaceOp(forOp, replacements);
2564 results.
add<AffineForEmptyLoopFolder>(context);
2568 assert((point.
isParent() || point == getRegion()) &&
"invalid region point");
2575 void AffineForOp::getSuccessorRegions(
2577 assert((point.
isParent() || point == getRegion()) &&
"expected loop region");
2582 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*
this);
2583 if (point.
isParent() && tripCount.has_value()) {
2584 if (tripCount.value() > 0) {
2585 regions.push_back(
RegionSuccessor(&getRegion(), getRegionIterArgs()));
2588 if (tripCount.value() == 0) {
2596 if (!point.
isParent() && tripCount == 1) {
2603 regions.push_back(
RegionSuccessor(&getRegion(), getRegionIterArgs()));
2609 return getTrivialConstantTripCount(op) == 0;
2612 LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
2622 results.assign(getInits().begin(), getInits().end());
2625 return success(folded);
2638 assert(map.
getNumResults() >= 1 &&
"bound map has at least one result");
2639 getLowerBoundOperandsMutable().assign(lbOperands);
2640 setLowerBoundMap(map);
2645 assert(map.
getNumResults() >= 1 &&
"bound map has at least one result");
2646 getUpperBoundOperandsMutable().assign(ubOperands);
2647 setUpperBoundMap(map);
2650 bool AffineForOp::hasConstantLowerBound() {
2651 return getLowerBoundMap().isSingleConstant();
2654 bool AffineForOp::hasConstantUpperBound() {
2655 return getUpperBoundMap().isSingleConstant();
2658 int64_t AffineForOp::getConstantLowerBound() {
2659 return getLowerBoundMap().getSingleConstantResult();
2662 int64_t AffineForOp::getConstantUpperBound() {
2663 return getUpperBoundMap().getSingleConstantResult();
2666 void AffineForOp::setConstantLowerBound(int64_t value) {
2670 void AffineForOp::setConstantUpperBound(int64_t value) {
2674 AffineForOp::operand_range AffineForOp::getControlOperands() {
2679 bool AffineForOp::matchingBoundOperandList() {
2680 auto lbMap = getLowerBoundMap();
2681 auto ubMap = getUpperBoundMap();
2687 for (
unsigned i = 0, e = lbMap.
getNumInputs(); i < e; i++) {
2689 if (getOperand(i) != getOperand(numOperands + i))
2697 std::optional<SmallVector<Value>> AffineForOp::getLoopInductionVars() {
2701 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopLowerBounds() {
2702 if (!hasConstantLowerBound())
2703 return std::nullopt;
2706 OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))};
2709 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopSteps() {
2715 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopUpperBounds() {
2716 if (!hasConstantUpperBound())
2720 OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()))};
2723 FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
2725 bool replaceInitOperandUsesInLoop,
2730 auto inits = llvm::to_vector(getInits());
2731 inits.append(newInitOperands.begin(), newInitOperands.end());
2732 AffineForOp newLoop = AffineForOp::create(
2737 auto yieldOp = cast<AffineYieldOp>(getBody()->getTerminator());
2739 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
2744 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
2745 assert(newInitOperands.size() == newYieldedValues.size() &&
2746 "expected as many new yield values as new iter operands");
2748 yieldOp.getOperandsMutable().append(newYieldedValues);
2753 rewriter.
mergeBlocks(getBody(), newLoop.getBody(),
2754 newLoop.getBody()->getArguments().take_front(
2755 getBody()->getNumArguments()));
2757 if (replaceInitOperandUsesInLoop) {
2760 for (
auto it : llvm::zip(newInitOperands, newIterArgs)) {
2771 newLoop->getResults().take_front(getNumResults()));
2772 return cast<LoopLikeOpInterface>(newLoop.getOperation());
2800 auto ivArg = dyn_cast<BlockArgument>(val);
2801 if (!ivArg || !ivArg.getOwner() || !ivArg.getOwner()->getParent())
2802 return AffineForOp();
2804 ivArg.getOwner()->getParent()->getParentOfType<AffineForOp>())
2806 return forOp.getInductionVar() == val ? forOp : AffineForOp();
2807 return AffineForOp();
2811 auto ivArg = dyn_cast<BlockArgument>(val);
2812 if (!ivArg || !ivArg.getOwner())
2815 auto parallelOp = dyn_cast_if_present<AffineParallelOp>(containingOp);
2816 if (parallelOp && llvm::is_contained(parallelOp.getIVs(), val))
2825 ivs->reserve(forInsts.size());
2826 for (
auto forInst : forInsts)
2827 ivs->push_back(forInst.getInductionVar());
2832 ivs.reserve(affineOps.size());
2835 if (
auto forOp = dyn_cast<AffineForOp>(op))
2836 ivs.push_back(forOp.getInductionVar());
2837 else if (
auto parallelOp = dyn_cast<AffineParallelOp>(op))
2838 for (
size_t i = 0; i < parallelOp.getBody()->getNumArguments(); i++)
2839 ivs.push_back(parallelOp.getBody()->getArgument(i));
2845 template <
typename BoundListTy,
typename LoopCreatorTy>
2850 LoopCreatorTy &&loopCreatorFn) {
2851 assert(lbs.size() == ubs.size() &&
"Mismatch in number of arguments");
2852 assert(lbs.size() == steps.size() &&
"Mismatch in number of arguments");
2864 ivs.reserve(lbs.size());
2865 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
2871 if (i == e - 1 && bodyBuilderFn) {
2873 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2875 AffineYieldOp::create(nestedBuilder, nestedLoc);
2880 auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
2888 int64_t ub, int64_t step,
2889 AffineForOp::BodyBuilderFn bodyBuilderFn) {
2890 return AffineForOp::create(builder, loc, lb, ub, step,
2898 AffineForOp::BodyBuilderFn bodyBuilderFn) {
2901 if (lbConst && ubConst)
2903 ubConst.value(), step, bodyBuilderFn);
2934 LogicalResult matchAndRewrite(AffineIfOp ifOp,
2936 if (ifOp.getElseRegion().empty() ||
2937 !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
2952 LogicalResult matchAndRewrite(AffineIfOp op,
2955 auto isTriviallyFalse = [](
IntegerSet iSet) {
2956 return iSet.isEmptyIntegerSet();
2960 return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 &&
2961 iSet.getConstraint(0) == 0);
2964 IntegerSet affineIfConditions = op.getIntegerSet();
2966 if (isTriviallyFalse(affineIfConditions)) {
2970 if (op.getNumResults() == 0 && !op.hasElse()) {
2976 blockToMove = op.getElseBlock();
2977 }
else if (isTriviallyTrue(affineIfConditions)) {
2978 blockToMove = op.getThenBlock();
2996 rewriter.
eraseOp(blockToMoveTerminator);
3004 void AffineIfOp::getSuccessorRegions(
3013 if (getElseRegion().empty()) {
3014 regions.push_back(getResults());
3030 auto conditionAttr =
3031 (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
3033 return emitOpError(
"requires an integer set attribute named 'condition'");
3036 IntegerSet condition = conditionAttr.getValue();
3038 return emitOpError(
"operand count and condition integer set dimension and "
3039 "symbol count must match");
3051 IntegerSetAttr conditionAttr;
3054 AffineIfOp::getConditionAttrStrName(),
3060 auto set = conditionAttr.getValue();
3061 if (set.getNumDims() != numDims)
3064 "dim operand count and integer set dim count must match");
3065 if (numDims + set.getNumSymbols() != result.
operands.size())
3068 "symbol operand count and integer set symbol count must match");
3082 AffineIfOp::ensureTerminator(*thenRegion, parser.
getBuilder(),
3089 AffineIfOp::ensureTerminator(*elseRegion, parser.
getBuilder(),
3101 auto conditionAttr =
3102 (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
3103 p <<
" " << conditionAttr;
3105 conditionAttr.getValue().getNumDims(), p);
3112 auto &elseRegion = this->getElseRegion();
3113 if (!elseRegion.
empty()) {
3122 getConditionAttrStrName());
3127 ->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName())
3131 void AffineIfOp::setIntegerSet(
IntegerSet newSet) {
3137 (*this)->setOperands(operands);
3142 bool withElseRegion) {
3143 assert(resultTypes.empty() || withElseRegion);
3152 if (resultTypes.empty())
3153 AffineIfOp::ensureTerminator(*thenRegion, builder, result.
location);
3156 if (withElseRegion) {
3158 if (resultTypes.empty())
3159 AffineIfOp::ensureTerminator(*elseRegion, builder, result.
location);
3165 AffineIfOp::build(builder, result, {}, set, args,
3174 bool composeAffineMin =
false) {
3181 if (llvm::none_of(operands,
3192 auto set = getIntegerSet();
3198 if (getIntegerSet() == set && llvm::equal(operands, getOperands()))
3201 setConditional(set, operands);
3207 results.
add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context);
3216 assert(operands.size() == 1 + map.
getNumInputs() &&
"inconsistent operands");
3220 auto memrefType = llvm::cast<MemRefType>(operands[0].
getType());
3221 result.
types.push_back(memrefType.getElementType());
3226 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
3229 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
3231 result.
types.push_back(memrefType.getElementType());
3236 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
3237 int64_t rank = memrefType.getRank();
3242 build(builder, result, memref, map, indices);
3251 AffineMapAttr mapAttr;
3256 AffineLoadOp::getMapAttrStrName(),
3266 p <<
" " << getMemRef() <<
'[';
3267 if (AffineMapAttr mapAttr =
3268 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3272 {getMapAttrStrName()});
3278 template <
typename AffineMemOpTy>
3279 static LogicalResult
3282 MemRefType memrefType,
unsigned numIndexOperands) {
3285 return op->emitOpError(
"affine map num results must equal memref rank");
3287 return op->emitOpError(
"expects as many subscripts as affine map inputs");
3289 for (
auto idx : mapOperands) {
3290 if (!idx.getType().isIndex())
3291 return op->emitOpError(
"index to load must have 'index' type");
3301 if (
getType() != memrefType.getElementType())
3302 return emitOpError(
"result type must match element type of memref");
3305 *
this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3306 getMapOperands(), memrefType,
3307 getNumOperands() - 1)))
3315 results.
add<SimplifyAffineOp<AffineLoadOp>>(context);
3324 auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
3331 auto global = dyn_cast_or_null<memref::GlobalOp>(
3338 dyn_cast_or_null<DenseElementsAttr>(global.getConstantInitValue());
3342 if (
auto splatAttr = dyn_cast<SplatElementsAttr>(cstAttr))
3343 return splatAttr.getSplatValue<
Attribute>();
3345 if (!getAffineMap().isConstant())
3347 auto indices = llvm::to_vector<4>(
3348 llvm::map_range(getAffineMap().getConstantResults(),
3349 [](int64_t v) -> uint64_t {
return v; }));
3350 return cstAttr.getValues<
Attribute>()[indices];
3360 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
3371 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
3372 int64_t rank = memrefType.getRank();
3377 build(builder, result, valueToStore, memref, map, indices);
3386 AffineMapAttr mapAttr;
3391 mapOperands, mapAttr, AffineStoreOp::getMapAttrStrName(),
3402 p <<
" " << getValueToStore();
3403 p <<
", " << getMemRef() <<
'[';
3404 if (AffineMapAttr mapAttr =
3405 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3409 {getMapAttrStrName()});
3416 if (getValueToStore().
getType() != memrefType.getElementType())
3418 "value to store must have the same type as memref element type");
3421 *
this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3422 getMapOperands(), memrefType,
3423 getNumOperands() - 2)))
3431 results.
add<SimplifyAffineOp<AffineStoreOp>>(context);
3434 LogicalResult AffineStoreOp::fold(FoldAdaptor adaptor,
3444 template <
typename T>
3447 if (op.getNumOperands() !=
3448 op.getMap().getNumDims() + op.getMap().getNumSymbols())
3449 return op.emitOpError(
3450 "operand count and affine map dimension and symbol count must match");
3452 if (op.getMap().getNumResults() == 0)
3453 return op.emitOpError(
"affine map expect at least one result");
3457 template <
typename T>
3459 p <<
' ' << op->getAttr(T::getMapAttrStrName());
3460 auto operands = op.getOperands();
3461 unsigned numDims = op.getMap().getNumDims();
3462 p <<
'(' << operands.take_front(numDims) <<
')';
3464 if (operands.size() != numDims)
3465 p <<
'[' << operands.drop_front(numDims) <<
']';
3467 {T::getMapAttrStrName()});
3470 template <
typename T>
3477 AffineMapAttr mapAttr;
3493 template <
typename T>
3495 static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
3496 "expected affine min or max op");
3502 auto foldedMap = op.getMap().partialConstantFold(operands, &results);
3504 if (foldedMap.getNumSymbols() == 1 && foldedMap.isSymbolIdentity())
3505 return op.getOperand(0);
3508 if (results.empty()) {
3510 if (foldedMap == op.getMap())
3513 return op.getResult();
3517 auto resultIt = std::is_same<T, AffineMinOp>::value
3518 ? llvm::min_element(results)
3519 : llvm::max_element(results);
3520 if (resultIt == results.end())
3526 template <
typename T>
3532 AffineMap oldMap = affineOp.getAffineMap();
3538 if (!llvm::is_contained(newExprs, expr))
3539 newExprs.push_back(expr);
3569 template <
typename T>
3575 AffineMap oldMap = affineOp.getAffineMap();
3577 affineOp.getMapOperands().take_front(oldMap.
getNumDims());
3579 affineOp.getMapOperands().take_back(oldMap.
getNumSymbols());
3581 auto newDimOperands = llvm::to_vector<8>(dimOperands);
3582 auto newSymOperands = llvm::to_vector<8>(symOperands);
3590 if (
auto symExpr = dyn_cast<AffineSymbolExpr>(expr)) {
3591 Value symValue = symOperands[symExpr.getPosition()];
3593 producerOps.push_back(producerOp);
3596 }
else if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
3597 Value dimValue = dimOperands[dimExpr.getPosition()];
3599 producerOps.push_back(producerOp);
3606 newExprs.push_back(expr);
3609 if (producerOps.empty())
3616 for (T producerOp : producerOps) {
3617 AffineMap producerMap = producerOp.getAffineMap();
3618 unsigned numProducerDims = producerMap.
getNumDims();
3623 producerOp.getMapOperands().take_front(numProducerDims);
3625 producerOp.getMapOperands().take_back(numProducerSyms);
3626 newDimOperands.append(dimValues.begin(), dimValues.end());
3627 newSymOperands.append(symValues.begin(), symValues.end());
3631 newExprs.push_back(expr.shiftDims(numProducerDims, numUsedDims)
3632 .shiftSymbols(numProducerSyms, numUsedSyms));
3635 numUsedDims += numProducerDims;
3636 numUsedSyms += numProducerSyms;
3642 llvm::to_vector<8>(llvm::concat<Value>(newDimOperands, newSymOperands));
3661 if (!resultExpr.isPureAffine())
3666 if (failed(flattenResult))
3679 if (llvm::is_sorted(flattenedExprs))
3684 llvm::to_vector(llvm::seq<unsigned>(0, map.
getNumResults()));
3685 llvm::sort(resultPermutation, [&](
unsigned lhs,
unsigned rhs) {
3686 return flattenedExprs[lhs] < flattenedExprs[rhs];
3689 for (
unsigned idx : resultPermutation)
3710 template <
typename T>
3716 AffineMap map = affineOp.getAffineMap();
3724 template <
typename T>
3730 if (affineOp.getMap().getNumResults() != 1)
3733 affineOp.getOperands());
3761 return parseAffineMinMaxOp<AffineMinOp>(parser, result);
3789 return parseAffineMinMaxOp<AffineMaxOp>(parser, result);
3808 IntegerAttr hintInfo;
3810 StringRef readOrWrite, cacheType;
3812 AffineMapAttr mapAttr;
3816 AffinePrefetchOp::getMapAttrStrName(),
3822 AffinePrefetchOp::getLocalityHintAttrStrName(),
3832 if (readOrWrite !=
"read" && readOrWrite !=
"write")
3834 "rw specifier has to be 'read' or 'write'");
3835 result.
addAttribute(AffinePrefetchOp::getIsWriteAttrStrName(),
3838 if (cacheType !=
"data" && cacheType !=
"instr")
3840 "cache type has to be 'data' or 'instr'");
3842 result.
addAttribute(AffinePrefetchOp::getIsDataCacheAttrStrName(),
3849 p <<
" " << getMemref() <<
'[';
3850 AffineMapAttr mapAttr =
3851 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3854 p <<
']' <<
", " << (getIsWrite() ?
"write" :
"read") <<
", "
3855 <<
"locality<" << getLocalityHint() <<
">, "
3856 << (getIsDataCache() ?
"data" :
"instr");
3858 (*this)->getAttrs(),
3859 {getMapAttrStrName(), getLocalityHintAttrStrName(),
3860 getIsDataCacheAttrStrName(), getIsWriteAttrStrName()});
3865 auto mapAttr = (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3869 return emitOpError(
"affine.prefetch affine map num results must equal"
3872 return emitOpError(
"too few operands");
3874 if (getNumOperands() != 1)
3875 return emitOpError(
"too few operands");
3879 for (
auto idx : getMapOperands()) {
3882 "index must be a valid dimension or symbol identifier");
3890 results.
add<SimplifyAffineOp<AffinePrefetchOp>>(context);
3893 LogicalResult AffinePrefetchOp::fold(FoldAdaptor adaptor,
3908 auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) {
3912 build(builder, result, resultTypes, reductions, lbs, {}, ubs,
3922 assert(llvm::all_of(lbMaps,
3924 return m.
getNumDims() == lbMaps[0].getNumDims() &&
3927 "expected all lower bounds maps to have the same number of dimensions "
3929 assert(llvm::all_of(ubMaps,
3931 return m.
getNumDims() == ubMaps[0].getNumDims() &&
3934 "expected all upper bounds maps to have the same number of dimensions "
3936 assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) &&
3937 "expected lower bound maps to have as many inputs as lower bound "
3939 assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) &&
3940 "expected upper bound maps to have as many inputs as upper bound "
3948 for (arith::AtomicRMWKind reduction : reductions)
3949 reductionAttrs.push_back(
3961 groups.reserve(groups.size() + maps.size());
3962 exprs.reserve(maps.size());
3967 return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs,
3973 AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups);
3974 AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups);
3992 for (
unsigned i = 0, e = steps.size(); i < e; ++i)
3994 if (resultTypes.empty())
3995 ensureTerminator(*bodyRegion, builder, result.
location);
3999 return {&getRegion()};
4002 unsigned AffineParallelOp::getNumDims() {
return getSteps().size(); }
4004 AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
4005 return getOperands().take_front(getLowerBoundsMap().getNumInputs());
4008 AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
4009 return getOperands().drop_front(getLowerBoundsMap().getNumInputs());
4012 AffineMap AffineParallelOp::getLowerBoundMap(
unsigned pos) {
4013 auto values = getLowerBoundsGroups().getValues<int32_t>();
4015 for (
unsigned i = 0; i < pos; ++i)
4017 return getLowerBoundsMap().getSliceMap(start, values[pos]);
4020 AffineMap AffineParallelOp::getUpperBoundMap(
unsigned pos) {
4021 auto values = getUpperBoundsGroups().getValues<int32_t>();
4023 for (
unsigned i = 0; i < pos; ++i)
4025 return getUpperBoundsMap().getSliceMap(start, values[pos]);
4029 return AffineValueMap(getLowerBoundsMap(), getLowerBoundsOperands());
4033 return AffineValueMap(getUpperBoundsMap(), getUpperBoundsOperands());
4036 std::optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
4037 if (hasMinMaxBounds())
4038 return std::nullopt;
4043 AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
4046 for (
unsigned i = 0, e = rangesValueMap.
getNumResults(); i < e; ++i) {
4047 auto expr = rangesValueMap.
getResult(i);
4048 auto cst = dyn_cast<AffineConstantExpr>(expr);
4050 return std::nullopt;
4051 out.push_back(cst.getValue());
4056 Block *AffineParallelOp::getBody() {
return &getRegion().
front(); }
4058 OpBuilder AffineParallelOp::getBodyBuilder() {
4059 return OpBuilder(getBody(), std::prev(getBody()->end()));
4064 "operands to map must match number of inputs");
4066 auto ubOperands = getUpperBoundsOperands();
4069 newOperands.append(ubOperands.begin(), ubOperands.end());
4070 (*this)->setOperands(newOperands);
4077 "operands to map must match number of inputs");
4080 newOperands.append(ubOperands.begin(), ubOperands.end());
4081 (*this)->setOperands(newOperands);
4087 setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
4092 arith::AtomicRMWKind op) {
4094 case arith::AtomicRMWKind::addf:
4095 return isa<FloatType>(resultType);
4096 case arith::AtomicRMWKind::addi:
4097 return isa<IntegerType>(resultType);
4098 case arith::AtomicRMWKind::assign:
4100 case arith::AtomicRMWKind::mulf:
4101 return isa<FloatType>(resultType);
4102 case arith::AtomicRMWKind::muli:
4103 return isa<IntegerType>(resultType);
4104 case arith::AtomicRMWKind::maximumf:
4105 return isa<FloatType>(resultType);
4106 case arith::AtomicRMWKind::minimumf:
4107 return isa<FloatType>(resultType);
4108 case arith::AtomicRMWKind::maxs: {
4109 auto intType = dyn_cast<IntegerType>(resultType);
4110 return intType && intType.isSigned();
4112 case arith::AtomicRMWKind::mins: {
4113 auto intType = dyn_cast<IntegerType>(resultType);
4114 return intType && intType.isSigned();
4116 case arith::AtomicRMWKind::maxu: {
4117 auto intType = dyn_cast<IntegerType>(resultType);
4118 return intType && intType.isUnsigned();
4120 case arith::AtomicRMWKind::minu: {
4121 auto intType = dyn_cast<IntegerType>(resultType);
4122 return intType && intType.isUnsigned();
4124 case arith::AtomicRMWKind::ori:
4125 return isa<IntegerType>(resultType);
4126 case arith::AtomicRMWKind::andi:
4127 return isa<IntegerType>(resultType);
4134 auto numDims = getNumDims();
4137 getSteps().size() != numDims || getBody()->getNumArguments() != numDims) {
4138 return emitOpError() <<
"the number of region arguments ("
4139 << getBody()->getNumArguments()
4140 <<
") and the number of map groups for lower ("
4141 << getLowerBoundsGroups().getNumElements()
4142 <<
") and upper bound ("
4143 << getUpperBoundsGroups().getNumElements()
4144 <<
"), and the number of steps (" << getSteps().size()
4145 <<
") must all match";
4148 unsigned expectedNumLBResults = 0;
4149 for (APInt v : getLowerBoundsGroups()) {
4150 unsigned results = v.getZExtValue();
4152 return emitOpError()
4153 <<
"expected lower bound map to have at least one result";
4154 expectedNumLBResults += results;
4156 if (expectedNumLBResults != getLowerBoundsMap().getNumResults())
4157 return emitOpError() <<
"expected lower bounds map to have "
4158 << expectedNumLBResults <<
" results";
4159 unsigned expectedNumUBResults = 0;
4160 for (APInt v : getUpperBoundsGroups()) {
4161 unsigned results = v.getZExtValue();
4163 return emitOpError()
4164 <<
"expected upper bound map to have at least one result";
4165 expectedNumUBResults += results;
4167 if (expectedNumUBResults != getUpperBoundsMap().getNumResults())
4168 return emitOpError() <<
"expected upper bounds map to have "
4169 << expectedNumUBResults <<
" results";
4171 if (getReductions().size() != getNumResults())
4172 return emitOpError(
"a reduction must be specified for each output");
4178 auto intAttr = dyn_cast<IntegerAttr>(attr);
4179 if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
4180 return emitOpError(
"invalid reduction attribute");
4181 auto kind = arith::symbolizeAtomicRMWKind(intAttr.getInt()).value();
4183 return emitOpError(
"result type cannot match reduction attribute");
4189 getLowerBoundsMap().getNumDims())))
4193 getUpperBoundsMap().getNumDims())))
4198 LogicalResult AffineValueMap::canonicalize() {
4200 auto newMap = getAffineMap();
4202 if (newMap == getAffineMap() && newOperands == operands)
4204 reset(newMap, newOperands);
4217 if (!lbCanonicalized && !ubCanonicalized)
4220 if (lbCanonicalized)
4222 if (ubCanonicalized)
4228 LogicalResult AffineParallelOp::fold(FoldAdaptor adaptor,
4240 StringRef keyword) {
4243 ValueRange dimOperands = operands.take_front(numDims);
4244 ValueRange symOperands = operands.drop_front(numDims);
4246 for (llvm::APInt groupSize : group) {
4250 unsigned size = groupSize.getZExtValue();
4255 p << keyword <<
'(';
4265 p <<
" (" << getBody()->getArguments() <<
") = (";
4267 getLowerBoundsOperands(),
"max");
4270 getUpperBoundsOperands(),
"min");
4273 bool elideSteps = llvm::all_of(steps, [](int64_t step) {
return step == 1; });
4276 llvm::interleaveComma(steps, p);
4279 if (getNumResults()) {
4281 llvm::interleaveComma(getReductions(), p, [&](
auto &attr) {
4282 arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(
4283 llvm::cast<IntegerAttr>(attr).getInt());
4284 p <<
"\"" << arith::stringifyAtomicRMWKind(sym) <<
"\"";
4286 p <<
") -> (" << getResultTypes() <<
")";
4293 (*this)->getAttrs(),
4294 {AffineParallelOp::getReductionsAttrStrName(),
4295 AffineParallelOp::getLowerBoundsMapAttrStrName(),
4296 AffineParallelOp::getLowerBoundsGroupsAttrStrName(),
4297 AffineParallelOp::getUpperBoundsMapAttrStrName(),
4298 AffineParallelOp::getUpperBoundsGroupsAttrStrName(),
4299 AffineParallelOp::getStepsAttrStrName()});
4312 "expected operands to be dim or symbol expression");
4315 for (
const auto &list : operands) {
4319 for (
Value operand : valueOperands) {
4320 unsigned pos = std::distance(uniqueOperands.begin(),
4321 llvm::find(uniqueOperands, operand));
4322 if (pos == uniqueOperands.size())
4323 uniqueOperands.push_back(operand);
4324 replacements.push_back(
4334 enum class MinMaxKind { Min, Max };
4358 const llvm::StringLiteral tmpAttrStrName =
"__pseudo_bound_map";
4360 StringRef mapName =
kind == MinMaxKind::Min
4361 ? AffineParallelOp::getUpperBoundsMapAttrStrName()
4362 : AffineParallelOp::getLowerBoundsMapAttrStrName();
4363 StringRef groupsName =
4364 kind == MinMaxKind::Min
4365 ? AffineParallelOp::getUpperBoundsGroupsAttrStrName()
4366 : AffineParallelOp::getLowerBoundsGroupsAttrStrName();
4383 auto parseOperands = [&]() {
4385 kind == MinMaxKind::Min ?
"min" :
"max"))) {
4386 mapOperands.clear();
4393 llvm::append_range(flatExprs, map.getValue().getResults());
4395 auto dimsRef = operandsRef.take_front(map.getValue().getNumDims());
4397 auto symsRef = operandsRef.drop_front(map.getValue().getNumDims());
4399 flatDimOperands.append(map.getValue().getNumResults(), dims);
4400 flatSymOperands.append(map.getValue().getNumResults(), syms);
4401 numMapsPerGroup.push_back(map.getValue().getNumResults());
4404 flatSymOperands.emplace_back(),
4405 flatExprs.emplace_back())))
4407 numMapsPerGroup.push_back(1);
4414 unsigned totalNumDims = 0;
4415 unsigned totalNumSyms = 0;
4416 for (
unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
4417 unsigned numDims = flatDimOperands[i].size();
4418 unsigned numSyms = flatSymOperands[i].size();
4419 flatExprs[i] = flatExprs[i]
4420 .shiftDims(numDims, totalNumDims)
4421 .shiftSymbols(numSyms, totalNumSyms);
4422 totalNumDims += numDims;
4423 totalNumSyms += numSyms;
4435 result.
operands.append(dimOperands.begin(), dimOperands.end());
4436 result.
operands.append(symOperands.begin(), symOperands.end());
4439 auto flatMap =
AffineMap::get(totalNumDims, totalNumSyms, flatExprs,
4441 flatMap = flatMap.replaceDimsAndSymbols(
4442 dimRplacements, symRepacements, dimOperands.size(), symOperands.size());
4466 AffineMapAttr stepsMapAttr;
4471 result.
addAttribute(AffineParallelOp::getStepsAttrStrName(),
4475 AffineParallelOp::getStepsAttrStrName(),
4482 auto stepsMap = stepsMapAttr.getValue();
4483 for (
const auto &result : stepsMap.getResults()) {
4484 auto constExpr = dyn_cast<AffineConstantExpr>(result);
4487 "steps must be constant integers");
4488 steps.push_back(constExpr.getValue());
4490 result.
addAttribute(AffineParallelOp::getStepsAttrStrName(),
4500 auto parseAttributes = [&]() -> ParseResult {
4510 std::optional<arith::AtomicRMWKind> reduction =
4511 arith::symbolizeAtomicRMWKind(attrVal.getValue());
4513 return parser.
emitError(loc,
"invalid reduction value: ") << attrVal;
4514 reductions.push_back(
4522 result.
addAttribute(AffineParallelOp::getReductionsAttrStrName(),
4531 for (
auto &iv : ivs)
4532 iv.type = indexType;
4538 AffineParallelOp::ensureTerminator(*body, builder, result.
location);
4547 auto *parentOp = (*this)->getParentOp();
4548 auto results = parentOp->getResults();
4549 auto operands = getOperands();
4551 if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
4552 return emitOpError() <<
"only terminates affine.if/for/parallel regions";
4553 if (parentOp->getNumResults() != getNumOperands())
4554 return emitOpError() <<
"parent of yield must have same number of "
4555 "results as the yield operands";
4556 for (
auto it : llvm::zip(results, operands)) {
4558 return emitOpError() <<
"types mismatch between yield op and its parent";
4571 assert(operands.size() == 1 + map.
getNumInputs() &&
"inconsistent operands");
4575 result.
types.push_back(resultType);
4579 VectorType resultType,
Value memref,
4581 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
4585 result.
types.push_back(resultType);
4589 VectorType resultType,
Value memref,
4591 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
4592 int64_t rank = memrefType.getRank();
4597 build(builder, result, resultType, memref, map, indices);
4600 void AffineVectorLoadOp::getCanonicalizationPatterns(
RewritePatternSet &results,
4602 results.
add<SimplifyAffineOp<AffineVectorLoadOp>>(context);
4610 MemRefType memrefType;
4611 VectorType resultType;
4613 AffineMapAttr mapAttr;
4618 AffineVectorLoadOp::getMapAttrStrName(),
4629 p <<
" " << getMemRef() <<
'[';
4630 if (AffineMapAttr mapAttr =
4631 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4635 {getMapAttrStrName()});
4641 VectorType vectorType) {
4643 if (memrefType.getElementType() != vectorType.getElementType())
4645 "requires memref and vector types of the same elemental type");
4652 *
this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4653 getMapOperands(), memrefType,
4654 getNumOperands() - 1)))
4670 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
4681 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
4682 int64_t rank = memrefType.getRank();
4687 build(builder, result, valueToStore, memref, map, indices);
4689 void AffineVectorStoreOp::getCanonicalizationPatterns(
4691 results.
add<SimplifyAffineOp<AffineVectorStoreOp>>(context);
4698 MemRefType memrefType;
4699 VectorType resultType;
4702 AffineMapAttr mapAttr;
4708 AffineVectorStoreOp::getMapAttrStrName(),
4719 p <<
" " << getValueToStore();
4720 p <<
", " << getMemRef() <<
'[';
4721 if (AffineMapAttr mapAttr =
4722 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4726 {getMapAttrStrName()});
4727 p <<
" : " <<
getMemRefType() <<
", " << getValueToStore().getType();
4733 *
this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4734 getMapOperands(), memrefType,
4735 getNumOperands() - 2)))
4748 void AffineDelinearizeIndexOp::build(
OpBuilder &odsBuilder,
4752 bool hasOuterBound) {
4754 : staticBasis.size() + 1,
4756 build(odsBuilder, odsState, returnTypes, linearIndex, dynamicBasis,
4760 void AffineDelinearizeIndexOp::build(
OpBuilder &odsBuilder,
4763 bool hasOuterBound) {
4764 if (hasOuterBound && !basis.empty() && basis.front() ==
nullptr) {
4765 hasOuterBound =
false;
4766 basis = basis.drop_front();
4772 build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4776 void AffineDelinearizeIndexOp::build(
OpBuilder &odsBuilder,
4780 bool hasOuterBound) {
4781 if (hasOuterBound && !basis.empty() && basis.front() ==
OpFoldResult()) {
4782 hasOuterBound =
false;
4783 basis = basis.drop_front();
4788 build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4792 void AffineDelinearizeIndexOp::build(
OpBuilder &odsBuilder,
4795 bool hasOuterBound) {
4796 build(odsBuilder, odsState, linearIndex,
ValueRange{}, basis, hasOuterBound);
4801 if (getNumResults() != staticBasis.size() &&
4802 getNumResults() != staticBasis.size() + 1)
4803 return emitOpError(
"should return an index for each basis element and up "
4804 "to one extra index");
4806 auto dynamicMarkersCount = llvm::count_if(staticBasis, ShapedType::isDynamic);
4807 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
4809 "mismatch between dynamic and static basis (kDynamic marker but no "
4810 "corresponding dynamic basis entry) -- this can only happen due to an "
4811 "incorrect fold/rewrite");
4813 if (!llvm::all_of(staticBasis, [](int64_t v) {
4814 return v > 0 || ShapedType::isDynamic(v);
4816 return emitOpError(
"no basis element may be statically non-positive");
4825 static std::optional<SmallVector<int64_t>>
4829 uint64_t dynamicBasisIndex = 0;
4832 mutableDynamicBasis.
erase(dynamicBasisIndex);
4834 ++dynamicBasisIndex;
4839 if (dynamicBasisIndex == dynamicBasis.size())
4840 return std::nullopt;
4846 staticBasis.push_back(ShapedType::kDynamic);
4848 staticBasis.push_back(*basisVal);
4855 AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
4857 std::optional<SmallVector<int64_t>> maybeStaticBasis =
4859 adaptor.getDynamicBasis());
4860 if (maybeStaticBasis) {
4861 setStaticBasis(*maybeStaticBasis);
4866 if (getNumResults() == 1) {
4867 result.push_back(getLinearIndex());
4871 if (adaptor.getLinearIndex() ==
nullptr)
4874 if (!adaptor.getDynamicBasis().empty())
4877 int64_t highPart = cast<IntegerAttr>(adaptor.getLinearIndex()).getInt();
4878 Type attrType = getLinearIndex().getType();
4881 if (hasOuterBound())
4882 staticBasis = staticBasis.drop_front();
4883 for (int64_t modulus : llvm::reverse(staticBasis)) {
4884 result.push_back(
IntegerAttr::get(attrType, llvm::mod(highPart, modulus)));
4885 highPart = llvm::divideFloorSigned(highPart, modulus);
4888 std::reverse(result.begin(), result.end());
4894 if (hasOuterBound()) {
4895 if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
4897 getDynamicBasis().drop_front(), builder);
4899 return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
4903 return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
4908 if (!hasOuterBound())
4916 struct DropUnitExtentBasis
4920 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4923 std::optional<Value> zero = std::nullopt;
4924 Location loc = delinearizeOp->getLoc();
4928 return zero.value();
4934 for (
auto [index, basis] :
4936 std::optional<int64_t> basisVal =
4939 replacements[index] =
getZero();
4941 newBasis.push_back(basis);
4944 if (newBasis.size() == delinearizeOp.getNumResults())
4946 "no unit basis elements");
4948 if (!newBasis.empty()) {
4950 auto newDelinearizeOp = affine::AffineDelinearizeIndexOp::create(
4951 rewriter, loc, delinearizeOp.getLinearIndex(), newBasis);
4954 for (
auto &replacement : replacements) {
4957 replacement = newDelinearizeOp->getResult(newIndex++);
4961 rewriter.
replaceOp(delinearizeOp, replacements);
4976 struct CancelDelinearizeOfLinearizeDisjointExactTail
4980 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4982 auto linearizeOp = delinearizeOp.getLinearIndex()
4983 .getDefiningOp<affine::AffineLinearizeIndexOp>();
4986 "index doesn't come from linearize");
4988 if (!linearizeOp.getDisjoint())
4991 ValueRange linearizeIns = linearizeOp.getMultiIndex();
4995 size_t numMatches = 0;
4996 for (
auto [linSize, delinSize] : llvm::zip(
4997 llvm::reverse(linearizeBasis), llvm::reverse(delinearizeBasis))) {
4998 if (linSize != delinSize)
5003 if (numMatches == 0)
5005 delinearizeOp,
"final basis element doesn't match linearize");
5008 if (numMatches == linearizeBasis.size() &&
5009 numMatches == delinearizeBasis.size() &&
5010 linearizeIns.size() == delinearizeOp.getNumResults()) {
5011 rewriter.
replaceOp(delinearizeOp, linearizeOp.getMultiIndex());
5015 Value newLinearize = affine::AffineLinearizeIndexOp::create(
5016 rewriter, linearizeOp.getLoc(), linearizeIns.drop_back(numMatches),
5018 linearizeOp.getDisjoint());
5019 auto newDelinearize = affine::AffineDelinearizeIndexOp::create(
5020 rewriter, delinearizeOp.getLoc(), newLinearize,
5022 delinearizeOp.hasOuterBound());
5024 mergedResults.append(linearizeIns.take_back(numMatches).begin(),
5025 linearizeIns.take_back(numMatches).end());
5026 rewriter.
replaceOp(delinearizeOp, mergedResults);
5044 struct SplitDelinearizeSpanningLastLinearizeArg final
5048 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
5050 auto linearizeOp = delinearizeOp.getLinearIndex()
5051 .getDefiningOp<affine::AffineLinearizeIndexOp>();
5054 "index doesn't come from linearize");
5056 if (!linearizeOp.getDisjoint())
5058 "linearize isn't disjoint");
5060 int64_t target = linearizeOp.getStaticBasis().back();
5061 if (ShapedType::isDynamic(target))
5063 linearizeOp,
"linearize ends with dynamic basis value");
5065 int64_t sizeToSplit = 1;
5066 size_t elemsToSplit = 0;
5068 for (int64_t basisElem : llvm::reverse(basis)) {
5069 if (ShapedType::isDynamic(basisElem))
5071 delinearizeOp,
"dynamic basis element while scanning for split");
5072 sizeToSplit *= basisElem;
5075 if (sizeToSplit > target)
5077 "overshot last argument size");
5078 if (sizeToSplit == target)
5082 if (sizeToSplit < target)
5084 delinearizeOp,
"product of known basis elements doesn't exceed last "
5085 "linearize argument");
5087 if (elemsToSplit < 2)
5090 "need at least two elements to form the basis product");
5092 Value linearizeWithoutBack = affine::AffineLinearizeIndexOp::create(
5093 rewriter, linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(),
5094 linearizeOp.getDynamicBasis(), linearizeOp.getStaticBasis().drop_back(),
5095 linearizeOp.getDisjoint());
5096 auto delinearizeWithoutSplitPart = affine::AffineDelinearizeIndexOp::create(
5097 rewriter, delinearizeOp.getLoc(), linearizeWithoutBack,
5098 delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),
5099 delinearizeOp.hasOuterBound());
5100 auto delinearizeBack = affine::AffineDelinearizeIndexOp::create(
5101 rewriter, delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(),
5102 basis.take_back(elemsToSplit),
true);
5104 llvm::concat<Value>(delinearizeWithoutSplitPart.getResults(),
5105 delinearizeBack.getResults()));
5106 rewriter.
replaceOp(delinearizeOp, results);
5113 void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
5116 .insert<CancelDelinearizeOfLinearizeDisjointExactTail,
5117 DropUnitExtentBasis, SplitDelinearizeSpanningLastLinearizeArg>(
5125 void AffineLinearizeIndexOp::build(
OpBuilder &odsBuilder,
5129 if (!basis.empty() && basis.front() ==
Value())
5130 basis = basis.drop_front();
5135 build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
5138 void AffineLinearizeIndexOp::build(
OpBuilder &odsBuilder,
5144 basis = basis.drop_front();
5148 build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
5151 void AffineLinearizeIndexOp::build(
OpBuilder &odsBuilder,
5155 build(odsBuilder, odsState, multiIndex,
ValueRange{}, basis, disjoint);
5159 size_t numIndexes = getMultiIndex().size();
5160 size_t numBasisElems = getStaticBasis().size();
5161 if (numIndexes != numBasisElems && numIndexes != numBasisElems + 1)
5162 return emitOpError(
"should be passed a basis element for each index except "
5163 "possibly the first");
5165 auto dynamicMarkersCount =
5166 llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
5167 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
5169 "mismatch between dynamic and static basis (kDynamic marker but no "
5170 "corresponding dynamic basis entry) -- this can only happen due to an "
5171 "incorrect fold/rewrite");
5176 OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
5177 std::optional<SmallVector<int64_t>> maybeStaticBasis =
5179 adaptor.getDynamicBasis());
5180 if (maybeStaticBasis) {
5181 setStaticBasis(*maybeStaticBasis);
5185 if (getMultiIndex().empty())
5189 if (getMultiIndex().size() == 1)
5190 return getMultiIndex().front();
5192 if (llvm::is_contained(adaptor.getMultiIndex(),
nullptr))
5195 if (!adaptor.getDynamicBasis().empty())
5200 for (
auto [length, indexAttr] :
5201 llvm::zip_first(llvm::reverse(getStaticBasis()),
5202 llvm::reverse(adaptor.getMultiIndex()))) {
5203 result = result + cast<IntegerAttr>(indexAttr).getInt() * stride;
5204 stride = stride * length;
5207 if (!hasOuterBound())
5210 cast<IntegerAttr>(adaptor.getMultiIndex().front()).getInt() * stride;
5217 if (hasOuterBound()) {
5218 if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
5220 getDynamicBasis().drop_front(), builder);
5222 return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
5226 return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
5231 if (!hasOuterBound())
5247 struct DropLinearizeUnitComponentsIfDisjointOrZero final
5251 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5254 size_t numIndices = multiIndex.size();
5256 newIndices.reserve(numIndices);
5258 newBasis.reserve(numIndices);
5260 if (!op.hasOuterBound()) {
5261 newIndices.push_back(multiIndex.front());
5262 multiIndex = multiIndex.drop_front();
5266 for (
auto [index, basisElem] : llvm::zip_equal(multiIndex, basis)) {
5268 if (!basisEntry || *basisEntry != 1) {
5269 newIndices.push_back(index);
5270 newBasis.push_back(basisElem);
5275 if (!op.getDisjoint() && (!indexValue || *indexValue != 0)) {
5276 newIndices.push_back(index);
5277 newBasis.push_back(basisElem);
5281 if (newIndices.size() == numIndices)
5283 "no unit basis entries to replace");
5285 if (newIndices.size() == 0) {
5290 op, newIndices, newBasis, op.getDisjoint());
5297 int64_t nDynamic = 0;
5307 dynamicPart.push_back(cast<Value>(term));
5311 if (
auto constant = dyn_cast<AffineConstantExpr>(result))
5313 return AffineApplyOp::create(builder, loc, result, dynamicPart).getResult();
5343 struct CancelLinearizeOfDelinearizePortion final
5354 unsigned linStart = 0;
5355 unsigned delinStart = 0;
5356 unsigned length = 0;
5360 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp,
5367 ValueRange multiIndex = linearizeOp.getMultiIndex();
5368 unsigned numLinArgs = multiIndex.size();
5369 unsigned linArgIdx = 0;
5373 while (linArgIdx < numLinArgs) {
5374 auto asResult = dyn_cast<OpResult>(multiIndex[linArgIdx]);
5380 auto delinearizeOp =
5381 dyn_cast<AffineDelinearizeIndexOp>(asResult.getOwner());
5382 if (!delinearizeOp) {
5399 unsigned delinArgIdx = asResult.getResultNumber();
5401 OpFoldResult firstDelinBound = delinBasis[delinArgIdx];
5403 bool boundsMatch = firstDelinBound == firstLinBound;
5404 bool bothAtFront = linArgIdx == 0 && delinArgIdx == 0;
5405 bool knownByDisjoint =
5406 linearizeOp.getDisjoint() && delinArgIdx == 0 && !firstDelinBound;
5407 if (!boundsMatch && !bothAtFront && !knownByDisjoint) {
5413 unsigned numDelinOuts = delinearizeOp.getNumResults();
5414 for (;
j + linArgIdx < numLinArgs &&
j + delinArgIdx < numDelinOuts;
5416 if (multiIndex[linArgIdx +
j] !=
5417 delinearizeOp.getResult(delinArgIdx +
j))
5419 if (linBasis[linArgIdx +
j] != delinBasis[delinArgIdx +
j])
5425 if (
j <= 1 || !alreadyMatchedDelinearize.insert(delinearizeOp).second) {
5429 matches.push_back(Match{delinearizeOp, linArgIdx, delinArgIdx,
j});
5433 if (matches.empty())
5435 linearizeOp,
"no run of delinearize outputs to deal with");
5443 newIndex.reserve(numLinArgs);
5445 newBasis.reserve(numLinArgs);
5446 unsigned prevMatchEnd = 0;
5447 for (Match m : matches) {
5448 unsigned gap = m.linStart - prevMatchEnd;
5449 llvm::append_range(newIndex, multiIndex.slice(prevMatchEnd, gap));
5450 llvm::append_range(newBasis, linBasisRef.slice(prevMatchEnd, gap));
5452 prevMatchEnd = m.linStart + m.length;
5454 PatternRewriter::InsertionGuard g(rewriter);
5458 linBasisRef.slice(m.linStart, m.length);
5466 newIndex.push_back(m.delinearize.getLinearIndex());
5467 newBasis.push_back(newSize);
5475 newDelinBasis.erase(newDelinBasis.begin() + m.delinStart,
5476 newDelinBasis.begin() + m.delinStart + m.length);
5477 newDelinBasis.insert(newDelinBasis.begin() + m.delinStart, newSize);
5478 auto newDelinearize = AffineDelinearizeIndexOp::create(
5479 rewriter, m.delinearize.getLoc(), m.delinearize.getLinearIndex(),
5485 Value combinedElem = newDelinearize.getResult(m.delinStart);
5486 auto residualDelinearize = AffineDelinearizeIndexOp::create(
5487 rewriter, m.delinearize.getLoc(), combinedElem, basisToMerge);
5492 llvm::append_range(newDelinResults,
5493 newDelinearize.getResults().take_front(m.delinStart));
5494 llvm::append_range(newDelinResults, residualDelinearize.getResults());
5497 newDelinearize.getResults().drop_front(m.delinStart + 1));
5499 delinearizeReplacements.push_back(newDelinResults);
5500 newIndex.push_back(combinedElem);
5501 newBasis.push_back(newSize);
5503 llvm::append_range(newIndex, multiIndex.drop_front(prevMatchEnd));
5504 llvm::append_range(newBasis, linBasisRef.drop_front(prevMatchEnd));
5506 linearizeOp, newIndex, newBasis, linearizeOp.getDisjoint());
5508 for (
auto [m, newResults] :
5509 llvm::zip_equal(matches, delinearizeReplacements)) {
5510 if (newResults.empty())
5512 rewriter.
replaceOp(m.delinearize, newResults);
5523 struct DropLinearizeLeadingZero final
5527 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5529 Value leadingIdx = op.getMultiIndex().front();
5533 if (op.getMultiIndex().size() == 1) {
5540 if (op.hasOuterBound())
5541 newMixedBasis = newMixedBasis.drop_front();
5544 op, op.getMultiIndex().drop_front(), newMixedBasis, op.getDisjoint());
5550 void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
5552 patterns.add<CancelLinearizeOfDelinearizePortion, DropLinearizeLeadingZero,
5553 DropLinearizeUnitComponentsIfDisjointOrZero>(context);
5560 #define GET_OP_CLASSES
5561 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
static AffineForOp buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb, int64_t ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds known to be constants.
static bool hasTrivialZeroTripCount(AffineForOp op)
Returns true if the affine.for has zero iterations in trivial cases.
static LogicalResult verifyMemoryOpIndexing(AffineMemOpTy op, AffineMapAttr mapAttr, Operation::operand_range mapOperands, MemRefType memrefType, unsigned numIndexOperands)
Verify common indexing invariants of affine.load, affine.store, affine.vector_load and affine....
static void printAffineMinMaxOp(OpAsmPrinter &p, T op)
static bool isResultTypeMatchAtomicRMWKind(Type resultType, arith::AtomicRMWKind op)
static bool remainsLegalAfterInline(Value value, Region *src, Region *dest, const IRMapping &mapping, function_ref< bool(Value, Region *)> legalityCheck)
Checks if value known to be a legal affine dimension or symbol in src region remains legal if the ope...
static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr, DenseIntElementsAttr group, ValueRange operands, StringRef keyword)
Prints a lower(upper) bound of an affine parallel loop with max(min) conditions in it.
static void LLVM_ATTRIBUTE_UNUSED simplifyMapWithOperands(AffineMap &map, ArrayRef< Value > operands)
Simplify the map while exploiting information on the values in operands.
static OpFoldResult foldMinMaxOp(T op, ArrayRef< Attribute > operands)
Fold an affine min or max operation with the given operands.
static LogicalResult canonicalizeLoopBounds(AffineForOp forOp)
Canonicalize the bounds of the given loop.
static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims, unsigned numSymbols, ArrayRef< Value > operands)
Simplify expr while exploiting information from the values in operands.
static bool isValidAffineIndexOperand(Value value, Region *region)
static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static ParseResult parseBound(bool isLower, OperationState &result, OpAsmParser &p)
Parse a for operation loop bounds.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType)
Verify common invariants of affine.vector_load and affine.vector_store.
static void simplifyMinOrMaxExprWithOperands(AffineMap &map, ArrayRef< Value > operands, bool isMax)
Simplify the expressions in map while making use of lower or upper bounds of its operands.
static ParseResult parseAffineMinMaxOp(OpAsmParser &parser, OperationState &result)
static void composeSetAndOperands(IntegerSet &set, SmallVectorImpl< Value > &operands, bool composeAffineMin=false)
Compose any affine.apply ops feeding into operands of the integer set set by composing the maps of su...
static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, Region *region)
Returns true if the 'index' dimension of the memref defined by memrefDefOp is a statically shaped one...
static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef< Value > operands, int64_t k)
Check if e is known to be: 0 <= e < k.
static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser, OperationState &result, MinMaxKind kind)
Parses an affine map that can contain a min/max for groups of its results, e.g., max(expr-1,...
static AffineForOp buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds that may or may not be constants.
static void printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, OpAsmPrinter &printer)
Prints dimension and symbol list.
static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef< Value > operands)
Returns the largest known divisor of e.
static void composeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands, bool composeAffineMin=false)
Iterate over operands and fold away all those produced by an AffineApplyOp iteratively.
static void legalizeDemotedDims(MapOrSet &mapOrSet, SmallVectorImpl< Value > &operands)
A valid affine dimension may appear as a symbol in affine.apply operations.
static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
static void buildAffineLoopNestImpl(OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, LoopCreatorTy &&loopCreatorFn)
Builds an affine loop nest, using "loopCreatorFn" to create individual loop operations.
static LogicalResult foldLoopBounds(AffineForOp forOp)
Fold the constant bounds of a loop.
static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp, AffineExpr dimOrSym, AffineMap *map, ValueRange dims, ValueRange syms)
Assuming dimOrSym is a quantity in the apply op map map and defined by minOp = affine_min(x_1,...
static LogicalResult verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, unsigned numDims)
Utility function to verify that a set of operands are valid dimension and symbol identifiers.
static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region)
Returns true if the result of the dim op is a valid symbol for region.
static bool isQTimesDPlusR(AffineExpr e, ArrayRef< Value > operands, int64_t &div, AffineExpr "ientTimesDiv, AffineExpr &rem)
Check if expression e is of the form d*e_1 + e_2 where 0 <= e_2 < d.
static ParseResult deduplicateAndResolveOperands(OpAsmParser &parser, ArrayRef< SmallVector< OpAsmParser::UnresolvedOperand >> operands, SmallVectorImpl< Value > &uniqueOperands, SmallVectorImpl< AffineExpr > &replacements, AffineExprKind kind)
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult replaceDimOrSym(AffineMap *map, unsigned dimOrSymbolPosition, SmallVectorImpl< Value > &dims, SmallVectorImpl< Value > &syms, bool replaceAffineMin)
Replace all occurrences of AffineExpr at position pos in map by the defining AffineApplyOp expression...
static LogicalResult verifyAffineMinMaxOp(T op)
static void printBound(AffineMapAttr boundMap, Operation::operand_range boundOperands, const char *prefix, OpAsmPrinter &p)
static void composeMultiResultAffineMap(AffineMap &map, SmallVectorImpl< Value > &operands, bool composeAffineMin=false)
Composes the given affine map with the given list of operands, pulling in the maps from any affine....
static std::optional< SmallVector< int64_t > > foldCstValueToCstAttrBasis(ArrayRef< OpFoldResult > mixedBasis, MutableOperandRange mutableDynamicBasis, ArrayRef< Attribute > dynamicBasis)
Given mixed basis of affine.delinearize_index/linearize_index replace constant SSA values with the co...
static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map)
Canonicalize the result expression order of an affine map and return success if the order changed.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
union mlir::linalg::@1227::ArityGroupAndKind::Kind kind
static Operation::operand_range getLowerBoundOperands(AffineForOp forOp)
static Operation::operand_range getUpperBoundOperands(AffineForOp forOp)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
AffineExprKind getKind() const
Return the classification for this type.
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
MLIRContext * getContext() const
AffineExpr replace(AffineExpr expr, AffineExpr replacement) const
Sparse replace method.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
MLIRContext * getContext() const
bool isFunctionOfDim(unsigned position) const
Return true if any affine expression involves AffineDimExpr position.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
bool isFunctionOfSymbol(unsigned position) const
Return true if any affine expression involves AffineSymbolExpr position.
unsigned getNumResults() const
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
unsigned getNumInputs() const
AffineMap shiftSymbols(unsigned shift, unsigned offset=0) const
Replace symbols[offset ...
AffineExpr getResult(unsigned idx) const
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
static AffineMap getConstantMap(int64_t val, MLIRContext *context)
Returns a single constant result affine map.
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
LogicalResult constantFold(ArrayRef< Attribute > operandConstants, SmallVectorImpl< Attribute > &results, bool *hasPoison=nullptr) const
Folds the results of the application of an affine map on the provided operands to a constant if possi...
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
AffineMap getDimIdentityMap()
AffineMap getMultiDimIdentityMap(unsigned rank)
AffineExpr getAffineSymbolExpr(unsigned position)
AffineExpr getAffineConstantExpr(int64_t constant)
DenseIntElementsAttr getI32TensorAttr(ArrayRef< int32_t > values)
Tensor-typed DenseIntElementsAttr getters.
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
AffineMap getEmptyAffineMap()
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
MLIRContext * getContext() const
AffineMap getSymbolIdentityMap()
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
An attribute that represents a reference to a dense integer vector or tensor object.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
An integer set representing a conjunction of one or more affine equalities and inequalities.
unsigned getNumDims() const
static IntegerSet get(unsigned dimCount, unsigned symbolCount, ArrayRef< AffineExpr > constraints, ArrayRef< bool > eqFlags)
MLIRContext * getContext() const
unsigned getNumInputs() const
ArrayRef< AffineExpr > getConstraints() const
ArrayRef< bool > getEqFlags() const
Returns the equality bits, which specify whether each of the constraints is an equality or inequality...
unsigned getNumSymbols() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
void erase(unsigned subStart, unsigned subLen=1)
Erase the operands within the given sub-range.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void pop_back()
Pop last element from list.
Attribute erase(StringAttr name)
Erase the attribute with the given name from the list.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseAffineMapOfSSAIds(SmallVectorImpl< UnresolvedOperand > &operands, Attribute &map, StringRef attrName, NamedAttrList &attrs, Delimiter delimiter=Delimiter::Square)=0
Parses an affine map attribute where dims and symbols are SSA operands.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseAffineExprOfSSAIds(SmallVectorImpl< UnresolvedOperand > &dimOperands, SmallVectorImpl< UnresolvedOperand > &symbOperands, AffineExpr &expr)=0
Parses an affine expression where dims and symbols are SSA operands.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, ValueRange symOperands)=0
Prints an affine expression of SSA ids with SSA id names used instead of dims and symbols.
virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, ValueRange operands)=0
Prints an affine map of SSA ids, where SSA id names are used in place of dims/symbols.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
A trait of region holding operations that defines a new scope for polyhedral optimization purposes.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
operand_range::iterator operand_iterator
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
bool hasOneBlock()
Return true if this region has exactly one block.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
std::vector< SmallVector< int64_t, 8 > > operandExprStack
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
A variable that can be added to the constraint set as a "column".
static bool compare(const Variable &lhs, ComparisonOperator cmp, const Variable &rhs)
Return "true" if "lhs cmp rhs" was proven to hold.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
AffineBound represents a lower or upper bound in the for operation.
AffineDmaStartOp starts a non-blocking DMA operation that transfers data from a source memref to a de...
AffineDmaWaitOp blocks until the completion of a DMA operation associated with the tag element 'tag[i...
An AffineValueMap is an affine map plus its ML value operands and results for analysis purposes.
LogicalResult canonicalize()
Attempts to canonicalize the map and operands.
ArrayRef< Value > getOperands() const
AffineExpr getResult(unsigned i)
AffineMap getAffineMap() const
unsigned getNumResults() const
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Operation * getOwner() const
Return the owner of this operand.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
void buildAffineLoopNest(OpBuilder &builder, Location loc, ArrayRef< int64_t > lbs, ArrayRef< int64_t > ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn=nullptr)
Builds a perfect nest of affine.for loops, i.e., each loop except the innermost one contains only ano...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
void extractForInductionVars(ArrayRef< AffineForOp > forInsts, SmallVectorImpl< Value > *ivs)
Extracts the induction variables from a list of AffineForOps and places them in the output argument i...
bool isValidDim(Value value)
Returns true if the given Value can be used as a dimension id in the region of the closest surroundin...
bool isAffineInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp or AffineParallelOp.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
AffineForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
bool isAffineForInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
bool isTopLevelValue(Value value)
A utility function to check if a value is defined at the top level of an op with trait AffineScope or...
Region * getAffineAnalysisScope(Operation *op)
Returns the closest region enclosing op that is held by a non-affine operation; nullptr if there is n...
void fullyComposeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands, bool composeAffineMin=false)
Given an affine map map and its input operands, this method composes into map, maps of AffineApplyOps...
void canonicalizeSetAndOperands(IntegerSet *set, SmallVectorImpl< Value > *operands)
Canonicalizes an integer set the same way canonicalizeMapAndOperands does for affine maps.
void extractInductionVars(ArrayRef< Operation * > affineOps, SmallVectorImpl< Value > &ivs)
Extracts the induction variables from a list of either AffineForOp or AffineParallelOp and places the...
bool isValidSymbol(Value value)
Returns true if the given value can be used as a symbol in the region of the closest surrounding op t...
AffineParallelOp getAffineParallelInductionVarOwner(Value val)
Returns true if the provided value is among the induction variables of an AffineParallelOp.
Region * getAffineScope(Operation *op)
Returns the closest region enclosing op that is held by an operation with trait AffineScope; nullptr ...
ParseResult parseDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl< Value > &operands, unsigned &numDims)
Parses dimension and symbol list.
bool isAffineParallelInductionVar(Value val)
Returns true if val is the induction variable of an AffineParallelOp.
AffineMinOp makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns an AffineMinOp obtained by composing map and operands with AffineApplyOps supplying those ope...
BaseMemRefType getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the TensorType can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
AffineMap removeDuplicateExprs(AffineMap map)
Returns a map with the same dimension and symbol count as map, but whose results are the unique affin...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::optional< int64_t > getBoundForAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols, ArrayRef< std::optional< int64_t >> constLowerBounds, ArrayRef< std::optional< int64_t >> constUpperBounds, bool isUpper)
Get a lower or upper (depending on isUpper) bound for expr while using the constant lower and upper b...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ SymbolId
Symbolic identifier.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
AffineMap foldAttributesIntoMap(Builder &b, AffineMap map, ArrayRef< OpFoldResult > operands, SmallVector< Value > &remainingValues)
Fold all attributes among the given operands into the affine map.
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Canonicalize the affine map result expression order of an affine min/max operation.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Remove duplicated expressions in affine min/max ops.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Merge an affine min/max op to its consumers if its consumer is also an affine min/max op.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.