25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallBitVector.h"
27 #include "llvm/ADT/SmallVectorExtras.h"
28 #include "llvm/ADT/TypeSwitch.h"
29 #include "llvm/Support/DebugLog.h"
30 #include "llvm/Support/LogicalResult.h"
31 #include "llvm/Support/MathExtras.h"
38 using llvm::divideCeilSigned;
39 using llvm::divideFloorSigned;
42 #define DEBUG_TYPE "affine-ops"
44 #include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
51 if (
auto arg = dyn_cast<BlockArgument>(value))
52 return arg.getParentRegion() == region;
75 if (llvm::isa<BlockArgument>(value))
76 return legalityCheck(mapping.
lookup(value), dest);
83 bool isDimLikeOp = isa<ShapedDimOpInterface>(value.
getDefiningOp());
94 return llvm::all_of(values, [&](
Value v) {
101 template <
typename OpTy>
104 static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
105 AffineWriteOpInterface>::value,
106 "only ops with affine read/write interface are supported");
113 dimOperands, src, dest, mapping,
117 symbolOperands, src, dest, mapping,
134 op.getMapOperands(), src, dest, mapping,
139 op.getMapOperands(), src, dest, mapping,
166 if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
179 if (
auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
180 if (iface.hasNoEffect())
188 .Case<AffineApplyOp, AffineReadOpInterface,
189 AffineWriteOpInterface>([&](
auto op) {
214 isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
218 bool shouldAnalyzeRecursively(
Operation *op)
const final {
return true; }
226 void AffineDialect::initialize() {
229 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
231 addInterfaces<AffineInlinerInterface>();
232 declarePromisedInterfaces<ValueBoundsOpInterface, AffineApplyOp, AffineMaxOp,
241 if (
auto poison = dyn_cast<ub::PoisonAttr>(value))
242 return ub::PoisonOp::create(builder, loc, type, poison);
243 return arith::ConstantOp::materialize(builder, value, type, loc);
251 if (
auto arg = dyn_cast<BlockArgument>(value)) {
267 while (
auto *parentOp = curOp->getParentOp()) {
278 if (!isa<AffineForOp, AffineIfOp, AffineParallelOp>(parentOp))
303 auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
331 if (
auto applyOp = dyn_cast<AffineApplyOp>(op))
332 return applyOp.isValidDim(region);
335 if (isa<AffineDelinearizeIndexOp, AffineLinearizeIndexOp>(op))
336 return llvm::all_of(op->getOperands(),
337 [&](
Value arg) { return ::isValidDim(arg, region); });
340 if (
auto dimOp = dyn_cast<ShapedDimOpInterface>(op))
348 template <
typename AnyMemRefDefOp>
351 MemRefType memRefType = memrefDefOp.getType();
354 if (index >= memRefType.getRank()) {
359 if (!memRefType.isDynamicDim(index))
362 unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
363 return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
375 if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
383 if (!index.has_value())
387 Operation *op = dimOp.getShapedValue().getDefiningOp();
388 while (
auto castOp = dyn_cast<memref::CastOp>(op)) {
390 if (isa<UnrankedMemRefType>(castOp.getSource().getType()))
392 op = castOp.getSource().getDefiningOp();
397 int64_t i = index.value();
399 .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
401 .Default([](
Operation *) {
return false; });
435 if (parentRegion == region)
476 if (
isPure(defOp) && llvm::all_of(defOp->getOperands(), [&](
Value operand) {
477 return affine::isValidSymbol(operand, region);
483 if (
auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp))
501 printer <<
'(' << operands.take_front(numDims) <<
')';
502 if (operands.size() > numDims)
503 printer <<
'[' << operands.drop_front(numDims) <<
']';
513 numDims = opInfos.size();
527 template <
typename OpTy>
532 for (
auto operand : operands) {
533 if (opIt++ < numDims) {
535 return op.emitOpError(
"operand cannot be used as a dimension id");
537 return op.emitOpError(
"operand cannot be used as a symbol");
548 return AffineValueMap(getAffineMap(), getOperands(), getResult());
555 AffineMapAttr mapAttr;
561 auto map = mapAttr.getValue();
563 if (map.getNumDims() != numDims ||
564 numDims + map.getNumSymbols() != result.
operands.size()) {
566 "dimension or symbol index mismatch");
569 result.
types.append(map.getNumResults(), indexTy);
574 p <<
" " << getMapAttr();
576 getAffineMap().getNumDims(), p);
587 "operand count and affine map dimension and symbol count must match");
591 return emitOpError(
"mapping must produce one value");
597 for (
Value operand : getMapOperands().drop_front(affineMap.
getNumDims())) {
599 return emitError(
"dimensional operand cannot be used as a symbol");
608 return llvm::all_of(getOperands(),
616 return llvm::all_of(getOperands(),
623 return llvm::all_of(getOperands(),
630 return llvm::all_of(getOperands(), [&](
Value operand) {
636 auto map = getAffineMap();
639 auto expr = map.getResult(0);
640 if (
auto dim = dyn_cast<AffineDimExpr>(expr))
641 return getOperand(dim.getPosition());
642 if (
auto sym = dyn_cast<AffineSymbolExpr>(expr))
643 return getOperand(map.getNumDims() + sym.getPosition());
647 bool hasPoison =
false;
649 map.constantFold(adaptor.getMapOperands(), result, &hasPoison);
669 auto dimExpr = dyn_cast<AffineDimExpr>(e);
679 Value operand = operands[dimExpr.getPosition()];
680 int64_t operandDivisor = 1;
684 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() == 0) {
685 operandDivisor = forOp.getStepAsInt();
687 uint64_t lbLargestKnownDivisor =
688 forOp.getLowerBoundMap().getLargestKnownDivisorOfMapExprs();
689 operandDivisor = std::gcd(lbLargestKnownDivisor, forOp.getStepAsInt());
692 return operandDivisor;
699 if (
auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
700 int64_t constVal = constExpr.getValue();
701 return constVal >= 0 && constVal < k;
703 auto dimExpr = dyn_cast<AffineDimExpr>(e);
706 Value operand = operands[dimExpr.getPosition()];
710 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() >= 0 &&
711 forOp.hasConstantUpperBound() && forOp.getConstantUpperBound() <= k) {
727 auto bin = dyn_cast<AffineBinaryOpExpr>(e);
735 quotientTimesDiv = llhs;
741 quotientTimesDiv = rlhs;
751 if (forOp && forOp.hasConstantLowerBound())
752 return forOp.getConstantLowerBound();
759 if (!forOp || !forOp.hasConstantUpperBound())
764 if (forOp.hasConstantLowerBound()) {
765 return forOp.getConstantUpperBound() - 1 -
766 (forOp.getConstantUpperBound() - forOp.getConstantLowerBound() - 1) %
767 forOp.getStepAsInt();
769 return forOp.getConstantUpperBound() - 1;
780 constLowerBounds.reserve(operands.size());
781 constUpperBounds.reserve(operands.size());
782 for (
Value operand : operands) {
787 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr))
788 return constExpr.getValue();
803 constLowerBounds.reserve(operands.size());
804 constUpperBounds.reserve(operands.size());
805 for (
Value operand : operands) {
810 std::optional<int64_t> lowerBound;
811 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
812 lowerBound = constExpr.getValue();
815 constLowerBounds, constUpperBounds,
826 auto binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
837 binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
845 lhs = binExpr.getLHS();
846 rhs = binExpr.getRHS();
847 auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
851 int64_t rhsConstVal = rhsConst.getValue();
853 if (rhsConstVal <= 0)
858 std::optional<int64_t> lhsLbConst =
860 std::optional<int64_t> lhsUbConst =
862 if (lhsLbConst && lhsUbConst) {
863 int64_t lhsLbConstVal = *lhsLbConst;
864 int64_t lhsUbConstVal = *lhsUbConst;
868 divideFloorSigned(lhsLbConstVal, rhsConstVal) ==
869 divideFloorSigned(lhsUbConstVal, rhsConstVal)) {
871 divideFloorSigned(lhsLbConstVal, rhsConstVal), context);
877 divideCeilSigned(lhsLbConstVal, rhsConstVal) ==
878 divideCeilSigned(lhsUbConstVal, rhsConstVal)) {
885 lhsLbConstVal < rhsConstVal && lhsUbConstVal < rhsConstVal) {
897 if (
isQTimesDPlusR(lhs, operands, divisor, quotientTimesDiv, rem)) {
898 if (rhsConstVal % divisor == 0 &&
900 expr = quotientTimesDiv.
floorDiv(rhsConst);
901 }
else if (divisor % rhsConstVal == 0 &&
903 expr = rem % rhsConst;
929 if (operands.empty())
935 constLowerBounds.reserve(operands.size());
936 constUpperBounds.reserve(operands.size());
937 for (
Value operand : operands) {
951 if (
auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
952 lowerBounds.push_back(constExpr.getValue());
953 upperBounds.push_back(constExpr.getValue());
955 lowerBounds.push_back(
957 constLowerBounds, constUpperBounds,
959 upperBounds.push_back(
961 constLowerBounds, constUpperBounds,
970 unsigned i = exprEn.index();
972 if (lowerBounds[i] && upperBounds[i] && *lowerBounds[i] == *upperBounds[i])
977 if (!upperBounds[i]) {
978 irredundantExprs.push_back(e);
984 auto otherLowerBound = en.value();
985 unsigned pos = en.index();
986 if (pos == i || !otherLowerBound)
988 if (*otherLowerBound > *upperBounds[i])
990 if (*otherLowerBound < *upperBounds[i])
995 if (upperBounds[pos] && lowerBounds[i] &&
996 lowerBounds[i] == upperBounds[i] &&
997 otherLowerBound == *upperBounds[pos] && i < pos)
1001 irredundantExprs.push_back(e);
1003 if (!lowerBounds[i]) {
1004 irredundantExprs.push_back(e);
1009 auto otherUpperBound = en.value();
1010 unsigned pos = en.index();
1011 if (pos == i || !otherUpperBound)
1013 if (*otherUpperBound < *lowerBounds[i])
1015 if (*otherUpperBound > *lowerBounds[i])
1017 if (lowerBounds[pos] && upperBounds[i] &&
1018 lowerBounds[i] == upperBounds[i] &&
1019 otherUpperBound == lowerBounds[pos] && i < pos)
1023 irredundantExprs.push_back(e);
1035 static void LLVM_ATTRIBUTE_UNUSED
1037 assert(map.
getNumInputs() == operands.size() &&
"invalid operands for map");
1043 newResults.push_back(expr);
1066 LDBG() <<
"replaceAffineMinBoundingBoxExpression: `" << minOp <<
"`";
1067 AffineMap affineMinMap = minOp.getAffineMap();
1070 for (
unsigned i = 0, e = affineMinMap.
getNumResults(); i < e; ++i) {
1074 ValueBoundsConstraintSet::ComparisonOperator::LT,
1076 minOp.getOperands())))
1085 auto it = llvm::find(dims, dim);
1086 if (it == dims.end()) {
1087 unmappedDims.push_back(i);
1094 auto it = llvm::find(syms, sym);
1095 if (it == syms.end()) {
1096 unmappedSyms.push_back(i);
1109 if (llvm::any_of(unmappedDims,
1110 [&](
unsigned i) {
return expr.isFunctionOfDim(i); }) ||
1111 llvm::any_of(unmappedSyms,
1112 [&](
unsigned i) {
return expr.isFunctionOfSymbol(i); }))
1118 repl[dimOrSym.
ceilDiv(convertedExpr)] = c1;
1120 repl[(dimOrSym + convertedExpr - 1).floorDiv(convertedExpr)] = c1;
1125 return success(*map != initialMap);
1141 unsigned dimOrSymbolPosition,
1144 bool replaceAffineMin) {
1146 bool isDimReplacement = (dimOrSymbolPosition < dims.size());
1147 unsigned pos = isDimReplacement ? dimOrSymbolPosition
1148 : dimOrSymbolPosition - dims.size();
1149 Value &v = isDimReplacement ? dims[pos] : syms[pos];
1153 if (
auto minOp = v.
getDefiningOp<AffineMinOp>(); minOp && replaceAffineMin) {
1169 AffineMap composeMap = affineApply.getAffineMap();
1170 assert(composeMap.
getNumResults() == 1 &&
"affine.apply with >1 results");
1172 affineApply.getMapOperands().end());
1186 dims.append(composeDims.begin(), composeDims.end());
1187 syms.append(composeSyms.begin(), composeSyms.end());
1188 *map = map->
replace(toReplace, replacementExpr, dims.size(), syms.size());
1198 bool composeAffineMin =
false) {
1218 for (
unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
1231 unsigned nDims = 0, nSyms = 0;
1233 dimReplacements.reserve(dims.size());
1234 symReplacements.reserve(syms.size());
1235 for (
auto *container : {&dims, &syms}) {
1236 bool isDim = (container == &dims);
1237 auto &repls = isDim ? dimReplacements : symReplacements;
1239 Value v = en.value();
1243 "map is function of unexpected expr@pos");
1249 operands->push_back(v);
1262 while (llvm::any_of(*operands, [](
Value v) {
1268 if (composeAffineMin && llvm::any_of(*operands, [](
Value v) {
1278 bool composeAffineMin) {
1283 return AffineApplyOp::create(b, loc, map, valueOperands);
1289 bool composeAffineMin) {
1294 operands, composeAffineMin);
1301 bool composeAffineMin =
false) {
1307 for (
unsigned i : llvm::seq<unsigned>(0, map.
getNumResults())) {
1315 llvm::append_range(dims,
1317 llvm::append_range(symbols,
1324 operands = llvm::to_vector(llvm::concat<Value>(dims, symbols));
1331 bool composeAffineMin) {
1332 assert(map.
getNumResults() == 1 &&
"building affine.apply with !=1 result");
1342 AffineApplyOp applyOp =
1347 for (
unsigned i = 0, e = constOperands.size(); i != e; ++i)
1352 if (
failed(applyOp->fold(constOperands, foldResults)) ||
1353 foldResults.empty()) {
1355 listener->notifyOperationInserted(applyOp, {});
1356 return applyOp.getResult();
1360 return llvm::getSingleElement(foldResults);
1370 operands, composeAffineMin);
1376 bool composeAffineMin) {
1377 return llvm::map_to_vector(
1378 llvm::seq<unsigned>(0, map.
getNumResults()), [&](
unsigned i) {
1379 return makeComposedFoldedAffineApply(b, loc, map.getSubMap({i}),
1380 operands, composeAffineMin);
1384 template <
typename OpTy>
1390 return OpTy::create(b, loc, b.
getIndexType(), map, valueOperands);
1396 return makeComposedMinMax<AffineMinOp>(b, loc, map, operands);
1399 template <
typename OpTy>
1411 auto minMaxOp = makeComposedMinMax<OpTy>(newBuilder, loc, map, operands);
1415 for (
unsigned i = 0, e = constOperands.size(); i != e; ++i)
1420 if (
failed(minMaxOp->fold(constOperands, foldResults)) ||
1421 foldResults.empty()) {
1423 listener->notifyOperationInserted(minMaxOp, {});
1424 return minMaxOp.getResult();
1428 return llvm::getSingleElement(foldResults);
1435 return makeComposedFoldedMinMax<AffineMinOp>(b, loc, map, operands);
1442 return makeComposedFoldedMinMax<AffineMaxOp>(b, loc, map, operands);
1447 template <
class MapOrSet>
1450 if (!mapOrSet || operands->empty())
1453 assert(mapOrSet->getNumInputs() == operands->size() &&
1454 "map/set inputs must match number of operands");
1456 auto *context = mapOrSet->getContext();
1458 resultOperands.reserve(operands->size());
1460 remappedSymbols.reserve(operands->size());
1461 unsigned nextDim = 0;
1462 unsigned nextSym = 0;
1463 unsigned oldNumSyms = mapOrSet->getNumSymbols();
1465 for (
unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
1466 if (i < mapOrSet->getNumDims()) {
1470 remappedSymbols.push_back((*operands)[i]);
1473 resultOperands.push_back((*operands)[i]);
1476 resultOperands.push_back((*operands)[i]);
1480 resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
1481 *operands = resultOperands;
1482 *mapOrSet = mapOrSet->replaceDimsAndSymbols(
1483 dimRemapping, {}, nextDim, oldNumSyms + nextSym);
1485 assert(mapOrSet->getNumInputs() == operands->size() &&
1486 "map/set inputs must match number of operands");
1495 template <
class MapOrSet>
1498 if (!mapOrSet || operands.empty())
1501 unsigned numOperands = operands.size();
1503 assert(mapOrSet.getNumInputs() == numOperands &&
1504 "map/set inputs must match number of operands");
1506 auto *context = mapOrSet.getContext();
1508 resultOperands.reserve(numOperands);
1510 remappedDims.reserve(numOperands);
1512 symOperands.reserve(mapOrSet.getNumSymbols());
1513 unsigned nextSym = 0;
1514 unsigned nextDim = 0;
1515 unsigned oldNumDims = mapOrSet.getNumDims();
1517 resultOperands.assign(operands.begin(), operands.begin() + oldNumDims);
1518 for (
unsigned i = oldNumDims, e = mapOrSet.getNumInputs(); i != e; ++i) {
1521 symRemapping[i - oldNumDims] =
1523 remappedDims.push_back(operands[i]);
1526 symOperands.push_back(operands[i]);
1530 append_range(resultOperands, remappedDims);
1531 append_range(resultOperands, symOperands);
1532 operands = resultOperands;
1533 mapOrSet = mapOrSet.replaceDimsAndSymbols(
1534 {}, symRemapping, oldNumDims + nextDim, nextSym);
1536 assert(mapOrSet.getNumInputs() == operands.size() &&
1537 "map/set inputs must match number of operands");
1541 template <
class MapOrSet>
1544 static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
1545 "Argument must be either of AffineMap or IntegerSet type");
1547 if (!mapOrSet || operands->empty())
1550 assert(mapOrSet->getNumInputs() == operands->size() &&
1551 "map/set inputs must match number of operands");
1553 canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
1554 legalizeDemotedDims<MapOrSet>(*mapOrSet, *operands);
1557 llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
1558 llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
1560 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr))
1561 usedDims[dimExpr.getPosition()] =
true;
1562 else if (
auto symExpr = dyn_cast<AffineSymbolExpr>(expr))
1563 usedSyms[symExpr.getPosition()] =
true;
1566 auto *context = mapOrSet->getContext();
1569 resultOperands.reserve(operands->size());
1571 llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
1573 unsigned nextDim = 0;
1574 for (
unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
1577 auto it = seenDims.find((*operands)[i]);
1578 if (it == seenDims.end()) {
1580 resultOperands.push_back((*operands)[i]);
1581 seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
1583 dimRemapping[i] = it->second;
1587 llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
1589 unsigned nextSym = 0;
1590 for (
unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
1596 IntegerAttr operandCst;
1597 if (
matchPattern((*operands)[i + mapOrSet->getNumDims()],
1604 auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
1605 if (it == seenSymbols.end()) {
1607 resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
1608 seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
1611 symRemapping[i] = it->second;
1614 *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
1616 *operands = resultOperands;
1621 canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
1626 canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
1633 template <
typename AffineOpTy>
1642 LogicalResult matchAndRewrite(AffineOpTy affineOp,
1645 llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
1646 AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp,
1647 AffineVectorStoreOp, AffineVectorLoadOp>::value,
1648 "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op "
1650 auto map = affineOp.getAffineMap();
1652 auto oldOperands = affineOp.getMapOperands();
1657 if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
1658 resultOperands.begin()))
1661 replaceAffineOp(rewriter, affineOp, map, resultOperands);
1669 void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
1676 void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
1680 prefetch, prefetch.getMemref(), map, mapOperands, prefetch.getIsWrite(),
1681 prefetch.getLocalityHint(), prefetch.getIsDataCache());
1684 void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
1688 store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
1691 void SimplifyAffineOp<AffineVectorLoadOp>::replaceAffineOp(
1695 vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map,
1699 void SimplifyAffineOp<AffineVectorStoreOp>::replaceAffineOp(
1703 vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map,
1708 template <
typename AffineOpTy>
1709 void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
1718 results.
add<SimplifyAffineOp<AffineApplyOp>>(context);
1753 Value elementsPerStride) {
1755 build(builder, state, srcMemRef, srcMap, srcIndices, destMemRef, dstMap,
1756 destIndices, tagMemRef, tagMap, tagIndices, numElements, stride,
1758 auto result = dyn_cast<AffineDmaStartOp>(builder.
create(state));
1759 assert(result &&
"builder didn't return the right type");
1768 Value elementsPerStride) {
1769 return create(builder, builder.
getLoc(), srcMemRef, srcMap, srcIndices,
1770 destMemRef, dstMap, destIndices, tagMemRef, tagMap, tagIndices,
1771 numElements, stride, elementsPerStride);
1775 p <<
" " << getSrcMemRef() <<
'[';
1777 p <<
"], " << getDstMemRef() <<
'[';
1779 p <<
"], " << getTagMemRef() <<
'[';
1784 p <<
", " << getNumElementsPerStride();
1786 p <<
" : " << getSrcMemRefType() <<
", " << getDstMemRefType() <<
", "
1787 << getTagMemRefType();
1799 AffineMapAttr srcMapAttr;
1802 AffineMapAttr dstMapAttr;
1805 AffineMapAttr tagMapAttr;
1820 getSrcMapAttrStrName(),
1824 getDstMapAttrStrName(),
1828 getTagMapAttrStrName(),
1837 if (!strideInfo.empty() && strideInfo.size() != 2) {
1839 "expected two stride related operands");
1841 bool isStrided = strideInfo.size() == 2;
1846 if (types.size() != 3)
1864 if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
1865 dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
1866 tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1868 "memref operand count not equal to map.numInputs");
1872 LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
1873 if (!llvm::isa<MemRefType>(getOperand(getSrcMemRefOperandIndex()).
getType()))
1874 return emitOpError(
"expected DMA source to be of memref type");
1875 if (!llvm::isa<MemRefType>(getOperand(getDstMemRefOperandIndex()).
getType()))
1876 return emitOpError(
"expected DMA destination to be of memref type");
1877 if (!llvm::isa<MemRefType>(getOperand(getTagMemRefOperandIndex()).
getType()))
1878 return emitOpError(
"expected DMA tag to be of memref type");
1880 unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
1881 getDstMap().getNumInputs() +
1882 getTagMap().getNumInputs();
1883 if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
1884 getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
1885 return emitOpError(
"incorrect number of operands");
1889 for (
auto idx : getSrcIndices()) {
1890 if (!idx.getType().isIndex())
1891 return emitOpError(
"src index to dma_start must have 'index' type");
1894 "src index must be a valid dimension or symbol identifier");
1896 for (
auto idx : getDstIndices()) {
1897 if (!idx.getType().isIndex())
1898 return emitOpError(
"dst index to dma_start must have 'index' type");
1901 "dst index must be a valid dimension or symbol identifier");
1903 for (
auto idx : getTagIndices()) {
1904 if (!idx.getType().isIndex())
1905 return emitOpError(
"tag index to dma_start must have 'index' type");
1908 "tag index must be a valid dimension or symbol identifier");
1919 void AffineDmaStartOp::getEffects(
1947 Value numElements) {
1949 build(builder, state, tagMemRef, tagMap, tagIndices, numElements);
1950 auto result = dyn_cast<AffineDmaWaitOp>(builder.
create(state));
1951 assert(result &&
"builder didn't return the right type");
1958 Value numElements) {
1959 return create(builder, builder.
getLoc(), tagMemRef, tagMap, tagIndices,
1964 p <<
" " << getTagMemRef() <<
'[';
1969 p <<
" : " << getTagMemRef().getType();
1980 AffineMapAttr tagMapAttr;
1989 getTagMapAttrStrName(),
1998 if (!llvm::isa<MemRefType>(type))
2000 "expected tag to be of memref type");
2002 if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
2004 "tag memref operand count != to map.numInputs");
2008 LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
2009 if (!llvm::isa<MemRefType>(getOperand(0).
getType()))
2010 return emitOpError(
"expected DMA tag to be of memref type");
2012 for (
auto idx : getTagIndices()) {
2013 if (!idx.getType().isIndex())
2014 return emitOpError(
"index to dma_wait must have 'index' type");
2017 "index must be a valid dimension or symbol identifier");
2028 void AffineDmaWaitOp::getEffects(
2044 ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
2045 assert(((!lbMap && lbOperands.empty()) ||
2047 "lower bound operand count does not match the affine map");
2048 assert(((!ubMap && ubOperands.empty()) ||
2050 "upper bound operand count does not match the affine map");
2051 assert(step > 0 &&
"step has to be a positive integer constant");
2057 getOperandSegmentSizeAttr(),
2059 static_cast<int32_t>(ubOperands.size()),
2060 static_cast<int32_t>(iterArgs.size())}));
2062 for (
Value val : iterArgs)
2084 Value inductionVar =
2086 for (
Value val : iterArgs)
2087 bodyBlock->
addArgument(val.getType(), val.getLoc());
2092 if (iterArgs.empty() && !bodyBuilder) {
2093 ensureTerminator(*bodyRegion, builder, result.
location);
2094 }
else if (bodyBuilder) {
2097 bodyBuilder(builder, result.
location, inductionVar,
2103 int64_t ub, int64_t step,
ValueRange iterArgs,
2104 BodyBuilderFn bodyBuilder) {
2107 return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,
2111 LogicalResult AffineForOp::verifyRegions() {
2114 auto *body = getBody();
2115 if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
2116 return emitOpError(
"expected body to have a single index argument for the "
2117 "induction variable");
2121 if (getLowerBoundMap().getNumInputs() > 0)
2123 getLowerBoundMap().getNumDims())))
2126 if (getUpperBoundMap().getNumInputs() > 0)
2128 getUpperBoundMap().getNumDims())))
2130 if (getLowerBoundMap().getNumResults() < 1)
2131 return emitOpError(
"expected lower bound map to have at least one result");
2132 if (getUpperBoundMap().getNumResults() < 1)
2133 return emitOpError(
"expected upper bound map to have at least one result");
2135 unsigned opNumResults = getNumResults();
2136 if (opNumResults == 0)
2142 if (getNumIterOperands() != opNumResults)
2144 "mismatch between the number of loop-carried values and results");
2145 if (getNumRegionIterArgs() != opNumResults)
2147 "mismatch between the number of basic block args and results");
2157 bool failedToParsedMinMax =
2161 auto boundAttrStrName =
2162 isLower ? AffineForOp::getLowerBoundMapAttrName(result.
name)
2163 : AffineForOp::getUpperBoundMapAttrName(result.
name);
2170 if (!boundOpInfos.empty()) {
2172 if (boundOpInfos.size() > 1)
2174 "expected only one loop bound operand");
2199 if (
auto affineMapAttr = dyn_cast<AffineMapAttr>(boundAttr)) {
2200 unsigned currentNumOperands = result.
operands.size();
2205 auto map = affineMapAttr.getValue();
2209 "dim operand count and affine map dim count must match");
2211 unsigned numDimAndSymbolOperands =
2212 result.
operands.size() - currentNumOperands;
2213 if (numDims + map.
getNumSymbols() != numDimAndSymbolOperands)
2216 "symbol operand count and affine map symbol count must match");
2222 return p.
emitError(attrLoc,
"lower loop bound affine map with "
2223 "multiple results requires 'max' prefix");
2225 return p.
emitError(attrLoc,
"upper loop bound affine map with multiple "
2226 "results requires 'min' prefix");
2232 if (
auto integerAttr = dyn_cast<IntegerAttr>(boundAttr)) {
2242 "expected valid affine map representation for loop bounds");
2254 int64_t numOperands = result.
operands.size();
2257 int64_t numLbOperands = result.
operands.size() - numOperands;
2260 numOperands = result.
operands.size();
2263 int64_t numUbOperands = result.
operands.size() - numOperands;
2268 getStepAttrName(result.
name),
2272 IntegerAttr stepAttr;
2274 getStepAttrName(result.
name).data(),
2278 if (stepAttr.getValue().isNegative())
2281 "expected step to be representable as a positive signed integer");
2289 regionArgs.push_back(inductionVariable);
2297 for (
auto argOperandType :
2298 llvm::zip(llvm::drop_begin(regionArgs), operands, result.
types)) {
2299 Type type = std::get<2>(argOperandType);
2300 std::get<0>(argOperandType).type = type;
2308 getOperandSegmentSizeAttr(),
2310 static_cast<int32_t>(numUbOperands),
2311 static_cast<int32_t>(operands.size())}));
2315 if (regionArgs.size() != result.
types.size() + 1)
2318 "mismatch between the number of loop-carried values and results");
2322 AffineForOp::ensureTerminator(*body, builder, result.
location);
2344 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
2345 p << constExpr.getValue();
2353 if (isa<AffineSymbolExpr>(expr)) {
2369 unsigned AffineForOp::getNumIterOperands() {
2370 AffineMap lbMap = getLowerBoundMapAttr().getValue();
2371 AffineMap ubMap = getUpperBoundMapAttr().getValue();
2376 std::optional<MutableArrayRef<OpOperand>>
2377 AffineForOp::getYieldedValuesMutable() {
2378 return cast<AffineYieldOp>(getBody()->getTerminator()).getOperandsMutable();
2390 if (getStepAsInt() != 1)
2391 p <<
" step " << getStepAsInt();
2393 bool printBlockTerminators =
false;
2394 if (getNumIterOperands() > 0) {
2396 auto regionArgs = getRegionIterArgs();
2397 auto operands = getInits();
2399 llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](
auto it) {
2400 p << std::get<0>(it) <<
" = " << std::get<1>(it);
2402 p <<
") -> (" << getResultTypes() <<
")";
2403 printBlockTerminators =
true;
2408 printBlockTerminators);
2410 (*this)->getAttrs(),
2411 {getLowerBoundMapAttrName(getOperation()->getName()),
2412 getUpperBoundMapAttrName(getOperation()->getName()),
2413 getStepAttrName(getOperation()->getName()),
2414 getOperandSegmentSizeAttr()});
2419 auto foldLowerOrUpperBound = [&forOp](
bool lower) {
2423 auto boundOperands =
2424 lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
2425 for (
auto operand : boundOperands) {
2428 operandConstants.push_back(operandCst);
2432 lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
2434 "bound maps should have at least one result");
2440 assert(!foldedResults.empty() &&
"bounds should have at least one result");
2441 auto maxOrMin = llvm::cast<IntegerAttr>(foldedResults[0]).getValue();
2442 for (
unsigned i = 1, e = foldedResults.size(); i < e; i++) {
2443 auto foldedResult = llvm::cast<IntegerAttr>(foldedResults[i]).getValue();
2444 maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
2445 : llvm::APIntOps::smin(maxOrMin, foldedResult);
2447 lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
2448 : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
2453 bool folded =
false;
2454 if (!forOp.hasConstantLowerBound())
2455 folded |= succeeded(foldLowerOrUpperBound(
true));
2458 if (!forOp.hasConstantUpperBound())
2459 folded |= succeeded(foldLowerOrUpperBound(
false));
2460 return success(folded);
2468 auto lbMap = forOp.getLowerBoundMap();
2469 auto ubMap = forOp.getUpperBoundMap();
2470 auto prevLbMap = lbMap;
2471 auto prevUbMap = ubMap;
2484 if (lbMap == prevLbMap && ubMap == prevUbMap)
2487 if (lbMap != prevLbMap)
2488 forOp.setLowerBound(lbOperands, lbMap);
2489 if (ubMap != prevUbMap)
2490 forOp.setUpperBound(ubOperands, ubMap);
2496 static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
2497 int64_t step = forOp.getStepAsInt();
2498 if (!forOp.hasConstantBounds() || step <= 0)
2499 return std::nullopt;
2500 int64_t lb = forOp.getConstantLowerBound();
2501 int64_t ub = forOp.getConstantUpperBound();
2502 return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
2510 LogicalResult matchAndRewrite(AffineForOp forOp,
2513 if (!llvm::hasSingleElement(*forOp.getBody()))
2515 if (forOp.getNumResults() == 0)
2517 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
2518 if (tripCount == 0) {
2521 rewriter.
replaceOp(forOp, forOp.getInits());
2525 auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
2526 auto iterArgs = forOp.getRegionIterArgs();
2527 bool hasValDefinedOutsideLoop =
false;
2528 bool iterArgsNotInOrder =
false;
2529 for (
unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
2530 Value val = yieldOp.getOperand(i);
2531 auto *iterArgIt = llvm::find(iterArgs, val);
2534 if (val == forOp.getInductionVar())
2536 if (iterArgIt == iterArgs.end()) {
2538 assert(forOp.isDefinedOutsideOfLoop(val) &&
2539 "must be defined outside of the loop");
2540 hasValDefinedOutsideLoop =
true;
2541 replacements.push_back(val);
2543 unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
2545 iterArgsNotInOrder =
true;
2546 replacements.push_back(forOp.getInits()[pos]);
2551 if (!tripCount.has_value() &&
2552 (hasValDefinedOutsideLoop || iterArgsNotInOrder))
2556 if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
2558 rewriter.
replaceOp(forOp, replacements);
2566 results.
add<AffineForEmptyLoopFolder>(context);
2570 assert((point.
isParent() || point == getRegion()) &&
"invalid region point");
2577 void AffineForOp::getSuccessorRegions(
2579 assert((point.
isParent() || point == getRegion()) &&
"expected loop region");
2584 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*
this);
2585 if (point.
isParent() && tripCount.has_value()) {
2586 if (tripCount.value() > 0) {
2587 regions.push_back(
RegionSuccessor(&getRegion(), getRegionIterArgs()));
2590 if (tripCount.value() == 0) {
2598 if (!point.
isParent() && tripCount == 1) {
2605 regions.push_back(
RegionSuccessor(&getRegion(), getRegionIterArgs()));
2611 return getTrivialConstantTripCount(op) == 0;
2614 LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
2624 results.assign(getInits().begin(), getInits().end());
2627 return success(folded);
2640 assert(map.
getNumResults() >= 1 &&
"bound map has at least one result");
2641 getLowerBoundOperandsMutable().assign(lbOperands);
2642 setLowerBoundMap(map);
2647 assert(map.
getNumResults() >= 1 &&
"bound map has at least one result");
2648 getUpperBoundOperandsMutable().assign(ubOperands);
2649 setUpperBoundMap(map);
2652 bool AffineForOp::hasConstantLowerBound() {
2653 return getLowerBoundMap().isSingleConstant();
2656 bool AffineForOp::hasConstantUpperBound() {
2657 return getUpperBoundMap().isSingleConstant();
2660 int64_t AffineForOp::getConstantLowerBound() {
2661 return getLowerBoundMap().getSingleConstantResult();
2664 int64_t AffineForOp::getConstantUpperBound() {
2665 return getUpperBoundMap().getSingleConstantResult();
2668 void AffineForOp::setConstantLowerBound(int64_t value) {
2672 void AffineForOp::setConstantUpperBound(int64_t value) {
2676 AffineForOp::operand_range AffineForOp::getControlOperands() {
2681 bool AffineForOp::matchingBoundOperandList() {
2682 auto lbMap = getLowerBoundMap();
2683 auto ubMap = getUpperBoundMap();
2689 for (
unsigned i = 0, e = lbMap.
getNumInputs(); i < e; i++) {
2691 if (getOperand(i) != getOperand(numOperands + i))
2699 std::optional<SmallVector<Value>> AffineForOp::getLoopInductionVars() {
2703 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopLowerBounds() {
2704 if (!hasConstantLowerBound())
2705 return std::nullopt;
2708 OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))};
2711 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopSteps() {
2717 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopUpperBounds() {
2718 if (!hasConstantUpperBound())
2722 OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()))};
2725 FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
2727 bool replaceInitOperandUsesInLoop,
2732 auto inits = llvm::to_vector(getInits());
2733 inits.append(newInitOperands.begin(), newInitOperands.end());
2734 AffineForOp newLoop = AffineForOp::create(
2739 auto yieldOp = cast<AffineYieldOp>(getBody()->getTerminator());
2741 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
2746 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
2747 assert(newInitOperands.size() == newYieldedValues.size() &&
2748 "expected as many new yield values as new iter operands");
2750 yieldOp.getOperandsMutable().append(newYieldedValues);
2755 rewriter.
mergeBlocks(getBody(), newLoop.getBody(),
2756 newLoop.getBody()->getArguments().take_front(
2757 getBody()->getNumArguments()));
2759 if (replaceInitOperandUsesInLoop) {
2762 for (
auto it : llvm::zip(newInitOperands, newIterArgs)) {
2773 newLoop->getResults().take_front(getNumResults()));
2774 return cast<LoopLikeOpInterface>(newLoop.getOperation());
2802 auto ivArg = dyn_cast<BlockArgument>(val);
2803 if (!ivArg || !ivArg.getOwner() || !ivArg.getOwner()->getParent())
2804 return AffineForOp();
2806 ivArg.getOwner()->getParent()->getParentOfType<AffineForOp>())
2808 return forOp.getInductionVar() == val ? forOp : AffineForOp();
2809 return AffineForOp();
2813 auto ivArg = dyn_cast<BlockArgument>(val);
2814 if (!ivArg || !ivArg.getOwner())
2817 auto parallelOp = dyn_cast_if_present<AffineParallelOp>(containingOp);
2818 if (parallelOp && llvm::is_contained(parallelOp.getIVs(), val))
2827 ivs->reserve(forInsts.size());
2828 for (
auto forInst : forInsts)
2829 ivs->push_back(forInst.getInductionVar());
2834 ivs.reserve(affineOps.size());
2837 if (
auto forOp = dyn_cast<AffineForOp>(op))
2838 ivs.push_back(forOp.getInductionVar());
2839 else if (
auto parallelOp = dyn_cast<AffineParallelOp>(op))
2840 for (
size_t i = 0; i < parallelOp.getBody()->getNumArguments(); i++)
2841 ivs.push_back(parallelOp.getBody()->getArgument(i));
2847 template <
typename BoundListTy,
typename LoopCreatorTy>
2852 LoopCreatorTy &&loopCreatorFn) {
2853 assert(lbs.size() == ubs.size() &&
"Mismatch in number of arguments");
2854 assert(lbs.size() == steps.size() &&
"Mismatch in number of arguments");
2866 ivs.reserve(lbs.size());
2867 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
2873 if (i == e - 1 && bodyBuilderFn) {
2875 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2877 AffineYieldOp::create(nestedBuilder, nestedLoc);
2882 auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
2890 int64_t ub, int64_t step,
2891 AffineForOp::BodyBuilderFn bodyBuilderFn) {
2892 return AffineForOp::create(builder, loc, lb, ub, step,
2900 AffineForOp::BodyBuilderFn bodyBuilderFn) {
2903 if (lbConst && ubConst)
2905 ubConst.value(), step, bodyBuilderFn);
2936 LogicalResult matchAndRewrite(AffineIfOp ifOp,
2938 if (ifOp.getElseRegion().empty() ||
2939 !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
2954 LogicalResult matchAndRewrite(AffineIfOp op,
2957 auto isTriviallyFalse = [](
IntegerSet iSet) {
2958 return iSet.isEmptyIntegerSet();
2962 return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 &&
2963 iSet.getConstraint(0) == 0);
2966 IntegerSet affineIfConditions = op.getIntegerSet();
2968 if (isTriviallyFalse(affineIfConditions)) {
2972 if (op.getNumResults() == 0 && !op.hasElse()) {
2978 blockToMove = op.getElseBlock();
2979 }
else if (isTriviallyTrue(affineIfConditions)) {
2980 blockToMove = op.getThenBlock();
2998 rewriter.
eraseOp(blockToMoveTerminator);
3006 void AffineIfOp::getSuccessorRegions(
3015 if (getElseRegion().empty()) {
3016 regions.push_back(getResults());
3032 auto conditionAttr =
3033 (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
3035 return emitOpError(
"requires an integer set attribute named 'condition'");
3038 IntegerSet condition = conditionAttr.getValue();
3040 return emitOpError(
"operand count and condition integer set dimension and "
3041 "symbol count must match");
3053 IntegerSetAttr conditionAttr;
3056 AffineIfOp::getConditionAttrStrName(),
3062 auto set = conditionAttr.getValue();
3063 if (set.getNumDims() != numDims)
3066 "dim operand count and integer set dim count must match");
3067 if (numDims + set.getNumSymbols() != result.
operands.size())
3070 "symbol operand count and integer set symbol count must match");
3084 AffineIfOp::ensureTerminator(*thenRegion, parser.
getBuilder(),
3091 AffineIfOp::ensureTerminator(*elseRegion, parser.
getBuilder(),
3103 auto conditionAttr =
3104 (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
3105 p <<
" " << conditionAttr;
3107 conditionAttr.getValue().getNumDims(), p);
3114 auto &elseRegion = this->getElseRegion();
3115 if (!elseRegion.
empty()) {
3124 getConditionAttrStrName());
3129 ->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName())
3133 void AffineIfOp::setIntegerSet(
IntegerSet newSet) {
3139 (*this)->setOperands(operands);
3144 bool withElseRegion) {
3145 assert(resultTypes.empty() || withElseRegion);
3154 if (resultTypes.empty())
3155 AffineIfOp::ensureTerminator(*thenRegion, builder, result.
location);
3158 if (withElseRegion) {
3160 if (resultTypes.empty())
3161 AffineIfOp::ensureTerminator(*elseRegion, builder, result.
location);
3167 AffineIfOp::build(builder, result, {}, set, args,
3176 bool composeAffineMin =
false) {
3183 if (llvm::none_of(operands,
3194 auto set = getIntegerSet();
3200 if (getIntegerSet() == set && llvm::equal(operands, getOperands()))
3203 setConditional(set, operands);
3209 results.
add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context);
3218 assert(operands.size() == 1 + map.
getNumInputs() &&
"inconsistent operands");
3222 auto memrefType = llvm::cast<MemRefType>(operands[0].
getType());
3223 result.
types.push_back(memrefType.getElementType());
3228 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
3231 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
3233 result.
types.push_back(memrefType.getElementType());
3238 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
3239 int64_t rank = memrefType.getRank();
3244 build(builder, result, memref, map, indices);
3253 AffineMapAttr mapAttr;
3258 AffineLoadOp::getMapAttrStrName(),
3269 if (AffineMapAttr mapAttr =
3270 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3274 {getMapAttrStrName()});
3280 template <
typename AffineMemOpTy>
3281 static LogicalResult
3284 MemRefType memrefType,
unsigned numIndexOperands) {
3287 return op->emitOpError(
"affine map num results must equal memref rank");
3289 return op->emitOpError(
"expects as many subscripts as affine map inputs");
3291 for (
auto idx : mapOperands) {
3292 if (!idx.getType().isIndex())
3293 return op->emitOpError(
"index to load must have 'index' type");
3303 if (
getType() != memrefType.getElementType())
3304 return emitOpError(
"result type must match element type of memref");
3307 *
this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3308 getMapOperands(), memrefType,
3309 getNumOperands() - 1)))
3317 results.
add<SimplifyAffineOp<AffineLoadOp>>(context);
3326 auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
3333 auto global = dyn_cast_or_null<memref::GlobalOp>(
3340 dyn_cast_or_null<DenseElementsAttr>(global.getConstantInitValue());
3344 if (
auto splatAttr = dyn_cast<SplatElementsAttr>(cstAttr))
3345 return splatAttr.getSplatValue<
Attribute>();
3347 if (!getAffineMap().isConstant())
3349 auto indices = llvm::to_vector<4>(
3350 llvm::map_range(getAffineMap().getConstantResults(),
3351 [](int64_t v) -> uint64_t {
return v; }));
3352 return cstAttr.getValues<
Attribute>()[indices];
3362 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
3373 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
3374 int64_t rank = memrefType.getRank();
3379 build(builder, result, valueToStore, memref, map, indices);
3388 AffineMapAttr mapAttr;
3393 mapOperands, mapAttr, AffineStoreOp::getMapAttrStrName(),
3404 p <<
" " << getValueToStore();
3406 if (AffineMapAttr mapAttr =
3407 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3411 {getMapAttrStrName()});
3418 if (getValueToStore().
getType() != memrefType.getElementType())
3420 "value to store must have the same type as memref element type");
3423 *
this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3424 getMapOperands(), memrefType,
3425 getNumOperands() - 2)))
3433 results.
add<SimplifyAffineOp<AffineStoreOp>>(context);
3436 LogicalResult AffineStoreOp::fold(FoldAdaptor adaptor,
3446 template <
typename T>
3449 if (op.getNumOperands() !=
3450 op.getMap().getNumDims() + op.getMap().getNumSymbols())
3451 return op.emitOpError(
3452 "operand count and affine map dimension and symbol count must match");
3454 if (op.getMap().getNumResults() == 0)
3455 return op.emitOpError(
"affine map expect at least one result");
3459 template <
typename T>
3461 p <<
' ' << op->getAttr(T::getMapAttrStrName());
3462 auto operands = op.getOperands();
3463 unsigned numDims = op.getMap().getNumDims();
3464 p <<
'(' << operands.take_front(numDims) <<
')';
3466 if (operands.size() != numDims)
3467 p <<
'[' << operands.drop_front(numDims) <<
']';
3469 {T::getMapAttrStrName()});
3472 template <
typename T>
3479 AffineMapAttr mapAttr;
3495 template <
typename T>
3497 static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
3498 "expected affine min or max op");
3504 auto foldedMap = op.getMap().partialConstantFold(operands, &results);
3506 if (foldedMap.getNumSymbols() == 1 && foldedMap.isSymbolIdentity())
3507 return op.getOperand(0);
3510 if (results.empty()) {
3512 if (foldedMap == op.getMap())
3515 return op.getResult();
3519 auto resultIt = std::is_same<T, AffineMinOp>::value
3520 ? llvm::min_element(results)
3521 : llvm::max_element(results);
3522 if (resultIt == results.end())
3528 template <
typename T>
3534 AffineMap oldMap = affineOp.getAffineMap();
3540 if (!llvm::is_contained(newExprs, expr))
3541 newExprs.push_back(expr);
3571 template <
typename T>
3577 AffineMap oldMap = affineOp.getAffineMap();
3579 affineOp.getMapOperands().take_front(oldMap.
getNumDims());
3581 affineOp.getMapOperands().take_back(oldMap.
getNumSymbols());
3583 auto newDimOperands = llvm::to_vector<8>(dimOperands);
3584 auto newSymOperands = llvm::to_vector<8>(symOperands);
3592 if (
auto symExpr = dyn_cast<AffineSymbolExpr>(expr)) {
3593 Value symValue = symOperands[symExpr.getPosition()];
3595 producerOps.push_back(producerOp);
3598 }
else if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
3599 Value dimValue = dimOperands[dimExpr.getPosition()];
3601 producerOps.push_back(producerOp);
3608 newExprs.push_back(expr);
3611 if (producerOps.empty())
3618 for (T producerOp : producerOps) {
3619 AffineMap producerMap = producerOp.getAffineMap();
3620 unsigned numProducerDims = producerMap.
getNumDims();
3625 producerOp.getMapOperands().take_front(numProducerDims);
3627 producerOp.getMapOperands().take_back(numProducerSyms);
3628 newDimOperands.append(dimValues.begin(), dimValues.end());
3629 newSymOperands.append(symValues.begin(), symValues.end());
3633 newExprs.push_back(expr.shiftDims(numProducerDims, numUsedDims)
3634 .shiftSymbols(numProducerSyms, numUsedSyms));
3637 numUsedDims += numProducerDims;
3638 numUsedSyms += numProducerSyms;
3644 llvm::to_vector<8>(llvm::concat<Value>(newDimOperands, newSymOperands));
3663 if (!resultExpr.isPureAffine())
3668 if (
failed(flattenResult))
3681 if (llvm::is_sorted(flattenedExprs))
3686 llvm::to_vector(llvm::seq<unsigned>(0, map.
getNumResults()));
3687 llvm::sort(resultPermutation, [&](
unsigned lhs,
unsigned rhs) {
3688 return flattenedExprs[lhs] < flattenedExprs[rhs];
3691 for (
unsigned idx : resultPermutation)
3712 template <
typename T>
3718 AffineMap map = affineOp.getAffineMap();
3726 template <
typename T>
3732 if (affineOp.getMap().getNumResults() != 1)
3735 affineOp.getOperands());
3763 return parseAffineMinMaxOp<AffineMinOp>(parser, result);
3791 return parseAffineMinMaxOp<AffineMaxOp>(parser, result);
3810 IntegerAttr hintInfo;
3812 StringRef readOrWrite, cacheType;
3814 AffineMapAttr mapAttr;
3818 AffinePrefetchOp::getMapAttrStrName(),
3824 AffinePrefetchOp::getLocalityHintAttrStrName(),
3834 if (readOrWrite !=
"read" && readOrWrite !=
"write")
3836 "rw specifier has to be 'read' or 'write'");
3837 result.
addAttribute(AffinePrefetchOp::getIsWriteAttrStrName(),
3840 if (cacheType !=
"data" && cacheType !=
"instr")
3842 "cache type has to be 'data' or 'instr'");
3844 result.
addAttribute(AffinePrefetchOp::getIsDataCacheAttrStrName(),
3851 p <<
" " << getMemref() <<
'[';
3852 AffineMapAttr mapAttr =
3853 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3856 p <<
']' <<
", " << (getIsWrite() ?
"write" :
"read") <<
", "
3857 <<
"locality<" << getLocalityHint() <<
">, "
3858 << (getIsDataCache() ?
"data" :
"instr");
3860 (*this)->getAttrs(),
3861 {getMapAttrStrName(), getLocalityHintAttrStrName(),
3862 getIsDataCacheAttrStrName(), getIsWriteAttrStrName()});
3867 auto mapAttr = (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3871 return emitOpError(
"affine.prefetch affine map num results must equal"
3874 return emitOpError(
"too few operands");
3876 if (getNumOperands() != 1)
3877 return emitOpError(
"too few operands");
3881 for (
auto idx : getMapOperands()) {
3884 "index must be a valid dimension or symbol identifier");
3892 results.
add<SimplifyAffineOp<AffinePrefetchOp>>(context);
3895 LogicalResult AffinePrefetchOp::fold(FoldAdaptor adaptor,
3910 auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) {
3914 build(builder, result, resultTypes, reductions, lbs, {}, ubs,
3924 assert(llvm::all_of(lbMaps,
3926 return m.
getNumDims() == lbMaps[0].getNumDims() &&
3929 "expected all lower bounds maps to have the same number of dimensions "
3931 assert(llvm::all_of(ubMaps,
3933 return m.
getNumDims() == ubMaps[0].getNumDims() &&
3936 "expected all upper bounds maps to have the same number of dimensions "
3938 assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) &&
3939 "expected lower bound maps to have as many inputs as lower bound "
3941 assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) &&
3942 "expected upper bound maps to have as many inputs as upper bound "
3950 for (arith::AtomicRMWKind reduction : reductions)
3951 reductionAttrs.push_back(
3963 groups.reserve(groups.size() + maps.size());
3964 exprs.reserve(maps.size());
3969 return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs,
3975 AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups);
3976 AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups);
3994 for (
unsigned i = 0, e = steps.size(); i < e; ++i)
3996 if (resultTypes.empty())
3997 ensureTerminator(*bodyRegion, builder, result.
location);
4001 return {&getRegion()};
4004 unsigned AffineParallelOp::getNumDims() {
return getSteps().size(); }
4006 AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
4007 return getOperands().take_front(getLowerBoundsMap().getNumInputs());
4010 AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
4011 return getOperands().drop_front(getLowerBoundsMap().getNumInputs());
4014 AffineMap AffineParallelOp::getLowerBoundMap(
unsigned pos) {
4015 auto values = getLowerBoundsGroups().getValues<int32_t>();
4017 for (
unsigned i = 0; i < pos; ++i)
4019 return getLowerBoundsMap().getSliceMap(start, values[pos]);
4022 AffineMap AffineParallelOp::getUpperBoundMap(
unsigned pos) {
4023 auto values = getUpperBoundsGroups().getValues<int32_t>();
4025 for (
unsigned i = 0; i < pos; ++i)
4027 return getUpperBoundsMap().getSliceMap(start, values[pos]);
4031 return AffineValueMap(getLowerBoundsMap(), getLowerBoundsOperands());
4035 return AffineValueMap(getUpperBoundsMap(), getUpperBoundsOperands());
4038 std::optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
4039 if (hasMinMaxBounds())
4040 return std::nullopt;
4045 AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
4048 for (
unsigned i = 0, e = rangesValueMap.
getNumResults(); i < e; ++i) {
4049 auto expr = rangesValueMap.
getResult(i);
4050 auto cst = dyn_cast<AffineConstantExpr>(expr);
4052 return std::nullopt;
4053 out.push_back(cst.getValue());
4058 Block *AffineParallelOp::getBody() {
return &getRegion().
front(); }
4060 OpBuilder AffineParallelOp::getBodyBuilder() {
4061 return OpBuilder(getBody(), std::prev(getBody()->end()));
4066 "operands to map must match number of inputs");
4068 auto ubOperands = getUpperBoundsOperands();
4071 newOperands.append(ubOperands.begin(), ubOperands.end());
4072 (*this)->setOperands(newOperands);
4079 "operands to map must match number of inputs");
4082 newOperands.append(ubOperands.begin(), ubOperands.end());
4083 (*this)->setOperands(newOperands);
4089 setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
4094 arith::AtomicRMWKind op) {
4096 case arith::AtomicRMWKind::addf:
4097 return isa<FloatType>(resultType);
4098 case arith::AtomicRMWKind::addi:
4099 return isa<IntegerType>(resultType);
4100 case arith::AtomicRMWKind::assign:
4102 case arith::AtomicRMWKind::mulf:
4103 return isa<FloatType>(resultType);
4104 case arith::AtomicRMWKind::muli:
4105 return isa<IntegerType>(resultType);
4106 case arith::AtomicRMWKind::maximumf:
4107 return isa<FloatType>(resultType);
4108 case arith::AtomicRMWKind::minimumf:
4109 return isa<FloatType>(resultType);
4110 case arith::AtomicRMWKind::maxs: {
4111 auto intType = dyn_cast<IntegerType>(resultType);
4112 return intType && intType.isSigned();
4114 case arith::AtomicRMWKind::mins: {
4115 auto intType = dyn_cast<IntegerType>(resultType);
4116 return intType && intType.isSigned();
4118 case arith::AtomicRMWKind::maxu: {
4119 auto intType = dyn_cast<IntegerType>(resultType);
4120 return intType && intType.isUnsigned();
4122 case arith::AtomicRMWKind::minu: {
4123 auto intType = dyn_cast<IntegerType>(resultType);
4124 return intType && intType.isUnsigned();
4126 case arith::AtomicRMWKind::ori:
4127 return isa<IntegerType>(resultType);
4128 case arith::AtomicRMWKind::andi:
4129 return isa<IntegerType>(resultType);
4136 auto numDims = getNumDims();
4139 getSteps().size() != numDims || getBody()->getNumArguments() != numDims) {
4140 return emitOpError() <<
"the number of region arguments ("
4141 << getBody()->getNumArguments()
4142 <<
") and the number of map groups for lower ("
4143 << getLowerBoundsGroups().getNumElements()
4144 <<
") and upper bound ("
4145 << getUpperBoundsGroups().getNumElements()
4146 <<
"), and the number of steps (" << getSteps().size()
4147 <<
") must all match";
4150 unsigned expectedNumLBResults = 0;
4151 for (APInt v : getLowerBoundsGroups()) {
4152 unsigned results = v.getZExtValue();
4154 return emitOpError()
4155 <<
"expected lower bound map to have at least one result";
4156 expectedNumLBResults += results;
4158 if (expectedNumLBResults != getLowerBoundsMap().getNumResults())
4159 return emitOpError() <<
"expected lower bounds map to have "
4160 << expectedNumLBResults <<
" results";
4161 unsigned expectedNumUBResults = 0;
4162 for (APInt v : getUpperBoundsGroups()) {
4163 unsigned results = v.getZExtValue();
4165 return emitOpError()
4166 <<
"expected upper bound map to have at least one result";
4167 expectedNumUBResults += results;
4169 if (expectedNumUBResults != getUpperBoundsMap().getNumResults())
4170 return emitOpError() <<
"expected upper bounds map to have "
4171 << expectedNumUBResults <<
" results";
4173 if (getReductions().size() != getNumResults())
4174 return emitOpError(
"a reduction must be specified for each output");
4180 auto intAttr = dyn_cast<IntegerAttr>(attr);
4181 if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
4182 return emitOpError(
"invalid reduction attribute");
4183 auto kind = arith::symbolizeAtomicRMWKind(intAttr.getInt()).value();
4185 return emitOpError(
"result type cannot match reduction attribute");
4191 getLowerBoundsMap().getNumDims())))
4195 getUpperBoundsMap().getNumDims())))
4200 LogicalResult AffineValueMap::canonicalize() {
4202 auto newMap = getAffineMap();
4204 if (newMap == getAffineMap() && newOperands == operands)
4206 reset(newMap, newOperands);
4219 if (!lbCanonicalized && !ubCanonicalized)
4222 if (lbCanonicalized)
4224 if (ubCanonicalized)
4230 LogicalResult AffineParallelOp::fold(FoldAdaptor adaptor,
4242 StringRef keyword) {
4245 ValueRange dimOperands = operands.take_front(numDims);
4246 ValueRange symOperands = operands.drop_front(numDims);
4248 for (llvm::APInt groupSize : group) {
4252 unsigned size = groupSize.getZExtValue();
4257 p << keyword <<
'(';
4267 p <<
" (" << getBody()->getArguments() <<
") = (";
4269 getLowerBoundsOperands(),
"max");
4272 getUpperBoundsOperands(),
"min");
4275 bool elideSteps = llvm::all_of(steps, [](int64_t step) {
return step == 1; });
4278 llvm::interleaveComma(steps, p);
4281 if (getNumResults()) {
4283 llvm::interleaveComma(getReductions(), p, [&](
auto &attr) {
4284 arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(
4285 llvm::cast<IntegerAttr>(attr).getInt());
4286 p <<
"\"" << arith::stringifyAtomicRMWKind(sym) <<
"\"";
4288 p <<
") -> (" << getResultTypes() <<
")";
4295 (*this)->getAttrs(),
4296 {AffineParallelOp::getReductionsAttrStrName(),
4297 AffineParallelOp::getLowerBoundsMapAttrStrName(),
4298 AffineParallelOp::getLowerBoundsGroupsAttrStrName(),
4299 AffineParallelOp::getUpperBoundsMapAttrStrName(),
4300 AffineParallelOp::getUpperBoundsGroupsAttrStrName(),
4301 AffineParallelOp::getStepsAttrStrName()});
4314 "expected operands to be dim or symbol expression");
4317 for (
const auto &list : operands) {
4321 for (
Value operand : valueOperands) {
4322 unsigned pos = std::distance(uniqueOperands.begin(),
4323 llvm::find(uniqueOperands, operand));
4324 if (pos == uniqueOperands.size())
4325 uniqueOperands.push_back(operand);
4326 replacements.push_back(
4336 enum class MinMaxKind { Min, Max };
4360 const llvm::StringLiteral tmpAttrStrName =
"__pseudo_bound_map";
4362 StringRef mapName =
kind == MinMaxKind::Min
4363 ? AffineParallelOp::getUpperBoundsMapAttrStrName()
4364 : AffineParallelOp::getLowerBoundsMapAttrStrName();
4365 StringRef groupsName =
4366 kind == MinMaxKind::Min
4367 ? AffineParallelOp::getUpperBoundsGroupsAttrStrName()
4368 : AffineParallelOp::getLowerBoundsGroupsAttrStrName();
4385 auto parseOperands = [&]() {
4387 kind == MinMaxKind::Min ?
"min" :
"max"))) {
4388 mapOperands.clear();
4395 llvm::append_range(flatExprs, map.getValue().getResults());
4397 auto dimsRef = operandsRef.take_front(map.getValue().getNumDims());
4399 auto symsRef = operandsRef.drop_front(map.getValue().getNumDims());
4401 flatDimOperands.append(map.getValue().getNumResults(), dims);
4402 flatSymOperands.append(map.getValue().getNumResults(), syms);
4403 numMapsPerGroup.push_back(map.getValue().getNumResults());
4406 flatSymOperands.emplace_back(),
4407 flatExprs.emplace_back())))
4409 numMapsPerGroup.push_back(1);
4416 unsigned totalNumDims = 0;
4417 unsigned totalNumSyms = 0;
4418 for (
unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
4419 unsigned numDims = flatDimOperands[i].size();
4420 unsigned numSyms = flatSymOperands[i].size();
4421 flatExprs[i] = flatExprs[i]
4422 .shiftDims(numDims, totalNumDims)
4423 .shiftSymbols(numSyms, totalNumSyms);
4424 totalNumDims += numDims;
4425 totalNumSyms += numSyms;
4437 result.
operands.append(dimOperands.begin(), dimOperands.end());
4438 result.
operands.append(symOperands.begin(), symOperands.end());
4441 auto flatMap =
AffineMap::get(totalNumDims, totalNumSyms, flatExprs,
4443 flatMap = flatMap.replaceDimsAndSymbols(
4444 dimRplacements, symRepacements, dimOperands.size(), symOperands.size());
4468 AffineMapAttr stepsMapAttr;
4473 result.
addAttribute(AffineParallelOp::getStepsAttrStrName(),
4477 AffineParallelOp::getStepsAttrStrName(),
4484 auto stepsMap = stepsMapAttr.getValue();
4485 for (
const auto &result : stepsMap.getResults()) {
4486 auto constExpr = dyn_cast<AffineConstantExpr>(result);
4489 "steps must be constant integers");
4490 steps.push_back(constExpr.getValue());
4492 result.
addAttribute(AffineParallelOp::getStepsAttrStrName(),
4502 auto parseAttributes = [&]() -> ParseResult {
4512 std::optional<arith::AtomicRMWKind> reduction =
4513 arith::symbolizeAtomicRMWKind(attrVal.getValue());
4515 return parser.
emitError(loc,
"invalid reduction value: ") << attrVal;
4516 reductions.push_back(
4524 result.
addAttribute(AffineParallelOp::getReductionsAttrStrName(),
4533 for (
auto &iv : ivs)
4534 iv.type = indexType;
4540 AffineParallelOp::ensureTerminator(*body, builder, result.
location);
4549 auto *parentOp = (*this)->getParentOp();
4550 auto results = parentOp->getResults();
4551 auto operands = getOperands();
4553 if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
4554 return emitOpError() <<
"only terminates affine.if/for/parallel regions";
4555 if (parentOp->getNumResults() != getNumOperands())
4556 return emitOpError() <<
"parent of yield must have same number of "
4557 "results as the yield operands";
4558 for (
auto it : llvm::zip(results, operands)) {
4560 return emitOpError() <<
"types mismatch between yield op and its parent";
4573 assert(operands.size() == 1 + map.
getNumInputs() &&
"inconsistent operands");
4577 result.
types.push_back(resultType);
4581 VectorType resultType,
Value memref,
4583 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
4587 result.
types.push_back(resultType);
4591 VectorType resultType,
Value memref,
4593 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
4594 int64_t rank = memrefType.getRank();
4599 build(builder, result, resultType, memref, map, indices);
4602 void AffineVectorLoadOp::getCanonicalizationPatterns(
RewritePatternSet &results,
4604 results.
add<SimplifyAffineOp<AffineVectorLoadOp>>(context);
4612 MemRefType memrefType;
4613 VectorType resultType;
4615 AffineMapAttr mapAttr;
4620 AffineVectorLoadOp::getMapAttrStrName(),
4632 if (AffineMapAttr mapAttr =
4633 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4637 {getMapAttrStrName()});
4643 VectorType vectorType) {
4645 if (memrefType.getElementType() != vectorType.getElementType())
4647 "requires memref and vector types of the same elemental type");
4654 *
this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4655 getMapOperands(), memrefType,
4656 getNumOperands() - 1)))
4672 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
4683 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
4684 int64_t rank = memrefType.getRank();
4689 build(builder, result, valueToStore, memref, map, indices);
4691 void AffineVectorStoreOp::getCanonicalizationPatterns(
4693 results.
add<SimplifyAffineOp<AffineVectorStoreOp>>(context);
4700 MemRefType memrefType;
4701 VectorType resultType;
4704 AffineMapAttr mapAttr;
4710 AffineVectorStoreOp::getMapAttrStrName(),
4721 p <<
" " << getValueToStore();
4723 if (AffineMapAttr mapAttr =
4724 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4728 {getMapAttrStrName()});
4729 p <<
" : " <<
getMemRefType() <<
", " << getValueToStore().getType();
4735 *
this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4736 getMapOperands(), memrefType,
4737 getNumOperands() - 2)))
4750 void AffineDelinearizeIndexOp::build(
OpBuilder &odsBuilder,
4754 bool hasOuterBound) {
4756 : staticBasis.size() + 1,
4758 build(odsBuilder, odsState, returnTypes, linearIndex, dynamicBasis,
4762 void AffineDelinearizeIndexOp::build(
OpBuilder &odsBuilder,
4765 bool hasOuterBound) {
4766 if (hasOuterBound && !basis.empty() && basis.front() ==
nullptr) {
4767 hasOuterBound =
false;
4768 basis = basis.drop_front();
4774 build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4778 void AffineDelinearizeIndexOp::build(
OpBuilder &odsBuilder,
4782 bool hasOuterBound) {
4783 if (hasOuterBound && !basis.empty() && basis.front() ==
OpFoldResult()) {
4784 hasOuterBound =
false;
4785 basis = basis.drop_front();
4790 build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4794 void AffineDelinearizeIndexOp::build(
OpBuilder &odsBuilder,
4797 bool hasOuterBound) {
4798 build(odsBuilder, odsState, linearIndex,
ValueRange{}, basis, hasOuterBound);
4803 if (getNumResults() != staticBasis.size() &&
4804 getNumResults() != staticBasis.size() + 1)
4805 return emitOpError(
"should return an index for each basis element and up "
4806 "to one extra index");
4808 auto dynamicMarkersCount = llvm::count_if(staticBasis, ShapedType::isDynamic);
4809 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
4811 "mismatch between dynamic and static basis (kDynamic marker but no "
4812 "corresponding dynamic basis entry) -- this can only happen due to an "
4813 "incorrect fold/rewrite");
4815 if (!llvm::all_of(staticBasis, [](int64_t v) {
4816 return v > 0 || ShapedType::isDynamic(v);
4818 return emitOpError(
"no basis element may be statically non-positive");
4827 static std::optional<SmallVector<int64_t>>
4831 uint64_t dynamicBasisIndex = 0;
4834 mutableDynamicBasis.
erase(dynamicBasisIndex);
4836 ++dynamicBasisIndex;
4841 if (dynamicBasisIndex == dynamicBasis.size())
4842 return std::nullopt;
4848 staticBasis.push_back(ShapedType::kDynamic);
4850 staticBasis.push_back(*basisVal);
4857 AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
4859 std::optional<SmallVector<int64_t>> maybeStaticBasis =
4861 adaptor.getDynamicBasis());
4862 if (maybeStaticBasis) {
4863 setStaticBasis(*maybeStaticBasis);
4868 if (getNumResults() == 1) {
4869 result.push_back(getLinearIndex());
4873 if (adaptor.getLinearIndex() ==
nullptr)
4876 if (!adaptor.getDynamicBasis().empty())
4879 int64_t highPart = cast<IntegerAttr>(adaptor.getLinearIndex()).getInt();
4880 Type attrType = getLinearIndex().getType();
4883 if (hasOuterBound())
4884 staticBasis = staticBasis.drop_front();
4885 for (int64_t modulus : llvm::reverse(staticBasis)) {
4886 result.push_back(
IntegerAttr::get(attrType, llvm::mod(highPart, modulus)));
4887 highPart = llvm::divideFloorSigned(highPart, modulus);
4890 std::reverse(result.begin(), result.end());
4896 if (hasOuterBound()) {
4897 if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
4899 getDynamicBasis().drop_front(), builder);
4901 return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
4905 return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
4910 if (!hasOuterBound())
4918 struct DropUnitExtentBasis
4922 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4925 std::optional<Value> zero = std::nullopt;
4926 Location loc = delinearizeOp->getLoc();
4930 return zero.value();
4936 for (
auto [index, basis] :
4938 std::optional<int64_t> basisVal =
4941 replacements[index] =
getZero();
4943 newBasis.push_back(basis);
4946 if (newBasis.size() == delinearizeOp.getNumResults())
4948 "no unit basis elements");
4950 if (!newBasis.empty()) {
4952 auto newDelinearizeOp = affine::AffineDelinearizeIndexOp::create(
4953 rewriter, loc, delinearizeOp.getLinearIndex(), newBasis);
4956 for (
auto &replacement : replacements) {
4959 replacement = newDelinearizeOp->getResult(newIndex++);
4963 rewriter.
replaceOp(delinearizeOp, replacements);
4978 struct CancelDelinearizeOfLinearizeDisjointExactTail
4982 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4984 auto linearizeOp = delinearizeOp.getLinearIndex()
4985 .getDefiningOp<affine::AffineLinearizeIndexOp>();
4988 "index doesn't come from linearize");
4990 if (!linearizeOp.getDisjoint())
4993 ValueRange linearizeIns = linearizeOp.getMultiIndex();
4997 size_t numMatches = 0;
4998 for (
auto [linSize, delinSize] : llvm::zip(
4999 llvm::reverse(linearizeBasis), llvm::reverse(delinearizeBasis))) {
5000 if (linSize != delinSize)
5005 if (numMatches == 0)
5007 delinearizeOp,
"final basis element doesn't match linearize");
5010 if (numMatches == linearizeBasis.size() &&
5011 numMatches == delinearizeBasis.size() &&
5012 linearizeIns.size() == delinearizeOp.getNumResults()) {
5013 rewriter.
replaceOp(delinearizeOp, linearizeOp.getMultiIndex());
5017 Value newLinearize = affine::AffineLinearizeIndexOp::create(
5018 rewriter, linearizeOp.getLoc(), linearizeIns.drop_back(numMatches),
5020 linearizeOp.getDisjoint());
5021 auto newDelinearize = affine::AffineDelinearizeIndexOp::create(
5022 rewriter, delinearizeOp.getLoc(), newLinearize,
5024 delinearizeOp.hasOuterBound());
5026 mergedResults.append(linearizeIns.take_back(numMatches).begin(),
5027 linearizeIns.take_back(numMatches).end());
5028 rewriter.
replaceOp(delinearizeOp, mergedResults);
5046 struct SplitDelinearizeSpanningLastLinearizeArg final
5050 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
5052 auto linearizeOp = delinearizeOp.getLinearIndex()
5053 .getDefiningOp<affine::AffineLinearizeIndexOp>();
5056 "index doesn't come from linearize");
5058 if (!linearizeOp.getDisjoint())
5060 "linearize isn't disjoint");
5062 int64_t target = linearizeOp.getStaticBasis().back();
5063 if (ShapedType::isDynamic(target))
5065 linearizeOp,
"linearize ends with dynamic basis value");
5067 int64_t sizeToSplit = 1;
5068 size_t elemsToSplit = 0;
5070 for (int64_t basisElem : llvm::reverse(basis)) {
5071 if (ShapedType::isDynamic(basisElem))
5073 delinearizeOp,
"dynamic basis element while scanning for split");
5074 sizeToSplit *= basisElem;
5077 if (sizeToSplit > target)
5079 "overshot last argument size");
5080 if (sizeToSplit == target)
5084 if (sizeToSplit < target)
5086 delinearizeOp,
"product of known basis elements doesn't exceed last "
5087 "linearize argument");
5089 if (elemsToSplit < 2)
5092 "need at least two elements to form the basis product");
5094 Value linearizeWithoutBack = affine::AffineLinearizeIndexOp::create(
5095 rewriter, linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(),
5096 linearizeOp.getDynamicBasis(), linearizeOp.getStaticBasis().drop_back(),
5097 linearizeOp.getDisjoint());
5098 auto delinearizeWithoutSplitPart = affine::AffineDelinearizeIndexOp::create(
5099 rewriter, delinearizeOp.getLoc(), linearizeWithoutBack,
5100 delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),
5101 delinearizeOp.hasOuterBound());
5102 auto delinearizeBack = affine::AffineDelinearizeIndexOp::create(
5103 rewriter, delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(),
5104 basis.take_back(elemsToSplit),
true);
5106 llvm::concat<Value>(delinearizeWithoutSplitPart.getResults(),
5107 delinearizeBack.getResults()));
5108 rewriter.
replaceOp(delinearizeOp, results);
5115 void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
5118 .insert<CancelDelinearizeOfLinearizeDisjointExactTail,
5119 DropUnitExtentBasis, SplitDelinearizeSpanningLastLinearizeArg>(
5127 void AffineLinearizeIndexOp::build(
OpBuilder &odsBuilder,
5131 if (!basis.empty() && basis.front() ==
Value())
5132 basis = basis.drop_front();
5137 build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
5140 void AffineLinearizeIndexOp::build(
OpBuilder &odsBuilder,
5146 basis = basis.drop_front();
5150 build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
5153 void AffineLinearizeIndexOp::build(
OpBuilder &odsBuilder,
5157 build(odsBuilder, odsState, multiIndex,
ValueRange{}, basis, disjoint);
5161 size_t numIndexes = getMultiIndex().size();
5162 size_t numBasisElems = getStaticBasis().size();
5163 if (numIndexes != numBasisElems && numIndexes != numBasisElems + 1)
5164 return emitOpError(
"should be passed a basis element for each index except "
5165 "possibly the first");
5167 auto dynamicMarkersCount =
5168 llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
5169 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
5171 "mismatch between dynamic and static basis (kDynamic marker but no "
5172 "corresponding dynamic basis entry) -- this can only happen due to an "
5173 "incorrect fold/rewrite");
5178 OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
5179 std::optional<SmallVector<int64_t>> maybeStaticBasis =
5181 adaptor.getDynamicBasis());
5182 if (maybeStaticBasis) {
5183 setStaticBasis(*maybeStaticBasis);
5187 if (getMultiIndex().empty())
5191 if (getMultiIndex().size() == 1)
5192 return getMultiIndex().front();
5194 if (llvm::is_contained(adaptor.getMultiIndex(),
nullptr))
5197 if (!adaptor.getDynamicBasis().empty())
5202 for (
auto [length, indexAttr] :
5203 llvm::zip_first(llvm::reverse(getStaticBasis()),
5204 llvm::reverse(adaptor.getMultiIndex()))) {
5205 result = result + cast<IntegerAttr>(indexAttr).getInt() * stride;
5206 stride = stride * length;
5209 if (!hasOuterBound())
5212 cast<IntegerAttr>(adaptor.getMultiIndex().front()).getInt() * stride;
5219 if (hasOuterBound()) {
5220 if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
5222 getDynamicBasis().drop_front(), builder);
5224 return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
5228 return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
5233 if (!hasOuterBound())
5249 struct DropLinearizeUnitComponentsIfDisjointOrZero final
5253 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5256 size_t numIndices = multiIndex.size();
5258 newIndices.reserve(numIndices);
5260 newBasis.reserve(numIndices);
5262 if (!op.hasOuterBound()) {
5263 newIndices.push_back(multiIndex.front());
5264 multiIndex = multiIndex.drop_front();
5268 for (
auto [index, basisElem] : llvm::zip_equal(multiIndex, basis)) {
5270 if (!basisEntry || *basisEntry != 1) {
5271 newIndices.push_back(index);
5272 newBasis.push_back(basisElem);
5277 if (!op.getDisjoint() && (!indexValue || *indexValue != 0)) {
5278 newIndices.push_back(index);
5279 newBasis.push_back(basisElem);
5283 if (newIndices.size() == numIndices)
5285 "no unit basis entries to replace");
5287 if (newIndices.size() == 0) {
5292 op, newIndices, newBasis, op.getDisjoint());
5299 int64_t nDynamic = 0;
5309 dynamicPart.push_back(cast<Value>(term));
5313 if (
auto constant = dyn_cast<AffineConstantExpr>(result))
5315 return AffineApplyOp::create(builder, loc, result, dynamicPart).getResult();
5345 struct CancelLinearizeOfDelinearizePortion final
5356 unsigned linStart = 0;
5357 unsigned delinStart = 0;
5358 unsigned length = 0;
5362 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp,
5369 ValueRange multiIndex = linearizeOp.getMultiIndex();
5370 unsigned numLinArgs = multiIndex.size();
5371 unsigned linArgIdx = 0;
5375 while (linArgIdx < numLinArgs) {
5376 auto asResult = dyn_cast<OpResult>(multiIndex[linArgIdx]);
5382 auto delinearizeOp =
5383 dyn_cast<AffineDelinearizeIndexOp>(asResult.getOwner());
5384 if (!delinearizeOp) {
5401 unsigned delinArgIdx = asResult.getResultNumber();
5403 OpFoldResult firstDelinBound = delinBasis[delinArgIdx];
5405 bool boundsMatch = firstDelinBound == firstLinBound;
5406 bool bothAtFront = linArgIdx == 0 && delinArgIdx == 0;
5407 bool knownByDisjoint =
5408 linearizeOp.getDisjoint() && delinArgIdx == 0 && !firstDelinBound;
5409 if (!boundsMatch && !bothAtFront && !knownByDisjoint) {
5415 unsigned numDelinOuts = delinearizeOp.getNumResults();
5416 for (;
j + linArgIdx < numLinArgs &&
j + delinArgIdx < numDelinOuts;
5418 if (multiIndex[linArgIdx +
j] !=
5419 delinearizeOp.getResult(delinArgIdx +
j))
5421 if (linBasis[linArgIdx +
j] != delinBasis[delinArgIdx +
j])
5427 if (
j <= 1 || !alreadyMatchedDelinearize.insert(delinearizeOp).second) {
5431 matches.push_back(Match{delinearizeOp, linArgIdx, delinArgIdx,
j});
5435 if (matches.empty())
5437 linearizeOp,
"no run of delinearize outputs to deal with");
5445 newIndex.reserve(numLinArgs);
5447 newBasis.reserve(numLinArgs);
5448 unsigned prevMatchEnd = 0;
5449 for (Match m : matches) {
5450 unsigned gap = m.linStart - prevMatchEnd;
5451 llvm::append_range(newIndex, multiIndex.slice(prevMatchEnd, gap));
5452 llvm::append_range(newBasis, linBasisRef.slice(prevMatchEnd, gap));
5454 prevMatchEnd = m.linStart + m.length;
5456 PatternRewriter::InsertionGuard g(rewriter);
5460 linBasisRef.slice(m.linStart, m.length);
5468 newIndex.push_back(m.delinearize.getLinearIndex());
5469 newBasis.push_back(newSize);
5477 newDelinBasis.erase(newDelinBasis.begin() + m.delinStart,
5478 newDelinBasis.begin() + m.delinStart + m.length);
5479 newDelinBasis.insert(newDelinBasis.begin() + m.delinStart, newSize);
5480 auto newDelinearize = AffineDelinearizeIndexOp::create(
5481 rewriter, m.delinearize.getLoc(), m.delinearize.getLinearIndex(),
5487 Value combinedElem = newDelinearize.getResult(m.delinStart);
5488 auto residualDelinearize = AffineDelinearizeIndexOp::create(
5489 rewriter, m.delinearize.getLoc(), combinedElem, basisToMerge);
5494 llvm::append_range(newDelinResults,
5495 newDelinearize.getResults().take_front(m.delinStart));
5496 llvm::append_range(newDelinResults, residualDelinearize.getResults());
5499 newDelinearize.getResults().drop_front(m.delinStart + 1));
5501 delinearizeReplacements.push_back(newDelinResults);
5502 newIndex.push_back(combinedElem);
5503 newBasis.push_back(newSize);
5505 llvm::append_range(newIndex, multiIndex.drop_front(prevMatchEnd));
5506 llvm::append_range(newBasis, linBasisRef.drop_front(prevMatchEnd));
5508 linearizeOp, newIndex, newBasis, linearizeOp.getDisjoint());
5510 for (
auto [m, newResults] :
5511 llvm::zip_equal(matches, delinearizeReplacements)) {
5512 if (newResults.empty())
5514 rewriter.
replaceOp(m.delinearize, newResults);
5525 struct DropLinearizeLeadingZero final
5529 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5531 Value leadingIdx = op.getMultiIndex().front();
5535 if (op.getMultiIndex().size() == 1) {
5542 if (op.hasOuterBound())
5543 newMixedBasis = newMixedBasis.drop_front();
5546 op, op.getMultiIndex().drop_front(), newMixedBasis, op.getDisjoint());
5552 void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
5554 patterns.add<CancelLinearizeOfDelinearizePortion, DropLinearizeLeadingZero,
5555 DropLinearizeUnitComponentsIfDisjointOrZero>(context);
5562 #define GET_OP_CLASSES
5563 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
static AffineForOp buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb, int64_t ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds known to be constants.
static bool hasTrivialZeroTripCount(AffineForOp op)
Returns true if the affine.for has zero iterations in trivial cases.
static LogicalResult verifyMemoryOpIndexing(AffineMemOpTy op, AffineMapAttr mapAttr, Operation::operand_range mapOperands, MemRefType memrefType, unsigned numIndexOperands)
Verify common indexing invariants of affine.load, affine.store, affine.vector_load and affine....
static void printAffineMinMaxOp(OpAsmPrinter &p, T op)
static bool isResultTypeMatchAtomicRMWKind(Type resultType, arith::AtomicRMWKind op)
static bool remainsLegalAfterInline(Value value, Region *src, Region *dest, const IRMapping &mapping, function_ref< bool(Value, Region *)> legalityCheck)
Checks if value known to be a legal affine dimension or symbol in src region remains legal if the ope...
static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr, DenseIntElementsAttr group, ValueRange operands, StringRef keyword)
Prints a lower(upper) bound of an affine parallel loop with max(min) conditions in it.
static void LLVM_ATTRIBUTE_UNUSED simplifyMapWithOperands(AffineMap &map, ArrayRef< Value > operands)
Simplify the map while exploiting information on the values in operands.
static OpFoldResult foldMinMaxOp(T op, ArrayRef< Attribute > operands)
Fold an affine min or max operation with the given operands.
static bool isTopLevelValueOrAbove(Value value, Region *region)
A utility function to check if a value is defined at the top level of region or is an argument of reg...
static LogicalResult canonicalizeLoopBounds(AffineForOp forOp)
Canonicalize the bounds of the given loop.
static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims, unsigned numSymbols, ArrayRef< Value > operands)
Simplify expr while exploiting information from the values in operands.
static bool isValidAffineIndexOperand(Value value, Region *region)
static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static ParseResult parseBound(bool isLower, OperationState &result, OpAsmParser &p)
Parse a for operation loop bounds.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType)
Verify common invariants of affine.vector_load and affine.vector_store.
static void simplifyMinOrMaxExprWithOperands(AffineMap &map, ArrayRef< Value > operands, bool isMax)
Simplify the expressions in map while making use of lower or upper bounds of its operands.
static ParseResult parseAffineMinMaxOp(OpAsmParser &parser, OperationState &result)
static void composeSetAndOperands(IntegerSet &set, SmallVectorImpl< Value > &operands, bool composeAffineMin=false)
Compose any affine.apply ops feeding into operands of the integer set set by composing the maps of su...
static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, Region *region)
Returns true if the 'index' dimension of the memref defined by memrefDefOp is a statically shaped one...
static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef< Value > operands, int64_t k)
Check if e is known to be: 0 <= e < k.
static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser, OperationState &result, MinMaxKind kind)
Parses an affine map that can contain a min/max for groups of its results, e.g., max(expr-1,...
static AffineForOp buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds that may or may not be constants.
static void printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, OpAsmPrinter &printer)
Prints dimension and symbol list.
static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef< Value > operands)
Returns the largest known divisor of e.
static void composeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands, bool composeAffineMin=false)
Iterate over operands and fold away all those produced by an AffineApplyOp iteratively.
static void legalizeDemotedDims(MapOrSet &mapOrSet, SmallVectorImpl< Value > &operands)
A valid affine dimension may appear as a symbol in affine.apply operations.
static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
static void buildAffineLoopNestImpl(OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, LoopCreatorTy &&loopCreatorFn)
Builds an affine loop nest, using "loopCreatorFn" to create individual loop operations.
static LogicalResult foldLoopBounds(AffineForOp forOp)
Fold the constant bounds of a loop.
static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp, AffineExpr dimOrSym, AffineMap *map, ValueRange dims, ValueRange syms)
Assuming dimOrSym is a quantity in the apply op map map and defined by minOp = affine_min(x_1,...
static LogicalResult verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, unsigned numDims)
Utility function to verify that a set of operands are valid dimension and symbol identifiers.
static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region)
Returns true if the result of the dim op is a valid symbol for region.
static bool isQTimesDPlusR(AffineExpr e, ArrayRef< Value > operands, int64_t &div, AffineExpr "ientTimesDiv, AffineExpr &rem)
Check if expression e is of the form d*e_1 + e_2 where 0 <= e_2 < d.
static ParseResult deduplicateAndResolveOperands(OpAsmParser &parser, ArrayRef< SmallVector< OpAsmParser::UnresolvedOperand >> operands, SmallVectorImpl< Value > &uniqueOperands, SmallVectorImpl< AffineExpr > &replacements, AffineExprKind kind)
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult replaceDimOrSym(AffineMap *map, unsigned dimOrSymbolPosition, SmallVectorImpl< Value > &dims, SmallVectorImpl< Value > &syms, bool replaceAffineMin)
Replace all occurrences of AffineExpr at position pos in map by the defining AffineApplyOp expression...
static LogicalResult verifyAffineMinMaxOp(T op)
static void printBound(AffineMapAttr boundMap, Operation::operand_range boundOperands, const char *prefix, OpAsmPrinter &p)
static void composeMultiResultAffineMap(AffineMap &map, SmallVectorImpl< Value > &operands, bool composeAffineMin=false)
Composes the given affine map with the given list of operands, pulling in the maps from any affine....
static std::optional< SmallVector< int64_t > > foldCstValueToCstAttrBasis(ArrayRef< OpFoldResult > mixedBasis, MutableOperandRange mutableDynamicBasis, ArrayRef< Attribute > dynamicBasis)
Given mixed basis of affine.delinearize_index/linearize_index replace constant SSA values with the co...
static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map)
Canonicalize the result expression order of an affine map and return success if the order changed.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value getMemRef(Operation *memOp)
Returns the memref being read/written by a memref/affine load/store op.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
union mlir::linalg::@1241::ArityGroupAndKind::Kind kind
static Operation::operand_range getLowerBoundOperands(AffineForOp forOp)
static Operation::operand_range getUpperBoundOperands(AffineForOp forOp)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
AffineExprKind getKind() const
Return the classification for this type.
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
MLIRContext * getContext() const
AffineExpr replace(AffineExpr expr, AffineExpr replacement) const
Sparse replace method.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
MLIRContext * getContext() const
bool isFunctionOfDim(unsigned position) const
Return true if any affine expression involves AffineDimExpr position.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
bool isFunctionOfSymbol(unsigned position) const
Return true if any affine expression involves AffineSymbolExpr position.
unsigned getNumResults() const
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
unsigned getNumInputs() const
AffineMap shiftSymbols(unsigned shift, unsigned offset=0) const
Replace symbols[offset ...
AffineExpr getResult(unsigned idx) const
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
static AffineMap getConstantMap(int64_t val, MLIRContext *context)
Returns a single constant result affine map.
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
LogicalResult constantFold(ArrayRef< Attribute > operandConstants, SmallVectorImpl< Attribute > &results, bool *hasPoison=nullptr) const
Folds the results of the application of an affine map on the provided operands to a constant if possi...
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
AffineMap getDimIdentityMap()
AffineMap getMultiDimIdentityMap(unsigned rank)
AffineExpr getAffineSymbolExpr(unsigned position)
AffineExpr getAffineConstantExpr(int64_t constant)
DenseIntElementsAttr getI32TensorAttr(ArrayRef< int32_t > values)
Tensor-typed DenseIntElementsAttr getters.
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
AffineMap getEmptyAffineMap()
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
MLIRContext * getContext() const
AffineMap getSymbolIdentityMap()
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
An attribute that represents a reference to a dense integer vector or tensor object.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
An integer set representing a conjunction of one or more affine equalities and inequalities.
unsigned getNumDims() const
static IntegerSet get(unsigned dimCount, unsigned symbolCount, ArrayRef< AffineExpr > constraints, ArrayRef< bool > eqFlags)
MLIRContext * getContext() const
unsigned getNumInputs() const
ArrayRef< AffineExpr > getConstraints() const
ArrayRef< bool > getEqFlags() const
Returns the equality bits, which specify whether each of the constraints is an equality or inequality...
unsigned getNumSymbols() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
void erase(unsigned subStart, unsigned subLen=1)
Erase the operands within the given sub-range.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void pop_back()
Pop last element from list.
Attribute erase(StringAttr name)
Erase the attribute with the given name from the list.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseAffineMapOfSSAIds(SmallVectorImpl< UnresolvedOperand > &operands, Attribute &map, StringRef attrName, NamedAttrList &attrs, Delimiter delimiter=Delimiter::Square)=0
Parses an affine map attribute where dims and symbols are SSA operands.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseAffineExprOfSSAIds(SmallVectorImpl< UnresolvedOperand > &dimOperands, SmallVectorImpl< UnresolvedOperand > &symbOperands, AffineExpr &expr)=0
Parses an affine expression where dims and symbols are SSA operands.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, ValueRange symOperands)=0
Prints an affine expression of SSA ids with SSA id names used instead of dims and symbols.
virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, ValueRange operands)=0
Prints an affine map of SSA ids, where SSA id names are used in place of dims/symbols.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
A trait of region holding operations that defines a new scope for polyhedral optimization purposes.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
operand_range::iterator operand_iterator
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
bool hasOneBlock()
Return true if this region has exactly one block.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
std::vector< SmallVector< int64_t, 8 > > operandExprStack
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
A variable that can be added to the constraint set as a "column".
static bool compare(const Variable &lhs, ComparisonOperator cmp, const Variable &rhs)
Return "true" if "lhs cmp rhs" was proven to hold.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
AffineBound represents a lower or upper bound in the for operation.
AffineDmaStartOp starts a non-blocking DMA operation that transfers data from a source memref to a de...
AffineDmaWaitOp blocks until the completion of a DMA operation associated with the tag element 'tag[i...
An AffineValueMap is an affine map plus its ML value operands and results for analysis purposes.
LogicalResult canonicalize()
Attempts to canonicalize the map and operands.
ArrayRef< Value > getOperands() const
AffineExpr getResult(unsigned i)
AffineMap getAffineMap() const
unsigned getNumResults() const
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Operation * getOwner() const
Return the owner of this operand.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
void buildAffineLoopNest(OpBuilder &builder, Location loc, ArrayRef< int64_t > lbs, ArrayRef< int64_t > ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn=nullptr)
Builds a perfect nest of affine.for loops, i.e., each loop except the innermost one contains only ano...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
void extractForInductionVars(ArrayRef< AffineForOp > forInsts, SmallVectorImpl< Value > *ivs)
Extracts the induction variables from a list of AffineForOps and places them in the output argument i...
bool isValidDim(Value value)
Returns true if the given Value can be used as a dimension id in the region of the closest surroundin...
bool isAffineInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp or AffineParallelOp.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
AffineForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
bool isAffineForInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
bool isTopLevelValue(Value value)
A utility function to check if a value is defined at the top level of an op with trait AffineScope or...
Region * getAffineAnalysisScope(Operation *op)
Returns the closest region enclosing op that is held by a non-affine operation; nullptr if there is n...
void fullyComposeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands, bool composeAffineMin=false)
Given an affine map map and its input operands, this method composes into map, maps of AffineApplyOps...
void canonicalizeSetAndOperands(IntegerSet *set, SmallVectorImpl< Value > *operands)
Canonicalizes an integer set the same way canonicalizeMapAndOperands does for affine maps.
void extractInductionVars(ArrayRef< Operation * > affineOps, SmallVectorImpl< Value > &ivs)
Extracts the induction variables from a list of either AffineForOp or AffineParallelOp and places the...
bool isValidSymbol(Value value)
Returns true if the given value can be used as a symbol in the region of the closest surrounding op t...
AffineParallelOp getAffineParallelInductionVarOwner(Value val)
Returns true if the provided value is among the induction variables of an AffineParallelOp.
Region * getAffineScope(Operation *op)
Returns the closest region enclosing op that is held by an operation with trait AffineScope; nullptr ...
ParseResult parseDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl< Value > &operands, unsigned &numDims)
Parses dimension and symbol list.
bool isAffineParallelInductionVar(Value val)
Returns true if val is the induction variable of an AffineParallelOp.
AffineMinOp makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns an AffineMinOp obtained by composing map and operands with AffineApplyOps supplying those ope...
BaseMemRefType getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the TensorType can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
AffineMap removeDuplicateExprs(AffineMap map)
Returns a map with the same dimension and symbol count as map, but whose results are the unique affin...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::optional< int64_t > getBoundForAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols, ArrayRef< std::optional< int64_t >> constLowerBounds, ArrayRef< std::optional< int64_t >> constUpperBounds, bool isUpper)
Get a lower or upper (depending on isUpper) bound for expr while using the constant lower and upper b...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ SymbolId
Symbolic identifier.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
AffineMap foldAttributesIntoMap(Builder &b, AffineMap map, ArrayRef< OpFoldResult > operands, SmallVector< Value > &remainingValues)
Fold all attributes among the given operands into the affine map.
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Canonicalize the affine map result expression order of an affine min/max operation.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Remove duplicated expressions in affine min/max ops.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Merge an affine min/max op to its consumers if its consumer is also an affine min/max op.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.