25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/SmallBitVector.h"
27#include "llvm/ADT/SmallVectorExtras.h"
28#include "llvm/ADT/TypeSwitch.h"
29#include "llvm/Support/DebugLog.h"
30#include "llvm/Support/LogicalResult.h"
31#include "llvm/Support/MathExtras.h"
38using llvm::divideCeilSigned;
39using llvm::divideFloorSigned;
42#define DEBUG_TYPE "affine-ops"
44#include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
51 if (
auto arg = dyn_cast<BlockArgument>(value))
52 return arg.getParentRegion() == region;
75 if (llvm::isa<BlockArgument>(value))
76 return legalityCheck(mapping.
lookup(value), dest);
83 bool isDimLikeOp = isa<ShapedDimOpInterface>(value.
getDefiningOp());
94 return llvm::all_of(values, [&](
Value v) {
101template <
typename OpTy>
104 static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
105 AffineWriteOpInterface>::value,
106 "only ops with affine read/write interface are supported");
113 dimOperands, src, dest, mapping,
117 symbolOperands, src, dest, mapping,
134 op.getMapOperands(), src, dest, mapping,
139 op.getMapOperands(), src, dest, mapping,
150struct AffineInlinerInterface :
public DialectInlinerInterface {
151 using DialectInlinerInterface::DialectInlinerInterface;
162 IRMapping &valueMapping)
const final {
166 if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
177 for (Operation &op : srcBlock) {
179 if (
auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
180 if (iface.hasNoEffect())
187 llvm::TypeSwitch<Operation *, bool>(&op)
188 .Case<AffineApplyOp, AffineReadOpInterface,
189 AffineWriteOpInterface>([&](
auto op) {
192 .Default([](Operation *) {
206 bool isLegalToInline(Operation *op, Region *region,
bool wouldBeCloned,
207 IRMapping &valueMapping)
const final {
212 Operation *parentOp = region->getParentOp();
213 return parentOp->
hasTrait<OpTrait::AffineScope>() ||
214 isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
218 bool shouldAnalyzeRecursively(Operation *op)
const final {
return true; }
226void AffineDialect::initialize() {
229#include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
231 addInterfaces<AffineInlinerInterface>();
232 declarePromisedInterfaces<ValueBoundsOpInterface, AffineApplyOp, AffineMaxOp,
241 if (
auto poison = dyn_cast<ub::PoisonAttr>(value))
242 return ub::PoisonOp::create(builder, loc, type, poison);
243 return arith::ConstantOp::materialize(builder, value, type, loc);
251 if (
auto arg = dyn_cast<BlockArgument>(value)) {
267 while (
auto *parentOp = curOp->getParentOp()) {
269 return curOp->getParentRegion();
278 if (!isa<AffineForOp, AffineIfOp, AffineParallelOp>(parentOp))
303 auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
331 if (
auto applyOp = dyn_cast<AffineApplyOp>(op))
332 return applyOp.isValidDim(region);
335 if (isa<AffineDelinearizeIndexOp, AffineLinearizeIndexOp>(op))
336 return llvm::all_of(op->getOperands(),
337 [&](
Value arg) { return ::isValidDim(arg, region); });
340 if (
auto dimOp = dyn_cast<ShapedDimOpInterface>(op))
348template <
typename AnyMemRefDefOp>
351 MemRefType memRefType = memrefDefOp.getType();
354 if (
index >= memRefType.getRank()) {
359 if (!memRefType.isDynamicDim(
index))
362 unsigned dynamicDimPos = memRefType.getDynamicDimIndex(
index);
363 return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
375 if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
383 if (!
index.has_value())
387 Operation *op = dimOp.getShapedValue().getDefiningOp();
388 while (
auto castOp = dyn_cast<memref::CastOp>(op)) {
390 if (isa<UnrankedMemRefType>(castOp.getSource().getType()))
392 op = castOp.getSource().getDefiningOp();
399 .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
401 .Default([](
Operation *) {
return false; });
435 if (parentRegion == region)
476 if (
isPure(defOp) && llvm::all_of(defOp->getOperands(), [&](
Value operand) {
477 return affine::isValidSymbol(operand, region);
483 if (
auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp))
501 printer <<
'(' << operands.take_front(numDims) <<
')';
502 if (operands.size() > numDims)
503 printer <<
'[' << operands.drop_front(numDims) <<
']';
513 numDims = opInfos.size();
527template <
typename OpTy>
532 for (
auto operand : operands) {
533 if (opIt++ < numDims) {
535 return op.emitOpError(
"operand cannot be used as a dimension id");
537 return op.emitOpError(
"operand cannot be used as a symbol");
548 return AffineValueMap(getAffineMap(), getOperands(), getResult());
555 AffineMapAttr mapAttr;
561 auto map = mapAttr.getValue();
563 if (map.getNumDims() != numDims ||
564 numDims + map.getNumSymbols() !=
result.operands.size()) {
566 "dimension or symbol index mismatch");
569 result.types.append(map.getNumResults(), indexTy);
574 p <<
" " << getMapAttr();
576 getAffineMap().getNumDims(), p);
580LogicalResult AffineApplyOp::verify() {
587 "operand count and affine map dimension and symbol count must match");
591 return emitOpError(
"mapping must produce one value");
597 for (
Value operand : getMapOperands().drop_front(affineMap.
getNumDims())) {
599 return emitError(
"dimensional operand cannot be used as a symbol");
607bool AffineApplyOp::isValidDim() {
608 return llvm::all_of(getOperands(),
615bool AffineApplyOp::isValidDim(
Region *region) {
616 return llvm::all_of(getOperands(),
617 [&](
Value op) { return ::isValidDim(op, region); });
622bool AffineApplyOp::isValidSymbol() {
623 return llvm::all_of(getOperands(),
629bool AffineApplyOp::isValidSymbol(
Region *region) {
630 return llvm::all_of(getOperands(), [&](
Value operand) {
636 auto map = getAffineMap();
639 auto expr = map.getResult(0);
640 if (
auto dim = dyn_cast<AffineDimExpr>(expr))
641 return getOperand(dim.getPosition());
642 if (
auto sym = dyn_cast<AffineSymbolExpr>(expr))
643 return getOperand(map.getNumDims() + sym.getPosition());
647 bool hasPoison =
false;
649 map.constantFold(adaptor.getMapOperands(),
result, &hasPoison);
669 auto dimExpr = dyn_cast<AffineDimExpr>(e);
679 Value operand = operands[dimExpr.getPosition()];
684 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() == 0) {
685 operandDivisor = forOp.getStepAsInt();
687 uint64_t lbLargestKnownDivisor =
688 forOp.getLowerBoundMap().getLargestKnownDivisorOfMapExprs();
689 operandDivisor = std::gcd(lbLargestKnownDivisor, forOp.getStepAsInt());
692 return operandDivisor;
699 if (
auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
700 int64_t constVal = constExpr.getValue();
701 return constVal >= 0 && constVal < k;
703 auto dimExpr = dyn_cast<AffineDimExpr>(e);
706 Value operand = operands[dimExpr.getPosition()];
710 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() >= 0 &&
711 forOp.hasConstantUpperBound() && forOp.getConstantUpperBound() <= k) {
727 auto bin = dyn_cast<AffineBinaryOpExpr>(e);
735 quotientTimesDiv = llhs;
741 quotientTimesDiv = rlhs;
751 if (forOp && forOp.hasConstantLowerBound())
752 return forOp.getConstantLowerBound();
759 if (!forOp || !forOp.hasConstantUpperBound())
764 if (forOp.hasConstantLowerBound()) {
765 return forOp.getConstantUpperBound() - 1 -
766 (forOp.getConstantUpperBound() - forOp.getConstantLowerBound() - 1) %
767 forOp.getStepAsInt();
769 return forOp.getConstantUpperBound() - 1;
780 constLowerBounds.reserve(operands.size());
781 constUpperBounds.reserve(operands.size());
782 for (
Value operand : operands) {
787 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr))
788 return constExpr.getValue();
803 constLowerBounds.reserve(operands.size());
804 constUpperBounds.reserve(operands.size());
805 for (
Value operand : operands) {
810 std::optional<int64_t> lowerBound;
811 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
812 lowerBound = constExpr.getValue();
815 constLowerBounds, constUpperBounds,
826 auto binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
837 binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
845 lhs = binExpr.getLHS();
846 rhs = binExpr.getRHS();
847 auto rhsConst = dyn_cast<AffineConstantExpr>(
rhs);
851 int64_t rhsConstVal = rhsConst.getValue();
853 if (rhsConstVal <= 0)
858 std::optional<int64_t> lhsLbConst =
860 std::optional<int64_t> lhsUbConst =
862 if (lhsLbConst && lhsUbConst) {
863 int64_t lhsLbConstVal = *lhsLbConst;
864 int64_t lhsUbConstVal = *lhsUbConst;
868 divideFloorSigned(lhsLbConstVal, rhsConstVal) ==
869 divideFloorSigned(lhsUbConstVal, rhsConstVal)) {
871 divideFloorSigned(lhsLbConstVal, rhsConstVal), context);
877 divideCeilSigned(lhsLbConstVal, rhsConstVal) ==
878 divideCeilSigned(lhsUbConstVal, rhsConstVal)) {
885 lhsLbConstVal < rhsConstVal && lhsUbConstVal < rhsConstVal) {
898 if (rhsConstVal % divisor == 0 &&
900 expr = quotientTimesDiv.
floorDiv(rhsConst);
901 }
else if (divisor % rhsConstVal == 0 &&
903 expr =
rem % rhsConst;
929 if (operands.empty())
935 constLowerBounds.reserve(operands.size());
936 constUpperBounds.reserve(operands.size());
937 for (
Value operand : operands) {
951 if (
auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
952 lowerBounds.push_back(constExpr.getValue());
953 upperBounds.push_back(constExpr.getValue());
955 lowerBounds.push_back(
957 constLowerBounds, constUpperBounds,
959 upperBounds.push_back(
961 constLowerBounds, constUpperBounds,
968 for (
auto exprEn : llvm::enumerate(map.
getResults())) {
970 unsigned i = exprEn.index();
972 if (lowerBounds[i] && upperBounds[i] && *lowerBounds[i] == *upperBounds[i])
977 if (!upperBounds[i]) {
978 irredundantExprs.push_back(e);
983 if (!llvm::any_of(llvm::enumerate(lowerBounds), [&](
const auto &en) {
984 auto otherLowerBound = en.value();
985 unsigned pos = en.index();
986 if (pos == i || !otherLowerBound)
988 if (*otherLowerBound > *upperBounds[i])
990 if (*otherLowerBound < *upperBounds[i])
995 if (upperBounds[pos] && lowerBounds[i] &&
996 lowerBounds[i] == upperBounds[i] &&
997 otherLowerBound == *upperBounds[pos] && i < pos)
1001 irredundantExprs.push_back(e);
1003 if (!lowerBounds[i]) {
1004 irredundantExprs.push_back(e);
1008 if (!llvm::any_of(llvm::enumerate(upperBounds), [&](
const auto &en) {
1009 auto otherUpperBound = en.value();
1010 unsigned pos = en.index();
1011 if (pos == i || !otherUpperBound)
1013 if (*otherUpperBound < *lowerBounds[i])
1015 if (*otherUpperBound > *lowerBounds[i])
1017 if (lowerBounds[pos] && upperBounds[i] &&
1018 lowerBounds[i] == upperBounds[i] &&
1019 otherUpperBound == lowerBounds[pos] && i < pos)
1023 irredundantExprs.push_back(e);
1037 assert(map.
getNumInputs() == operands.size() &&
"invalid operands for map");
1043 newResults.push_back(expr);
1066 LDBG() <<
"replaceAffineMinBoundingBoxExpression: `" << minOp <<
"`";
1067 AffineMap affineMinMap = minOp.getAffineMap();
1070 for (
unsigned i = 0, e = affineMinMap.
getNumResults(); i < e; ++i) {
1076 minOp.getOperands())))
1084 for (
auto [i, dim] : llvm::enumerate(minOp.getDimOperands())) {
1085 auto it = llvm::find(dims, dim);
1086 if (it == dims.end()) {
1087 unmappedDims.push_back(i);
1093 for (
auto [i, sym] : llvm::enumerate(minOp.getSymbolOperands())) {
1094 auto it = llvm::find(syms, sym);
1095 if (it == syms.end()) {
1096 unmappedSyms.push_back(i);
1109 if (llvm::any_of(unmappedDims,
1110 [&](
unsigned i) {
return expr.isFunctionOfDim(i); }) ||
1111 llvm::any_of(unmappedSyms,
1112 [&](
unsigned i) {
return expr.isFunctionOfSymbol(i); }))
1118 repl[dimOrSym.
ceilDiv(convertedExpr)] = c1;
1120 repl[(dimOrSym + convertedExpr - 1).floorDiv(convertedExpr)] = c1;
1125 return success(*map != initialMap);
1134 AffineExpr e,
const llvm::SmallDenseSet<AffineExpr, 4> &exprsToRemove,
1136 auto binOp = dyn_cast<AffineBinaryOpExpr>(e);
1147 llvm::SmallDenseSet<AffineExpr, 4> ourTracker(exprsToRemove);
1152 if (!ourTracker.erase(thisTerm)) {
1153 toPreserve.push_back(thisTerm);
1157 auto nextBinOp = dyn_cast_if_present<AffineBinaryOpExpr>(nextTerm);
1159 thisTerm = nextTerm;
1162 thisTerm = nextBinOp.getRHS();
1163 nextTerm = nextBinOp.getLHS();
1166 if (!ourTracker.empty())
1171 for (
AffineExpr preserved : llvm::reverse(toPreserve))
1172 newExpr = newExpr + preserved;
1173 replacementsMap.insert({e, newExpr});
1191 AffineDelinearizeIndexOp delinOp,
Value resultToReplace,
AffineMap *map,
1193 if (!delinOp.getDynamicBasis().empty())
1195 if (resultToReplace != delinOp.getMultiIndex().back())
1200 for (
auto [pos, dim] : llvm::enumerate(dims)) {
1201 auto asResult = dyn_cast_if_present<OpResult>(dim);
1204 if (asResult.getOwner() == delinOp.getOperation())
1207 for (
auto [pos, sym] : llvm::enumerate(syms)) {
1208 auto asResult = dyn_cast_if_present<OpResult>(sym);
1211 if (asResult.getOwner() == delinOp.getOperation())
1214 if (llvm::is_contained(resToExpr,
AffineExpr()))
1217 bool isDimReplacement = llvm::all_of(resToExpr, llvm::IsaPred<AffineDimExpr>);
1219 llvm::SmallDenseSet<AffineExpr, 4> expectedExprs;
1222 for (
auto [binding, size] : llvm::zip(
1223 llvm::reverse(resToExpr), llvm::reverse(delinOp.getStaticBasis()))) {
1227 if (resToExpr.size() != delinOp.getStaticBasis().size())
1228 expectedExprs.insert(resToExpr[0] * stride);
1237 if (replacements.empty())
1241 if (isDimReplacement)
1242 dims.push_back(delinOp.getLinearIndex());
1244 syms.push_back(delinOp.getLinearIndex());
1245 *map = origMap.
replace(replacements, dims.size(), syms.size());
1249 if (
auto d = dyn_cast<AffineDimExpr>(e)) {
1250 unsigned pos = d.getPosition();
1252 dims[pos] =
nullptr;
1254 if (
auto s = dyn_cast<AffineSymbolExpr>(e)) {
1255 unsigned pos = s.getPosition();
1257 syms[pos] =
nullptr;
1276 unsigned dimOrSymbolPosition,
1279 bool replaceAffineMin) {
1281 bool isDimReplacement = (dimOrSymbolPosition < dims.size());
1282 unsigned pos = isDimReplacement ? dimOrSymbolPosition
1283 : dimOrSymbolPosition - dims.size();
1284 Value &v = isDimReplacement ? dims[pos] : syms[pos];
1288 if (
auto minOp = v.
getDefiningOp<AffineMinOp>(); minOp && replaceAffineMin) {
1295 if (
auto delinOp = v.
getDefiningOp<affine::AffineDelinearizeIndexOp>()) {
1309 AffineMap composeMap = affineApply.getAffineMap();
1310 assert(composeMap.
getNumResults() == 1 &&
"affine.apply with >1 results");
1312 affineApply.getMapOperands().end());
1326 dims.append(composeDims.begin(), composeDims.end());
1327 syms.append(composeSyms.begin(), composeSyms.end());
1328 *map = map->
replace(toReplace, replacementExpr, dims.size(), syms.size());
1338 bool composeAffineMin =
false) {
1357 bool changed =
false;
1358 for (
unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
1371 unsigned nDims = 0, nSyms = 0;
1373 dimReplacements.reserve(dims.size());
1374 symReplacements.reserve(syms.size());
1375 for (
auto *container : {&dims, &syms}) {
1376 bool isDim = (container == &dims);
1377 auto &repls = isDim ? dimReplacements : symReplacements;
1378 for (
const auto &en : llvm::enumerate(*container)) {
1379 Value v = en.value();
1383 "map is function of unexpected expr@pos");
1389 operands->push_back(v);
1402 while (llvm::any_of(*operands, [](
Value v) {
1408 if (composeAffineMin && llvm::any_of(*operands, [](
Value v) {
1418 bool composeAffineMin) {
1423 return AffineApplyOp::create(
b, loc, map, valueOperands);
1429 bool composeAffineMin) {
1434 operands, composeAffineMin);
1441 bool composeAffineMin =
false) {
1447 for (
unsigned i : llvm::seq<unsigned>(0, map.
getNumResults())) {
1455 llvm::append_range(dims,
1457 llvm::append_range(symbols,
1464 operands = llvm::to_vector(llvm::concat<Value>(dims, symbols));
1471 bool composeAffineMin) {
1472 assert(map.
getNumResults() == 1 &&
"building affine.apply with !=1 result");
1482 AffineApplyOp applyOp =
1487 for (
unsigned i = 0, e = constOperands.size(); i != e; ++i)
1492 if (failed(applyOp->fold(constOperands, foldResults)) ||
1493 foldResults.empty()) {
1495 listener->notifyOperationInserted(applyOp, {});
1496 return applyOp.getResult();
1500 return llvm::getSingleElement(foldResults);
1510 operands, composeAffineMin);
1516 bool composeAffineMin) {
1517 return llvm::map_to_vector(
1518 llvm::seq<unsigned>(0, map.
getNumResults()), [&](
unsigned i) {
1519 return makeComposedFoldedAffineApply(b, loc, map.getSubMap({i}),
1520 operands, composeAffineMin);
1524template <
typename OpTy>
1530 return OpTy::create(
b, loc,
b.getIndexType(), map, valueOperands);
1539template <
typename OpTy>
1555 for (
unsigned i = 0, e = constOperands.size(); i != e; ++i)
1560 if (failed(minMaxOp->fold(constOperands, foldResults)) ||
1561 foldResults.empty()) {
1563 listener->notifyOperationInserted(minMaxOp, {});
1564 return minMaxOp.getResult();
1568 return llvm::getSingleElement(foldResults);
1587template <
class MapOrSet>
1590 if (!mapOrSet || operands->empty())
1593 assert(mapOrSet->getNumInputs() == operands->size() &&
1594 "map/set inputs must match number of operands");
1596 auto *context = mapOrSet->getContext();
1598 resultOperands.reserve(operands->size());
1600 remappedSymbols.reserve(operands->size());
1601 unsigned nextDim = 0;
1602 unsigned nextSym = 0;
1603 unsigned oldNumSyms = mapOrSet->getNumSymbols();
1605 for (
unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
1606 if (i < mapOrSet->getNumDims()) {
1610 remappedSymbols.push_back((*operands)[i]);
1613 resultOperands.push_back((*operands)[i]);
1616 resultOperands.push_back((*operands)[i]);
1620 resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
1621 *operands = resultOperands;
1622 *mapOrSet = mapOrSet->replaceDimsAndSymbols(
1623 dimRemapping, {}, nextDim, oldNumSyms + nextSym);
1625 assert(mapOrSet->getNumInputs() == operands->size() &&
1626 "map/set inputs must match number of operands");
1635template <
class MapOrSet>
1638 if (!mapOrSet || operands.empty())
1641 unsigned numOperands = operands.size();
1643 assert(mapOrSet.getNumInputs() == numOperands &&
1644 "map/set inputs must match number of operands");
1646 auto *context = mapOrSet.getContext();
1648 resultOperands.reserve(numOperands);
1650 remappedDims.reserve(numOperands);
1652 symOperands.reserve(mapOrSet.getNumSymbols());
1653 unsigned nextSym = 0;
1654 unsigned nextDim = 0;
1655 unsigned oldNumDims = mapOrSet.getNumDims();
1657 resultOperands.assign(operands.begin(), operands.begin() + oldNumDims);
1658 for (
unsigned i = oldNumDims, e = mapOrSet.getNumInputs(); i != e; ++i) {
1661 symRemapping[i - oldNumDims] =
1663 remappedDims.push_back(operands[i]);
1666 symOperands.push_back(operands[i]);
1670 append_range(resultOperands, remappedDims);
1671 append_range(resultOperands, symOperands);
1672 operands = resultOperands;
1673 mapOrSet = mapOrSet.replaceDimsAndSymbols(
1674 {}, symRemapping, oldNumDims + nextDim, nextSym);
1676 assert(mapOrSet.getNumInputs() == operands.size() &&
1677 "map/set inputs must match number of operands");
1681template <
class MapOrSet>
1684 static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
1685 "Argument must be either of AffineMap or IntegerSet type");
1687 if (!mapOrSet || operands->empty())
1690 assert(mapOrSet->getNumInputs() == operands->size() &&
1691 "map/set inputs must match number of operands");
1697 llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
1698 llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
1700 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr))
1701 usedDims[dimExpr.getPosition()] =
true;
1702 else if (
auto symExpr = dyn_cast<AffineSymbolExpr>(expr))
1703 usedSyms[symExpr.getPosition()] =
true;
1706 auto *context = mapOrSet->getContext();
1709 resultOperands.reserve(operands->size());
1711 llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
1713 unsigned nextDim = 0;
1714 for (
unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
1717 auto it = seenDims.find((*operands)[i]);
1718 if (it == seenDims.end()) {
1720 resultOperands.push_back((*operands)[i]);
1721 seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
1723 dimRemapping[i] = it->second;
1727 llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
1729 unsigned nextSym = 0;
1730 for (
unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
1736 IntegerAttr operandCst;
1737 if (
matchPattern((*operands)[i + mapOrSet->getNumDims()],
1744 auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
1745 if (it == seenSymbols.end()) {
1747 resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
1748 seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
1751 symRemapping[i] = it->second;
1754 *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
1756 *operands = resultOperands;
1773template <
typename AffineOpTy>
1782 LogicalResult matchAndRewrite(AffineOpTy affineOp,
1785 llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
1786 AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp,
1787 AffineVectorStoreOp, AffineVectorLoadOp>::value,
1788 "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op "
1790 auto map = affineOp.getAffineMap();
1792 auto oldOperands = affineOp.getMapOperands();
1797 if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
1798 resultOperands.begin()))
1801 replaceAffineOp(rewriter, affineOp, map, resultOperands);
1809void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
1816void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
1820 prefetch, prefetch.getMemref(), map, mapOperands, prefetch.getIsWrite(),
1821 prefetch.getLocalityHint(), prefetch.getIsDataCache());
1824void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
1828 store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
1831void SimplifyAffineOp<AffineVectorLoadOp>::replaceAffineOp(
1835 vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map,
1839void SimplifyAffineOp<AffineVectorStoreOp>::replaceAffineOp(
1843 vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map,
1848template <
typename AffineOpTy>
1849void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
1858 results.
add<SimplifyAffineOp<AffineApplyOp>>(context);
1873 result.addOperands(srcMemRef);
1874 result.addAttribute(getSrcMapAttrStrName(), AffineMapAttr::get(srcMap));
1875 result.addOperands(srcIndices);
1876 result.addOperands(destMemRef);
1877 result.addAttribute(getDstMapAttrStrName(), AffineMapAttr::get(dstMap));
1878 result.addOperands(destIndices);
1879 result.addOperands(tagMemRef);
1880 result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
1881 result.addOperands(tagIndices);
1882 result.addOperands(numElements);
1884 result.addOperands({stride, elementsPerStride});
1889 p <<
" " << getSrcMemRef() <<
'[';
1891 p <<
"], " << getDstMemRef() <<
'[';
1893 p <<
"], " << getTagMemRef() <<
'[';
1897 p <<
", " << getStride();
1898 p <<
", " << getNumElementsPerStride();
1900 p <<
" : " << getSrcMemRefType() <<
", " << getDstMemRefType() <<
", "
1901 << getTagMemRefType();
1910ParseResult AffineDmaStartOp::parse(
OpAsmParser &parser,
1913 AffineMapAttr srcMapAttr;
1916 AffineMapAttr dstMapAttr;
1919 AffineMapAttr tagMapAttr;
1934 getSrcMapAttrStrName(),
1938 getDstMapAttrStrName(),
1942 getTagMapAttrStrName(),
1951 if (!strideInfo.empty() && strideInfo.size() != 2) {
1953 "expected two stride related operands");
1955 bool isStrided = strideInfo.size() == 2;
1960 if (types.size() != 3)
1978 if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
1979 dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
1980 tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1982 "memref operand count not equal to map.numInputs");
1986LogicalResult AffineDmaStartOp::verify() {
1987 if (!llvm::isa<MemRefType>(getOperand(getSrcMemRefOperandIndex()).
getType()))
1988 return emitOpError(
"expected DMA source to be of memref type");
1989 if (!llvm::isa<MemRefType>(getOperand(getDstMemRefOperandIndex()).
getType()))
1990 return emitOpError(
"expected DMA destination to be of memref type");
1991 if (!llvm::isa<MemRefType>(getOperand(getTagMemRefOperandIndex()).
getType()))
1992 return emitOpError(
"expected DMA tag to be of memref type");
1994 unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
1995 getDstMap().getNumInputs() +
1996 getTagMap().getNumInputs();
1997 if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
1998 getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
1999 return emitOpError(
"incorrect number of operands");
2003 for (
auto idx : getSrcIndices()) {
2004 if (!idx.getType().isIndex())
2005 return emitOpError(
"src index to dma_start must have 'index' type");
2008 "src index must be a valid dimension or symbol identifier");
2010 for (
auto idx : getDstIndices()) {
2011 if (!idx.getType().isIndex())
2012 return emitOpError(
"dst index to dma_start must have 'index' type");
2015 "dst index must be a valid dimension or symbol identifier");
2017 for (
auto idx : getTagIndices()) {
2018 if (!idx.getType().isIndex())
2019 return emitOpError(
"tag index to dma_start must have 'index' type");
2022 "tag index must be a valid dimension or symbol identifier");
2027LogicalResult AffineDmaStartOp::fold(FoldAdaptor adaptor,
2033void AffineDmaStartOp::getEffects(
2052 result.addOperands(tagMemRef);
2053 result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
2054 result.addOperands(tagIndices);
2055 result.addOperands(numElements);
2059 p <<
" " << getTagMemRef() <<
'[';
2064 p <<
" : " << getTagMemRef().getType();
2072ParseResult AffineDmaWaitOp::parse(
OpAsmParser &parser,
2075 AffineMapAttr tagMapAttr;
2084 getTagMapAttrStrName(),
2093 if (!llvm::isa<MemRefType>(type))
2095 "expected tag to be of memref type");
2097 if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
2099 "tag memref operand count != to map.numInputs");
2103LogicalResult AffineDmaWaitOp::verify() {
2104 if (!llvm::isa<MemRefType>(getOperand(0).
getType()))
2105 return emitOpError(
"expected DMA tag to be of memref type");
2107 for (
auto idx : getTagIndices()) {
2108 if (!idx.getType().isIndex())
2109 return emitOpError(
"index to dma_wait must have 'index' type");
2112 "index must be a valid dimension or symbol identifier");
2117LogicalResult AffineDmaWaitOp::fold(FoldAdaptor adaptor,
2123void AffineDmaWaitOp::getEffects(
2139 ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
2140 assert(((!lbMap && lbOperands.empty()) ||
2142 "lower bound operand count does not match the affine map");
2143 assert(((!ubMap && ubOperands.empty()) ||
2145 "upper bound operand count does not match the affine map");
2146 assert(step > 0 &&
"step has to be a positive integer constant");
2148 OpBuilder::InsertionGuard guard(builder);
2152 getOperandSegmentSizeAttr(),
2154 static_cast<int32_t>(ubOperands.size()),
2155 static_cast<int32_t>(iterArgs.size())}));
2157 for (Value val : iterArgs)
2158 result.addTypes(val.getType());
2165 result.addAttribute(getLowerBoundMapAttrName(
result.name),
2166 AffineMapAttr::get(lbMap));
2167 result.addOperands(lbOperands);
2170 result.addAttribute(getUpperBoundMapAttrName(
result.name),
2171 AffineMapAttr::get(ubMap));
2172 result.addOperands(ubOperands);
2174 result.addOperands(iterArgs);
2177 Region *bodyRegion =
result.addRegion();
2179 Value inductionVar =
2181 for (Value val : iterArgs)
2182 bodyBlock->
addArgument(val.getType(), val.getLoc());
2187 if (iterArgs.empty() && !bodyBuilder) {
2188 ensureTerminator(*bodyRegion, builder,
result.location);
2189 }
else if (bodyBuilder) {
2190 OpBuilder::InsertionGuard guard(builder);
2192 bodyBuilder(builder,
result.location, inductionVar,
2199 BodyBuilderFn bodyBuilder) {
2202 return build(builder,
result, {}, lbMap, {}, ubMap, step, iterArgs,
2206LogicalResult AffineForOp::verifyRegions() {
2208 if (getStepAsInt() <= 0)
2209 return emitOpError(
"expected step to be a positive integer, got ")
2214 auto *body = getBody();
2215 if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
2216 return emitOpError(
"expected body to have a single index argument for the "
2217 "induction variable");
2221 if (getLowerBoundMap().getNumInputs() > 0)
2223 getLowerBoundMap().getNumDims())))
2226 if (getUpperBoundMap().getNumInputs() > 0)
2228 getUpperBoundMap().getNumDims())))
2230 if (getLowerBoundMap().getNumResults() < 1)
2231 return emitOpError(
"expected lower bound map to have at least one result");
2232 if (getUpperBoundMap().getNumResults() < 1)
2233 return emitOpError(
"expected upper bound map to have at least one result");
2235 unsigned opNumResults = getNumResults();
2236 if (opNumResults == 0)
2242 if (getNumIterOperands() != opNumResults)
2244 "mismatch between the number of loop-carried values and results");
2245 if (getNumRegionIterArgs() != opNumResults)
2247 "mismatch between the number of basic block args and results");
2257 bool failedToParsedMinMax =
2261 auto boundAttrStrName =
2262 isLower ? AffineForOp::getLowerBoundMapAttrName(
result.name)
2263 : AffineForOp::getUpperBoundMapAttrName(
result.name);
2270 if (!boundOpInfos.empty()) {
2272 if (boundOpInfos.size() > 1)
2274 "expected only one loop bound operand");
2286 result.addAttribute(boundAttrStrName, AffineMapAttr::get(map));
2299 if (
auto affineMapAttr = dyn_cast<AffineMapAttr>(boundAttr)) {
2300 unsigned currentNumOperands =
result.operands.size();
2305 auto map = affineMapAttr.getValue();
2306 if (map.getNumDims() != numDims)
2309 "dim operand count and affine map dim count must match");
2311 unsigned numDimAndSymbolOperands =
2312 result.operands.size() - currentNumOperands;
2313 if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
2316 "symbol operand count and affine map symbol count must match");
2320 if (map.getNumResults() > 1 && failedToParsedMinMax) {
2322 return p.
emitError(attrLoc,
"lower loop bound affine map with "
2323 "multiple results requires 'max' prefix");
2325 return p.
emitError(attrLoc,
"upper loop bound affine map with multiple "
2326 "results requires 'min' prefix");
2332 if (
auto integerAttr = dyn_cast<IntegerAttr>(boundAttr)) {
2333 result.attributes.pop_back();
2342 "expected valid affine map representation for loop bounds");
2347 OpAsmParser::Argument inductionVariable;
2354 int64_t numOperands =
result.operands.size();
2357 int64_t numLbOperands =
result.operands.size() - numOperands;
2360 numOperands =
result.operands.size();
2363 int64_t numUbOperands =
result.operands.size() - numOperands;
2368 getStepAttrName(
result.name),
2372 IntegerAttr stepAttr;
2374 getStepAttrName(
result.name).data(),
2378 if (!stepAttr.getValue().isStrictlyPositive())
2381 "expected step to be representable as a positive signed integer");
2385 SmallVector<OpAsmParser::Argument, 4> regionArgs;
2386 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
2389 regionArgs.push_back(inductionVariable);
2397 for (
auto argOperandType :
2398 llvm::zip(llvm::drop_begin(regionArgs), operands,
result.types)) {
2399 Type type = std::get<2>(argOperandType);
2400 std::get<0>(argOperandType).type = type;
2408 getOperandSegmentSizeAttr(),
2410 static_cast<int32_t>(numUbOperands),
2411 static_cast<int32_t>(operands.size())}));
2414 Region *body =
result.addRegion();
2415 if (regionArgs.size() !=
result.types.size() + 1)
2418 "mismatch between the number of loop-carried values and results");
2422 AffineForOp::ensureTerminator(*body, builder,
result.location);
2444 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
2445 p << constExpr.getValue();
2453 if (isa<AffineSymbolExpr>(expr)) {
2469unsigned AffineForOp::getNumIterOperands() {
2470 AffineMap lbMap = getLowerBoundMapAttr().getValue();
2471 AffineMap ubMap = getUpperBoundMapAttr().getValue();
2476std::optional<MutableArrayRef<OpOperand>>
2477AffineForOp::getYieldedValuesMutable() {
2478 return cast<AffineYieldOp>(getBody()->getTerminator()).getOperandsMutable();
2490 if (getStepAsInt() != 1)
2491 p <<
" step " << getStepAsInt();
2493 bool printBlockTerminators =
false;
2494 if (getNumIterOperands() > 0) {
2496 auto regionArgs = getRegionIterArgs();
2497 auto operands = getInits();
2499 llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](
auto it) {
2500 p << std::get<0>(it) <<
" = " << std::get<1>(it);
2502 p <<
") -> (" << getResultTypes() <<
")";
2503 printBlockTerminators =
true;
2508 printBlockTerminators);
2510 (*this)->getAttrs(),
2511 {getLowerBoundMapAttrName(getOperation()->getName()),
2512 getUpperBoundMapAttrName(getOperation()->getName()),
2513 getStepAttrName(getOperation()->getName()),
2514 getOperandSegmentSizeAttr()});
2519 auto foldLowerOrUpperBound = [&forOp](
bool lower) {
2523 auto boundOperands =
2524 lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
2525 for (
auto operand : boundOperands) {
2528 operandConstants.push_back(operandCst);
2532 lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
2534 "bound maps should have at least one result");
2536 if (failed(boundMap.
constantFold(operandConstants, foldedResults)))
2540 assert(!foldedResults.empty() &&
"bounds should have at least one result");
2541 auto maxOrMin = llvm::cast<IntegerAttr>(foldedResults[0]).getValue();
2542 for (
unsigned i = 1, e = foldedResults.size(); i < e; i++) {
2543 auto foldedResult = llvm::cast<IntegerAttr>(foldedResults[i]).getValue();
2544 maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
2545 : llvm::APIntOps::smin(maxOrMin, foldedResult);
2547 lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
2548 : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
2553 bool folded =
false;
2554 if (!forOp.hasConstantLowerBound())
2555 folded |= succeeded(foldLowerOrUpperBound(
true));
2558 if (!forOp.hasConstantUpperBound())
2559 folded |= succeeded(foldLowerOrUpperBound(
false));
2565 int64_t step = forOp.getStepAsInt();
2566 if (!forOp.hasConstantBounds() || step <= 0)
2567 return std::nullopt;
2568 int64_t lb = forOp.getConstantLowerBound();
2569 int64_t ub = forOp.getConstantUpperBound();
2570 return ub - lb <= 0 ? 0 : (
ub - lb + step - 1) / step;
2575 if (!llvm::hasSingleElement(*forOp.getBody()))
2577 if (forOp.getNumResults() == 0)
2580 if (tripCount == 0) {
2583 return forOp.getInits();
2586 auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
2587 auto iterArgs = forOp.getRegionIterArgs();
2588 bool hasValDefinedOutsideLoop =
false;
2589 bool iterArgsNotInOrder =
false;
2590 for (
unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
2591 Value val = yieldOp.getOperand(i);
2595 if (val == forOp.getInductionVar())
2597 if (iterArgIt == iterArgs.end()) {
2599 assert(forOp.isDefinedOutsideOfLoop(val) &&
2600 "must be defined outside of the loop");
2601 hasValDefinedOutsideLoop =
true;
2602 replacements.push_back(val);
2604 unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
2606 iterArgsNotInOrder =
true;
2607 replacements.push_back(forOp.getInits()[pos]);
2612 if (!tripCount.has_value() &&
2613 (hasValDefinedOutsideLoop || iterArgsNotInOrder))
2617 if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
2619 return llvm::to_vector_of<OpFoldResult>(replacements);
2627 auto lbMap = forOp.getLowerBoundMap();
2628 auto ubMap = forOp.getUpperBoundMap();
2629 auto prevLbMap = lbMap;
2630 auto prevUbMap = ubMap;
2643 if (lbMap == prevLbMap && ubMap == prevUbMap)
2646 if (lbMap != prevLbMap)
2647 forOp.setLowerBound(lbOperands, lbMap);
2648 if (ubMap != prevUbMap)
2649 forOp.setUpperBound(ubOperands, ubMap);
2658LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
2668 results.assign(getInits().begin(), getInits().end());
2672 if (!foldResults.empty()) {
2673 results.assign(foldResults);
2682 "invalid region point");
2689void AffineForOp::getSuccessorRegions(
2694 "expected loop region");
2700 if (tripCount.has_value()) {
2704 if (tripCount == 1) {
2705 regions.push_back(RegionSuccessor(getOperation()));
2711 if (tripCount.value() > 0) {
2712 regions.push_back(RegionSuccessor(&getRegion()));
2715 if (tripCount.value() == 0) {
2716 regions.push_back(RegionSuccessor(getOperation()));
2724 regions.push_back(RegionSuccessor(&getRegion()));
2725 regions.push_back(RegionSuccessor(getOperation()));
2730 return getResults();
2731 return getRegionIterArgs();
2744 assert(map.
getNumResults() >= 1 &&
"bound map has at least one result");
2745 getLowerBoundOperandsMutable().assign(lbOperands);
2746 setLowerBoundMap(map);
2751 assert(map.
getNumResults() >= 1 &&
"bound map has at least one result");
2752 getUpperBoundOperandsMutable().assign(ubOperands);
2753 setUpperBoundMap(map);
2756bool AffineForOp::hasConstantLowerBound() {
2757 return getLowerBoundMap().isSingleConstant();
2760bool AffineForOp::hasConstantUpperBound() {
2761 return getUpperBoundMap().isSingleConstant();
2764int64_t AffineForOp::getConstantLowerBound() {
2765 return getLowerBoundMap().getSingleConstantResult();
2768int64_t AffineForOp::getConstantUpperBound() {
2769 return getUpperBoundMap().getSingleConstantResult();
2772void AffineForOp::setConstantLowerBound(
int64_t value) {
2776void AffineForOp::setConstantUpperBound(
int64_t value) {
2780AffineForOp::operand_range AffineForOp::getControlOperands() {
2785bool AffineForOp::matchingBoundOperandList() {
2786 auto lbMap = getLowerBoundMap();
2787 auto ubMap = getUpperBoundMap();
2793 for (
unsigned i = 0, e = lbMap.
getNumInputs(); i < e; i++) {
2795 if (getOperand(i) != getOperand(numOperands + i))
2803std::optional<SmallVector<Value>> AffineForOp::getLoopInductionVars() {
2804 return SmallVector<Value>{getInductionVar()};
2807std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopLowerBounds() {
2808 if (!hasConstantLowerBound())
2809 return std::nullopt;
2811 return SmallVector<OpFoldResult>{
2812 OpFoldResult(
b.getI64IntegerAttr(getConstantLowerBound()))};
2815std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopSteps() {
2817 return SmallVector<OpFoldResult>{
2818 OpFoldResult(
b.getI64IntegerAttr(getStepAsInt()))};
2821std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopUpperBounds() {
2822 if (!hasConstantUpperBound())
2825 return SmallVector<OpFoldResult>{
2826 OpFoldResult(
b.getI64IntegerAttr(getConstantUpperBound()))};
2829std::optional<APInt> AffineForOp::getStaticTripCount() {
2831 int64_t step = getStepAsInt();
2833 return std::nullopt;
2835 if (hasConstantBounds()) {
2836 int64_t lb = getConstantLowerBound();
2837 int64_t ub = getConstantUpperBound();
2838 int64_t loopSpan = ub - lb;
2841 return APInt(64, llvm::divideCeilSigned(loopSpan, step));
2844 auto lbMap = getLowerBoundMap();
2845 auto ubMap = getUpperBoundMap();
2847 return std::nullopt;
2854 SmallVector<AffineExpr, 4> lbSplatExpr(ubValueMap.getNumResults(),
2857 lbSplatExpr, context);
2860 AffineValueMap tripCountValueMap;
2864 std::optional<uint64_t> tripCount;
2865 for (
unsigned i = 0, e = tripCountValueMap.
getNumResults(); i < e; ++i) {
2867 if (
auto constExpr = llvm::dyn_cast<AffineConstantExpr>(expr)) {
2868 uint64_t value = constExpr.getValue();
2869 if (tripCount.has_value())
2870 tripCount = std::min(*tripCount, value);
2874 return std::nullopt;
2878 if (tripCount.has_value())
2879 return APInt(64, *tripCount);
2881 return std::nullopt;
2884FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
2886 bool replaceInitOperandUsesInLoop,
2889 OpBuilder::InsertionGuard g(rewriter);
2891 auto inits = llvm::to_vector(getInits());
2892 inits.append(newInitOperands.begin(), newInitOperands.end());
2893 AffineForOp newLoop = AffineForOp::create(
2898 auto yieldOp = cast<AffineYieldOp>(getBody()->getTerminator());
2899 ArrayRef<BlockArgument> newIterArgs =
2900 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
2902 OpBuilder::InsertionGuard g(rewriter);
2904 SmallVector<Value> newYieldedValues =
2905 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
2906 assert(newInitOperands.size() == newYieldedValues.size() &&
2907 "expected as many new yield values as new iter operands");
2909 yieldOp.getOperandsMutable().append(newYieldedValues);
2914 rewriter.
mergeBlocks(getBody(), newLoop.getBody(),
2915 newLoop.getBody()->getArguments().take_front(
2916 getBody()->getNumArguments()));
2918 if (replaceInitOperandUsesInLoop) {
2921 for (
auto it : llvm::zip(newInitOperands, newIterArgs)) {
2923 [&](OpOperand &use) {
2925 return newLoop->isProperAncestor(user);
2932 newLoop->getResults().take_front(getNumResults()));
2933 return cast<LoopLikeOpInterface>(newLoop.getOperation());
2961 auto ivArg = dyn_cast<BlockArgument>(val);
2962 if (!ivArg || !ivArg.getOwner() || !ivArg.getOwner()->getParent())
2963 return AffineForOp();
2965 ivArg.getOwner()->getParent()->getParentOfType<AffineForOp>())
2967 return forOp.getInductionVar() == val ? forOp : AffineForOp();
2968 return AffineForOp();
2972 auto ivArg = dyn_cast<BlockArgument>(val);
2973 if (!ivArg || !ivArg.getOwner())
2976 auto parallelOp = dyn_cast_if_present<AffineParallelOp>(containingOp);
2977 if (parallelOp && llvm::is_contained(parallelOp.getIVs(), val))
2986 ivs->reserve(forInsts.size());
2987 for (
auto forInst : forInsts)
2988 ivs->push_back(forInst.getInductionVar());
2993 ivs.reserve(affineOps.size());
2996 if (
auto forOp = dyn_cast<AffineForOp>(op))
2997 ivs.push_back(forOp.getInductionVar());
2998 else if (
auto parallelOp = dyn_cast<AffineParallelOp>(op))
2999 for (
size_t i = 0; i < parallelOp.getBody()->getNumArguments(); i++)
3000 ivs.push_back(parallelOp.getBody()->getArgument(i));
3006template <
typename BoundListTy,
typename LoopCreatorTy>
3011 LoopCreatorTy &&loopCreatorFn) {
3012 assert(lbs.size() == ubs.size() &&
"Mismatch in number of arguments");
3013 assert(lbs.size() == steps.size() &&
"Mismatch in number of arguments");
3025 ivs.reserve(lbs.size());
3026 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
3032 if (i == e - 1 && bodyBuilderFn) {
3034 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
3036 AffineYieldOp::create(nestedBuilder, nestedLoc);
3041 auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
3050 AffineForOp::BodyBuilderFn bodyBuilderFn) {
3051 return AffineForOp::create(builder, loc, lb,
ub, step,
3059 AffineForOp::BodyBuilderFn bodyBuilderFn) {
3062 if (lbConst && ubConst)
3064 ubConst.value(), step, bodyBuilderFn);
3095 LogicalResult matchAndRewrite(AffineIfOp ifOp,
3097 if (ifOp.getElseRegion().empty() ||
3098 !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
3111 using OpRewritePattern<AffineIfOp>::OpRewritePattern;
3113 LogicalResult matchAndRewrite(AffineIfOp op,
3114 PatternRewriter &rewriter)
const override {
3116 auto isTriviallyFalse = [](IntegerSet iSet) {
3117 return iSet.isEmptyIntegerSet();
3120 auto isTriviallyTrue = [](IntegerSet iSet) {
3121 return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 &&
3122 iSet.getConstraint(0) == 0);
3125 IntegerSet affineIfConditions = op.getIntegerSet();
3127 if (isTriviallyFalse(affineIfConditions)) {
3131 if (op.getNumResults() == 0 && !op.hasElse()) {
3137 blockToMove = op.getElseBlock();
3138 }
else if (isTriviallyTrue(affineIfConditions)) {
3139 blockToMove = op.getThenBlock();
3143 Operation *blockToMoveTerminator = blockToMove->
getTerminator();
3157 rewriter.
eraseOp(blockToMoveTerminator);
3165void AffineIfOp::getSuccessorRegions(
3173 if (getElseRegion().empty()) {
3188 return getResults();
3189 if (successor == &getThenRegion())
3190 return getThenRegion().getArguments();
3191 if (successor == &getElseRegion())
3192 return getElseRegion().getArguments();
3193 llvm_unreachable(
"invalid region successor");
3196LogicalResult AffineIfOp::verify() {
3199 auto conditionAttr =
3200 (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
3202 return emitOpError(
"requires an integer set attribute named 'condition'");
3205 IntegerSet condition = conditionAttr.getValue();
3207 return emitOpError(
"operand count and condition integer set dimension and "
3208 "symbol count must match");
3220 IntegerSetAttr conditionAttr;
3223 AffineIfOp::getConditionAttrStrName(),
3229 auto set = conditionAttr.getValue();
3230 if (set.getNumDims() != numDims)
3233 "dim operand count and integer set dim count must match");
3234 if (numDims + set.getNumSymbols() !=
result.operands.size())
3237 "symbol operand count and integer set symbol count must match");
3244 result.regions.reserve(2);
3251 AffineIfOp::ensureTerminator(*thenRegion, parser.
getBuilder(),
3258 AffineIfOp::ensureTerminator(*elseRegion, parser.
getBuilder(),
3270 auto conditionAttr =
3271 (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
3272 p <<
" " << conditionAttr;
3274 conditionAttr.getValue().getNumDims(), p);
3281 auto &elseRegion = this->getElseRegion();
3282 if (!elseRegion.
empty()) {
3291 getConditionAttrStrName());
3296 ->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName())
3300void AffineIfOp::setIntegerSet(
IntegerSet newSet) {
3301 (*this)->setAttr(getConditionAttrStrName(), IntegerSetAttr::get(newSet));
3306 (*this)->setOperands(operands);
3311 bool withElseRegion) {
3312 assert(resultTypes.empty() || withElseRegion);
3315 result.addTypes(resultTypes);
3316 result.addOperands(args);
3317 result.addAttribute(getConditionAttrStrName(), IntegerSetAttr::get(set));
3321 if (resultTypes.empty())
3322 AffineIfOp::ensureTerminator(*thenRegion, builder,
result.location);
3325 if (withElseRegion) {
3327 if (resultTypes.empty())
3328 AffineIfOp::ensureTerminator(*elseRegion, builder,
result.location);
3334 AffineIfOp::build(builder,
result, {}, set, args,
3343 bool composeAffineMin =
false) {
3350 if (llvm::none_of(operands,
3361 auto set = getIntegerSet();
3367 if (getIntegerSet() == set && llvm::equal(operands, getOperands()))
3370 setConditional(set, operands);
3376 results.
add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context);
3385 assert(operands.size() == 1 + map.
getNumInputs() &&
"inconsistent operands");
3386 result.addOperands(operands);
3388 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
3389 auto memrefType = llvm::cast<MemRefType>(operands[0].
getType());
3390 result.types.push_back(memrefType.getElementType());
3395 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
3397 result.addOperands(mapOperands);
3398 auto memrefType = llvm::cast<MemRefType>(
memref.getType());
3399 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
3400 result.types.push_back(memrefType.getElementType());
3405 auto memrefType = llvm::cast<MemRefType>(
memref.getType());
3406 int64_t rank = memrefType.getRank();
3420 AffineMapAttr mapAttr;
3425 AffineLoadOp::getMapAttrStrName(),
3436 if (AffineMapAttr mapAttr =
3437 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3441 {getMapAttrStrName()});
3447template <
typename AffineMemOpTy>
3451 MemRefType memrefType,
unsigned numIndexOperands) {
3454 return op->emitOpError(
"affine map num results must equal memref rank");
3456 return op->emitOpError(
"expects as many subscripts as affine map inputs");
3458 for (
auto idx : mapOperands) {
3459 if (!idx.getType().isIndex())
3460 return op->emitOpError(
"index to load must have 'index' type");
3468LogicalResult AffineLoadOp::verify() {
3470 if (
getType() != memrefType.getElementType())
3471 return emitOpError(
"result type must match element type of memref");
3474 *
this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3475 getMapOperands(), memrefType,
3476 getNumOperands() - 1)))
3484 results.
add<SimplifyAffineOp<AffineLoadOp>>(context);
3493 auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
3498 getGlobalOp, getGlobalOp.getNameAttr());
3504 dyn_cast_or_null<DenseElementsAttr>(global.getConstantInitValue());
3508 if (
auto splatAttr = dyn_cast<SplatElementsAttr>(cstAttr))
3509 return splatAttr.getSplatValue<
Attribute>();
3511 if (!getAffineMap().isConstant())
3514 llvm::map_to_vector<4>(getAffineMap().getConstantResults(),
3515 [](
int64_t v) -> uint64_t {
return v; });
3526 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
3527 result.addOperands(valueToStore);
3529 result.addOperands(mapOperands);
3530 result.getOrAddProperties<Properties>().map = AffineMapAttr::get(map);
3537 auto memrefType = llvm::cast<MemRefType>(
memref.getType());
3538 int64_t rank = memrefType.getRank();
3552 AffineMapAttr mapAttr;
3557 mapOperands, mapAttr, AffineStoreOp::getMapAttrStrName(),
3568 p <<
" " << getValueToStore();
3570 if (AffineMapAttr mapAttr =
3571 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3575 {getMapAttrStrName()});
3579LogicalResult AffineStoreOp::verify() {
3582 if (getValueToStore().
getType() != memrefType.getElementType())
3584 "value to store must have the same type as memref element type");
3587 *
this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3588 getMapOperands(), memrefType,
3589 getNumOperands() - 2)))
3597 results.
add<SimplifyAffineOp<AffineStoreOp>>(context);
3600LogicalResult AffineStoreOp::fold(FoldAdaptor adaptor,
3610template <
typename T>
3613 if (op.getNumOperands() !=
3614 op.getMap().getNumDims() + op.getMap().getNumSymbols())
3615 return op.emitOpError(
3616 "operand count and affine map dimension and symbol count must match");
3618 if (op.getMap().getNumResults() == 0)
3619 return op.emitOpError(
"affine map expect at least one result");
3623template <
typename T>
3625 p <<
' ' << op->getAttr(T::getMapAttrStrName());
3626 auto operands = op.getOperands();
3627 unsigned numDims = op.getMap().getNumDims();
3628 p <<
'(' << operands.take_front(numDims) <<
')';
3630 if (operands.size() != numDims)
3631 p <<
'[' << operands.drop_front(numDims) <<
']';
3633 {T::getMapAttrStrName()});
3636template <
typename T>
3643 AffineMapAttr mapAttr;
3659template <
typename T>
3661 static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
3662 "expected affine min or max op");
3668 auto foldedMap = op.getMap().partialConstantFold(operands, &results);
3670 if (foldedMap.getNumSymbols() == 1 && foldedMap.isSymbolIdentity())
3671 return op.getOperand(0);
3674 if (results.empty()) {
3676 if (foldedMap == op.getMap())
3678 op->setAttr(
"map", AffineMapAttr::get(foldedMap));
3679 return op.getResult();
3683 auto resultIt = std::is_same<T, AffineMinOp>::value
3684 ? llvm::min_element(results)
3685 : llvm::max_element(results);
3686 if (resultIt == results.end())
3688 return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt);
3692template <
typename T>
3698 AffineMap oldMap = affineOp.getAffineMap();
3704 if (!llvm::is_contained(newExprs, expr))
3705 newExprs.push_back(expr);
3735template <
typename T>
3741 AffineMap oldMap = affineOp.getAffineMap();
3743 affineOp.getMapOperands().take_front(oldMap.
getNumDims());
3745 affineOp.getMapOperands().take_back(oldMap.
getNumSymbols());
3747 auto newDimOperands = llvm::to_vector<8>(dimOperands);
3748 auto newSymOperands = llvm::to_vector<8>(symOperands);
3756 if (
auto symExpr = dyn_cast<AffineSymbolExpr>(expr)) {
3757 Value symValue = symOperands[symExpr.getPosition()];
3759 producerOps.push_back(producerOp);
3762 }
else if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
3763 Value dimValue = dimOperands[dimExpr.getPosition()];
3765 producerOps.push_back(producerOp);
3772 newExprs.push_back(expr);
3775 if (producerOps.empty())
3782 for (T producerOp : producerOps) {
3783 AffineMap producerMap = producerOp.getAffineMap();
3784 unsigned numProducerDims = producerMap.
getNumDims();
3789 producerOp.getMapOperands().take_front(numProducerDims);
3791 producerOp.getMapOperands().take_back(numProducerSyms);
3792 newDimOperands.append(dimValues.begin(), dimValues.end());
3793 newSymOperands.append(symValues.begin(), symValues.end());
3797 newExprs.push_back(expr.
shiftDims(numProducerDims, numUsedDims)
3801 numUsedDims += numProducerDims;
3802 numUsedSyms += numProducerSyms;
3808 llvm::to_vector<8>(llvm::concat<Value>(newDimOperands, newSymOperands));
3827 if (!resultExpr.isPureAffine())
3832 if (failed(flattenResult))
3845 if (llvm::is_sorted(flattenedExprs))
3850 llvm::to_vector(llvm::seq<unsigned>(0, map.
getNumResults()));
3851 llvm::sort(resultPermutation, [&](
unsigned lhs,
unsigned rhs) {
3852 return flattenedExprs[
lhs] < flattenedExprs[
rhs];
3855 for (
unsigned idx : resultPermutation)
3876template <
typename T>
3882 AffineMap map = affineOp.getAffineMap();
3890template <
typename T>
3896 if (affineOp.getMap().getNumResults() != 1)
3899 affineOp.getOperands());
3967ParseResult AffinePrefetchOp::parse(
OpAsmParser &parser,
3974 IntegerAttr hintInfo;
3976 StringRef readOrWrite, cacheType;
3978 AffineMapAttr mapAttr;
3982 AffinePrefetchOp::getMapAttrStrName(),
3988 AffinePrefetchOp::getLocalityHintAttrStrName(),
3998 if (readOrWrite !=
"read" && readOrWrite !=
"write")
4000 "rw specifier has to be 'read' or 'write'");
4001 result.addAttribute(AffinePrefetchOp::getIsWriteAttrStrName(),
4004 if (cacheType !=
"data" && cacheType !=
"instr")
4006 "cache type has to be 'data' or 'instr'");
4008 result.addAttribute(AffinePrefetchOp::getIsDataCacheAttrStrName(),
4015 p <<
" " << getMemref() <<
'[';
4016 AffineMapAttr mapAttr =
4017 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
4020 p <<
']' <<
", " << (getIsWrite() ?
"write" :
"read") <<
", " <<
"locality<"
4021 << getLocalityHint() <<
">, " << (getIsDataCache() ?
"data" :
"instr");
4023 (*this)->getAttrs(),
4024 {getMapAttrStrName(), getLocalityHintAttrStrName(),
4025 getIsDataCacheAttrStrName(), getIsWriteAttrStrName()});
4029LogicalResult AffinePrefetchOp::verify() {
4030 auto mapAttr = (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
4034 return emitOpError(
"affine.prefetch affine map num results must equal"
4039 if (getNumOperands() != 1)
4044 for (
auto idx : getMapOperands()) {
4047 "index must be a valid dimension or symbol identifier");
4055 results.
add<SimplifyAffineOp<AffinePrefetchOp>>(context);
4058LogicalResult AffinePrefetchOp::fold(FoldAdaptor adaptor,
4073 auto ubs = llvm::map_to_vector<4>(ranges, [&](
int64_t value) {
4077 build(builder,
result, resultTypes, reductions, lbs, {}, ubs,
4087 assert(llvm::all_of(lbMaps,
4089 return m.
getNumDims() == lbMaps[0].getNumDims() &&
4092 "expected all lower bounds maps to have the same number of dimensions "
4094 assert(llvm::all_of(ubMaps,
4096 return m.
getNumDims() == ubMaps[0].getNumDims() &&
4099 "expected all upper bounds maps to have the same number of dimensions "
4101 assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) &&
4102 "expected lower bound maps to have as many inputs as lower bound "
4104 assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) &&
4105 "expected upper bound maps to have as many inputs as upper bound "
4109 result.addTypes(resultTypes);
4113 for (arith::AtomicRMWKind reduction : reductions)
4114 reductionAttrs.push_back(
4116 result.addAttribute(getReductionsAttrStrName(),
4126 groups.reserve(groups.size() + maps.size());
4127 exprs.reserve(maps.size());
4132 return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs,
4138 AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups);
4139 AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups);
4140 result.addAttribute(getLowerBoundsMapAttrStrName(),
4141 AffineMapAttr::get(lbMap));
4142 result.addAttribute(getLowerBoundsGroupsAttrStrName(),
4144 result.addAttribute(getUpperBoundsMapAttrStrName(),
4145 AffineMapAttr::get(ubMap));
4146 result.addAttribute(getUpperBoundsGroupsAttrStrName(),
4149 result.addOperands(lbArgs);
4150 result.addOperands(ubArgs);
4153 auto *bodyRegion =
result.addRegion();
4157 for (
unsigned i = 0, e = steps.size(); i < e; ++i)
4159 if (resultTypes.empty())
4160 ensureTerminator(*bodyRegion, builder,
result.location);
4164 return {&getRegion()};
4167unsigned AffineParallelOp::getNumDims() {
return getSteps().size(); }
4169AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
4170 return getOperands().take_front(getLowerBoundsMap().getNumInputs());
4173AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
4174 return getOperands().drop_front(getLowerBoundsMap().getNumInputs());
4177AffineMap AffineParallelOp::getLowerBoundMap(
unsigned pos) {
4178 auto values = getLowerBoundsGroups().getValues<int32_t>();
4180 for (
unsigned i = 0; i < pos; ++i)
4182 return getLowerBoundsMap().getSliceMap(start, values[pos]);
4185AffineMap AffineParallelOp::getUpperBoundMap(
unsigned pos) {
4186 auto values = getUpperBoundsGroups().getValues<int32_t>();
4188 for (
unsigned i = 0; i < pos; ++i)
4190 return getUpperBoundsMap().getSliceMap(start, values[pos]);
4194 return AffineValueMap(getLowerBoundsMap(), getLowerBoundsOperands());
4198 return AffineValueMap(getUpperBoundsMap(), getUpperBoundsOperands());
4201std::optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
4202 if (hasMinMaxBounds())
4203 return std::nullopt;
4211 for (
unsigned i = 0, e = rangesValueMap.
getNumResults(); i < e; ++i) {
4212 auto expr = rangesValueMap.
getResult(i);
4213 auto cst = dyn_cast<AffineConstantExpr>(expr);
4215 return std::nullopt;
4216 out.push_back(cst.getValue());
4221Block *AffineParallelOp::getBody() {
return &getRegion().
front(); }
4223OpBuilder AffineParallelOp::getBodyBuilder() {
4224 return OpBuilder(getBody(), std::prev(getBody()->end()));
4229 "operands to map must match number of inputs");
4231 auto ubOperands = getUpperBoundsOperands();
4234 newOperands.append(ubOperands.begin(), ubOperands.end());
4235 (*this)->setOperands(newOperands);
4237 setLowerBoundsMapAttr(AffineMapAttr::get(map));
4242 "operands to map must match number of inputs");
4245 newOperands.append(ubOperands.begin(), ubOperands.end());
4246 (*this)->setOperands(newOperands);
4248 setUpperBoundsMapAttr(AffineMapAttr::get(map));
4257 arith::AtomicRMWKind op) {
4259 case arith::AtomicRMWKind::addf:
4260 return isa<FloatType>(resultType);
4261 case arith::AtomicRMWKind::addi:
4262 return isa<IntegerType>(resultType);
4263 case arith::AtomicRMWKind::assign:
4265 case arith::AtomicRMWKind::mulf:
4266 return isa<FloatType>(resultType);
4267 case arith::AtomicRMWKind::muli:
4268 return isa<IntegerType>(resultType);
4269 case arith::AtomicRMWKind::maximumf:
4270 case arith::AtomicRMWKind::maxnumf:
4271 case arith::AtomicRMWKind::minimumf:
4272 case arith::AtomicRMWKind::minnumf:
4273 return isa<FloatType>(resultType);
4274 case arith::AtomicRMWKind::maxs: {
4275 auto intType = dyn_cast<IntegerType>(resultType);
4276 return intType && intType.isSigned();
4278 case arith::AtomicRMWKind::mins: {
4279 auto intType = dyn_cast<IntegerType>(resultType);
4280 return intType && intType.isSigned();
4282 case arith::AtomicRMWKind::maxu: {
4283 auto intType = dyn_cast<IntegerType>(resultType);
4284 return intType && intType.isUnsigned();
4286 case arith::AtomicRMWKind::minu: {
4287 auto intType = dyn_cast<IntegerType>(resultType);
4288 return intType && intType.isUnsigned();
4290 case arith::AtomicRMWKind::ori:
4291 case arith::AtomicRMWKind::andi:
4292 case arith::AtomicRMWKind::xori:
4293 return isa<IntegerType>(resultType);
4295 llvm_unreachable(
"Unhandled atomic rmw kind");
4298LogicalResult AffineParallelOp::verify() {
4299 auto numDims = getNumDims();
4302 getSteps().size() != numDims || getBody()->getNumArguments() != numDims) {
4303 return emitOpError() <<
"the number of region arguments ("
4304 << getBody()->getNumArguments()
4305 <<
") and the number of map groups for lower ("
4306 << getLowerBoundsGroups().getNumElements()
4307 <<
") and upper bound ("
4308 << getUpperBoundsGroups().getNumElements()
4309 <<
"), and the number of steps (" << getSteps().size()
4310 <<
") must all match";
4313 unsigned expectedNumLBResults = 0;
4314 for (APInt v : getLowerBoundsGroups()) {
4315 unsigned results = v.getZExtValue();
4318 <<
"expected lower bound map to have at least one result";
4319 expectedNumLBResults += results;
4321 if (expectedNumLBResults != getLowerBoundsMap().getNumResults())
4322 return emitOpError() <<
"expected lower bounds map to have "
4323 << expectedNumLBResults <<
" results";
4324 unsigned expectedNumUBResults = 0;
4325 for (APInt v : getUpperBoundsGroups()) {
4326 unsigned results = v.getZExtValue();
4329 <<
"expected upper bound map to have at least one result";
4330 expectedNumUBResults += results;
4332 if (expectedNumUBResults != getUpperBoundsMap().getNumResults())
4333 return emitOpError() <<
"expected upper bounds map to have "
4334 << expectedNumUBResults <<
" results";
4336 if (getReductions().size() != getNumResults())
4337 return emitOpError(
"a reduction must be specified for each output");
4341 for (
auto it : llvm::enumerate((getReductions()))) {
4343 auto intAttr = dyn_cast<IntegerAttr>(attr);
4344 if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
4345 return emitOpError(
"invalid reduction attribute");
4346 auto kind = arith::symbolizeAtomicRMWKind(intAttr.getInt()).value();
4348 return emitOpError(
"result type cannot match reduction attribute");
4354 getLowerBoundsMap().getNumDims())))
4358 getUpperBoundsMap().getNumDims())))
4367 if (newMap ==
getAffineMap() && newOperands == operands)
4369 reset(newMap, newOperands);
4379 bool ubCanonicalized = succeeded(
ub.canonicalize());
4382 if (!lbCanonicalized && !ubCanonicalized)
4385 if (lbCanonicalized)
4387 if (ubCanonicalized)
4388 op.setUpperBounds(
ub.getOperands(),
ub.getAffineMap());
4393LogicalResult AffineParallelOp::fold(FoldAdaptor adaptor,
4394 SmallVectorImpl<OpFoldResult> &results) {
4405 StringRef keyword) {
4408 ValueRange dimOperands = operands.take_front(numDims);
4409 ValueRange symOperands = operands.drop_front(numDims);
4411 for (llvm::APInt groupSize : group) {
4415 unsigned size = groupSize.getZExtValue();
4420 p << keyword <<
'(';
4429void AffineParallelOp::print(OpAsmPrinter &p) {
4430 p <<
" (" << getBody()->getArguments() <<
") = (";
4432 getLowerBoundsOperands(),
"max");
4435 getUpperBoundsOperands(),
"min");
4437 SmallVector<int64_t, 8> steps = getSteps();
4438 bool elideSteps = llvm::all_of(steps, [](int64_t step) {
return step == 1; });
4441 llvm::interleaveComma(steps, p);
4444 if (getNumResults()) {
4446 llvm::interleaveComma(getReductions(), p, [&](
auto &attr) {
4447 arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(
4448 llvm::cast<IntegerAttr>(attr).getInt());
4449 p <<
"\"" << arith::stringifyAtomicRMWKind(sym) <<
"\"";
4451 p <<
") -> (" << getResultTypes() <<
")";
4458 (*this)->getAttrs(),
4459 {AffineParallelOp::getReductionsAttrStrName(),
4460 AffineParallelOp::getLowerBoundsMapAttrStrName(),
4461 AffineParallelOp::getLowerBoundsGroupsAttrStrName(),
4462 AffineParallelOp::getUpperBoundsMapAttrStrName(),
4463 AffineParallelOp::getUpperBoundsGroupsAttrStrName(),
4464 AffineParallelOp::getStepsAttrStrName()});
4471static ParseResult deduplicateAndResolveOperands(
4472 OpAsmParser &parser,
4473 ArrayRef<SmallVector<OpAsmParser::UnresolvedOperand>> operands,
4474 SmallVectorImpl<Value> &uniqueOperands,
4475 SmallVectorImpl<AffineExpr> &replacements,
AffineExprKind kind) {
4477 "expected operands to be dim or symbol expression");
4480 for (
const auto &list : operands) {
4481 SmallVector<Value> valueOperands;
4484 for (Value operand : valueOperands) {
4485 unsigned pos = std::distance(uniqueOperands.begin(),
4486 llvm::find(uniqueOperands, operand));
4487 if (pos == uniqueOperands.size())
4488 uniqueOperands.push_back(operand);
4489 replacements.push_back(
4499enum class MinMaxKind { Min, Max };
4518static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser,
4523 const llvm::StringLiteral tmpAttrStrName =
"__pseudo_bound_map";
4525 StringRef mapName = kind == MinMaxKind::Min
4526 ? AffineParallelOp::getUpperBoundsMapAttrStrName()
4527 : AffineParallelOp::getLowerBoundsMapAttrStrName();
4528 StringRef groupsName =
4529 kind == MinMaxKind::Min
4530 ? AffineParallelOp::getUpperBoundsGroupsAttrStrName()
4531 : AffineParallelOp::getLowerBoundsGroupsAttrStrName();
4537 result.addAttribute(
4538 mapName, AffineMapAttr::get(parser.getBuilder().getEmptyAffineMap()));
4539 result.addAttribute(groupsName, parser.getBuilder().getI32TensorAttr({}));
4543 SmallVector<AffineExpr> flatExprs;
4544 SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> flatDimOperands;
4545 SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> flatSymOperands;
4546 SmallVector<int32_t> numMapsPerGroup;
4547 SmallVector<OpAsmParser::UnresolvedOperand> mapOperands;
4548 auto parseOperands = [&]() {
4550 kind == MinMaxKind::Min ?
"min" :
"max"))) {
4551 mapOperands.clear();
4557 result.attributes.erase(tmpAttrStrName);
4558 llvm::append_range(flatExprs, map.getValue().getResults());
4559 auto operandsRef = llvm::ArrayRef(mapOperands);
4560 auto dimsRef = operandsRef.take_front(map.getValue().getNumDims());
4561 SmallVector<OpAsmParser::UnresolvedOperand> dims(dimsRef);
4562 auto symsRef = operandsRef.drop_front(map.getValue().getNumDims());
4563 SmallVector<OpAsmParser::UnresolvedOperand> syms(symsRef);
4564 flatDimOperands.append(map.getValue().getNumResults(), dims);
4565 flatSymOperands.append(map.getValue().getNumResults(), syms);
4566 numMapsPerGroup.push_back(map.getValue().getNumResults());
4569 flatSymOperands.emplace_back(),
4570 flatExprs.emplace_back())))
4572 numMapsPerGroup.push_back(1);
4579 unsigned totalNumDims = 0;
4580 unsigned totalNumSyms = 0;
4581 for (
unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
4582 unsigned numDims = flatDimOperands[i].size();
4583 unsigned numSyms = flatSymOperands[i].size();
4584 flatExprs[i] = flatExprs[i]
4585 .shiftDims(numDims, totalNumDims)
4586 .shiftSymbols(numSyms, totalNumSyms);
4587 totalNumDims += numDims;
4588 totalNumSyms += numSyms;
4592 SmallVector<Value> dimOperands, symOperands;
4593 SmallVector<AffineExpr> dimRplacements, symRepacements;
4594 if (deduplicateAndResolveOperands(parser, flatDimOperands, dimOperands,
4596 deduplicateAndResolveOperands(parser, flatSymOperands, symOperands,
4600 result.operands.append(dimOperands.begin(), dimOperands.end());
4601 result.operands.append(symOperands.begin(), symOperands.end());
4604 auto flatMap =
AffineMap::get(totalNumDims, totalNumSyms, flatExprs,
4606 flatMap = flatMap.replaceDimsAndSymbols(
4607 dimRplacements, symRepacements, dimOperands.size(), symOperands.size());
4609 result.addAttribute(mapName, AffineMapAttr::get(flatMap));
4619ParseResult AffineParallelOp::parse(OpAsmParser &parser,
4620 OperationState &
result) {
4623 SmallVector<OpAsmParser::Argument, 4> ivs;
4626 parseAffineMapWithMinMax(parser,
result, MinMaxKind::Max) ||
4628 parseAffineMapWithMinMax(parser,
result, MinMaxKind::Min))
4631 AffineMapAttr stepsMapAttr;
4632 NamedAttrList stepsAttrs;
4633 SmallVector<OpAsmParser::UnresolvedOperand, 4> stepsMapOperands;
4635 SmallVector<int64_t, 4> steps(ivs.size(), 1);
4636 result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4640 AffineParallelOp::getStepsAttrStrName(),
4646 SmallVector<int64_t, 4> steps;
4647 auto stepsMap = stepsMapAttr.getValue();
4648 for (
const auto &
result : stepsMap.getResults()) {
4649 auto constExpr = dyn_cast<AffineConstantExpr>(
result);
4652 "steps must be constant integers");
4653 steps.push_back(constExpr.getValue());
4655 result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4661 SmallVector<Attribute, 4> reductions;
4665 auto parseAttributes = [&]() -> ParseResult {
4670 NamedAttrList attrStorage;
4675 std::optional<arith::AtomicRMWKind> reduction =
4676 arith::symbolizeAtomicRMWKind(attrVal.getValue());
4678 return parser.
emitError(loc,
"invalid reduction value: ") << attrVal;
4679 reductions.push_back(
4687 result.addAttribute(AffineParallelOp::getReductionsAttrStrName(),
4695 Region *body =
result.addRegion();
4696 for (
auto &iv : ivs)
4697 iv.type = indexType;
4703 AffineParallelOp::ensureTerminator(*body, builder,
result.location);
4711LogicalResult AffineYieldOp::verify() {
4712 auto *parentOp = (*this)->getParentOp();
4713 auto results = parentOp->getResults();
4714 auto operands = getOperands();
4716 if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
4717 return emitOpError() <<
"only terminates affine.if/for/parallel regions";
4718 if (parentOp->getNumResults() != getNumOperands())
4719 return emitOpError() <<
"parent of yield must have same number of "
4720 "results as the yield operands";
4721 for (
auto it : llvm::zip(results, operands)) {
4723 return emitOpError() <<
"types mismatch between yield op and its parent";
4733void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &
result,
4734 VectorType resultType, AffineMap map,
4736 assert(operands.size() == 1 + map.
getNumInputs() &&
"inconsistent operands");
4737 result.addOperands(operands);
4739 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4740 result.types.push_back(resultType);
4743void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &
result,
4744 VectorType resultType, Value memref,
4746 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
4747 result.addOperands(memref);
4748 result.addOperands(mapOperands);
4749 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4750 result.types.push_back(resultType);
4753void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &
result,
4754 VectorType resultType, Value memref,
4756 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
4757 int64_t rank = memrefType.getRank();
4765void AffineVectorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4766 MLIRContext *context) {
4767 results.
add<SimplifyAffineOp<AffineVectorLoadOp>>(context);
4770ParseResult AffineVectorLoadOp::parse(OpAsmParser &parser,
4771 OperationState &
result) {
4775 MemRefType memrefType;
4776 VectorType resultType;
4777 OpAsmParser::UnresolvedOperand memrefInfo;
4778 AffineMapAttr mapAttr;
4779 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
4783 AffineVectorLoadOp::getMapAttrStrName(),
4793void AffineVectorLoadOp::print(OpAsmPrinter &p) {
4795 if (AffineMapAttr mapAttr =
4796 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4800 {getMapAttrStrName()});
4805static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType,
4806 VectorType vectorType) {
4808 if (memrefType.getElementType() != vectorType.getElementType())
4810 "requires memref and vector types of the same elemental type");
4814LogicalResult AffineVectorLoadOp::verify() {
4817 *
this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4818 getMapOperands(), memrefType,
4819 getNumOperands() - 1)))
4832void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &
result,
4833 Value valueToStore, Value memref, AffineMap map,
4835 assert(map.
getNumInputs() == mapOperands.size() &&
"inconsistent index info");
4836 result.addOperands(valueToStore);
4837 result.addOperands(memref);
4838 result.addOperands(mapOperands);
4839 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4843void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &
result,
4844 Value valueToStore, Value memref,
4846 auto memrefType = llvm::cast<MemRefType>(memref.
getType());
4847 int64_t rank = memrefType.getRank();
4854void AffineVectorStoreOp::getCanonicalizationPatterns(
4855 RewritePatternSet &results, MLIRContext *context) {
4856 results.
add<SimplifyAffineOp<AffineVectorStoreOp>>(context);
4859ParseResult AffineVectorStoreOp::parse(OpAsmParser &parser,
4860 OperationState &
result) {
4863 MemRefType memrefType;
4864 VectorType resultType;
4865 OpAsmParser::UnresolvedOperand storeValueInfo;
4866 OpAsmParser::UnresolvedOperand memrefInfo;
4867 AffineMapAttr mapAttr;
4868 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
4873 AffineVectorStoreOp::getMapAttrStrName(),
4883void AffineVectorStoreOp::print(OpAsmPrinter &p) {
4884 p <<
" " << getValueToStore();
4886 if (AffineMapAttr mapAttr =
4887 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4891 {getMapAttrStrName()});
4892 p <<
" : " <<
getMemRefType() <<
", " << getValueToStore().getType();
4895LogicalResult AffineVectorStoreOp::verify() {
4898 *
this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4899 getMapOperands(), memrefType,
4900 getNumOperands() - 2)))
4913void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4914 OperationState &odsState,
4916 ArrayRef<int64_t> staticBasis,
4917 bool hasOuterBound) {
4918 SmallVector<Type> returnTypes(hasOuterBound ? staticBasis.size()
4919 : staticBasis.size() + 1,
4921 build(odsBuilder, odsState, returnTypes, linearIndex, dynamicBasis,
4925void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4926 OperationState &odsState,
4928 bool hasOuterBound) {
4929 if (hasOuterBound && !basis.empty() && basis.front() ==
nullptr) {
4930 hasOuterBound =
false;
4931 basis = basis.drop_front();
4933 SmallVector<Value> dynamicBasis;
4934 SmallVector<int64_t> staticBasis;
4937 build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4941void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4942 OperationState &odsState,
4944 ArrayRef<OpFoldResult> basis,
4945 bool hasOuterBound) {
4946 if (hasOuterBound && !basis.empty() && basis.front() == OpFoldResult()) {
4947 hasOuterBound =
false;
4948 basis = basis.drop_front();
4950 SmallVector<Value> dynamicBasis;
4951 SmallVector<int64_t> staticBasis;
4953 build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4957void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4958 OperationState &odsState,
4959 Value linearIndex, ArrayRef<int64_t> basis,
4960 bool hasOuterBound) {
4961 build(odsBuilder, odsState, linearIndex,
ValueRange{}, basis, hasOuterBound);
4964LogicalResult AffineDelinearizeIndexOp::verify() {
4965 ArrayRef<int64_t> staticBasis = getStaticBasis();
4966 if (getNumResults() != staticBasis.size() &&
4967 getNumResults() != staticBasis.size() + 1)
4968 return emitOpError(
"should return an index for each basis element and up "
4969 "to one extra index");
4971 auto dynamicMarkersCount = llvm::count_if(staticBasis, ShapedType::isDynamic);
4972 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
4974 "mismatch between dynamic and static basis (kDynamic marker but no "
4975 "corresponding dynamic basis entry) -- this can only happen due to an "
4976 "incorrect fold/rewrite");
4978 if (!llvm::all_of(staticBasis, [](int64_t v) {
4979 return v > 0 || ShapedType::isDynamic(v);
4981 return emitOpError(
"no basis element may be statically non-positive");
4990static std::optional<SmallVector<int64_t>>
4994 uint64_t dynamicBasisIndex = 0;
5000 if (basis && isa<IntegerAttr>(basis)) {
5001 mutableDynamicBasis.
erase(dynamicBasisIndex);
5003 ++dynamicBasisIndex;
5008 if (dynamicBasisIndex == dynamicBasis.size())
5009 return std::nullopt;
5015 staticBasis.push_back(ShapedType::kDynamic);
5017 staticBasis.push_back(*basisVal);
5024AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
5025 SmallVectorImpl<OpFoldResult> &
result) {
5026 std::optional<SmallVector<int64_t>> maybeStaticBasis =
5028 adaptor.getDynamicBasis());
5029 if (maybeStaticBasis) {
5030 setStaticBasis(*maybeStaticBasis);
5035 if (getNumResults() == 1) {
5036 result.push_back(getLinearIndex());
5040 if (adaptor.getLinearIndex() ==
nullptr)
5043 if (!adaptor.getDynamicBasis().empty())
5046 int64_t highPart = cast<IntegerAttr>(adaptor.getLinearIndex()).getInt();
5047 Type attrType = getLinearIndex().getType();
5049 ArrayRef<int64_t> staticBasis = getStaticBasis();
5050 if (hasOuterBound())
5051 staticBasis = staticBasis.drop_front();
5052 for (int64_t modulus : llvm::reverse(staticBasis)) {
5053 result.push_back(IntegerAttr::get(attrType, llvm::mod(highPart, modulus)));
5054 highPart = llvm::divideFloorSigned(highPart, modulus);
5056 result.push_back(IntegerAttr::get(attrType, highPart));
5061SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getEffectiveBasis() {
5063 if (hasOuterBound()) {
5064 if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
5066 getDynamicBasis().drop_front(), builder);
5068 return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
5072 return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
5075SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getPaddedBasis() {
5076 SmallVector<OpFoldResult> ret = getMixedBasis();
5077 if (!hasOuterBound())
5078 ret.insert(ret.begin(), OpFoldResult());
5085struct DropUnitExtentBasis
5086 :
public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
5089 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
5090 PatternRewriter &rewriter)
const override {
5091 SmallVector<Value> replacements(delinearizeOp->getNumResults(),
nullptr);
5092 std::optional<Value> zero = std::nullopt;
5093 Location loc = delinearizeOp->getLoc();
5094 Type indexType = delinearizeOp.getLinearIndex().getType();
5095 auto getZero = [&]() -> Value {
5097 zero = arith::ConstantOp::create(rewriter, loc,
5099 return zero.value();
5104 SmallVector<OpFoldResult> newBasis;
5105 for (
auto [index, basis] :
5106 llvm::enumerate(delinearizeOp.getPaddedBasis())) {
5107 std::optional<int64_t> basisVal =
5110 replacements[index] =
getZero();
5112 newBasis.push_back(basis);
5115 if (newBasis.size() == delinearizeOp.getNumResults())
5117 "no unit basis elements");
5119 if (!newBasis.empty()) {
5121 auto newDelinearizeOp = affine::AffineDelinearizeIndexOp::create(
5122 rewriter, loc, delinearizeOp.getLinearIndex(), newBasis);
5128 replacement = newDelinearizeOp->getResult(newIndex++);
5132 rewriter.
replaceOp(delinearizeOp, replacements);
5147struct CancelDelinearizeOfLinearizeDisjointExactTail
5148 :
public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
5151 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
5152 PatternRewriter &rewriter)
const override {
5153 auto linearizeOp = delinearizeOp.getLinearIndex()
5154 .getDefiningOp<affine::AffineLinearizeIndexOp>();
5157 "index doesn't come from linearize");
5159 if (!linearizeOp.getDisjoint())
5162 ValueRange linearizeIns = linearizeOp.getMultiIndex();
5164 SmallVector<OpFoldResult> linearizeBasis = linearizeOp.getMixedBasis();
5165 SmallVector<OpFoldResult> delinearizeBasis = delinearizeOp.getMixedBasis();
5166 size_t numMatches = 0;
5167 for (
auto [linSize, delinSize] : llvm::zip(
5168 llvm::reverse(linearizeBasis), llvm::reverse(delinearizeBasis))) {
5169 if (linSize != delinSize)
5174 if (numMatches == 0)
5176 delinearizeOp,
"final basis element doesn't match linearize");
5179 if (numMatches == linearizeBasis.size() &&
5180 numMatches == delinearizeBasis.size() &&
5181 linearizeIns.size() == delinearizeOp.getNumResults()) {
5182 rewriter.
replaceOp(delinearizeOp, linearizeOp.getMultiIndex());
5186 Value newLinearize = affine::AffineLinearizeIndexOp::create(
5187 rewriter, linearizeOp.getLoc(), linearizeIns.drop_back(numMatches),
5188 ArrayRef<OpFoldResult>{linearizeBasis}.drop_back(numMatches),
5189 linearizeOp.getDisjoint());
5190 auto newDelinearize = affine::AffineDelinearizeIndexOp::create(
5191 rewriter, delinearizeOp.getLoc(), newLinearize,
5192 ArrayRef<OpFoldResult>{delinearizeBasis}.drop_back(numMatches),
5193 delinearizeOp.hasOuterBound());
5194 SmallVector<Value> mergedResults(newDelinearize.getResults());
5195 mergedResults.append(linearizeIns.take_back(numMatches).begin(),
5196 linearizeIns.take_back(numMatches).end());
5197 rewriter.
replaceOp(delinearizeOp, mergedResults);
5215struct SplitDelinearizeSpanningLastLinearizeArg final
5216 : OpRewritePattern<affine::AffineDelinearizeIndexOp> {
5219 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
5220 PatternRewriter &rewriter)
const override {
5221 auto linearizeOp = delinearizeOp.getLinearIndex()
5222 .getDefiningOp<affine::AffineLinearizeIndexOp>();
5225 "index doesn't come from linearize");
5227 if (!linearizeOp.getDisjoint())
5229 "linearize isn't disjoint");
5234 if (linearizeOp.getStaticBasis().empty())
5236 linearizeOp,
"linearize has no basis elements (no inputs)");
5238 int64_t
target = linearizeOp.getStaticBasis().back();
5239 if (ShapedType::isDynamic(
target))
5241 linearizeOp,
"linearize ends with dynamic basis value");
5243 int64_t sizeToSplit = 1;
5244 size_t elemsToSplit = 0;
5245 ArrayRef<int64_t> basis = delinearizeOp.getStaticBasis();
5246 for (int64_t basisElem : llvm::reverse(basis)) {
5247 if (ShapedType::isDynamic(basisElem))
5249 delinearizeOp,
"dynamic basis element while scanning for split");
5250 sizeToSplit *= basisElem;
5253 if (sizeToSplit >
target)
5255 "overshot last argument size");
5256 if (sizeToSplit ==
target)
5260 if (sizeToSplit <
target)
5262 delinearizeOp,
"product of known basis elements doesn't exceed last "
5263 "linearize argument");
5265 if (elemsToSplit < 2)
5268 "need at least two elements to form the basis product");
5270 Value linearizeWithoutBack = affine::AffineLinearizeIndexOp::create(
5271 rewriter, linearizeOp.getLoc(), linearizeOp.getLinearIndex().getType(),
5272 linearizeOp.getMultiIndex().drop_back(), linearizeOp.getDynamicBasis(),
5273 linearizeOp.getStaticBasis().drop_back(), linearizeOp.getDisjoint());
5274 auto delinearizeWithoutSplitPart = affine::AffineDelinearizeIndexOp::create(
5275 rewriter, delinearizeOp.getLoc(), linearizeWithoutBack,
5276 delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),
5277 delinearizeOp.hasOuterBound());
5278 auto delinearizeBack = affine::AffineDelinearizeIndexOp::create(
5279 rewriter, delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(),
5280 basis.take_back(elemsToSplit),
true);
5281 SmallVector<Value> results = llvm::to_vector(
5282 llvm::concat<Value>(delinearizeWithoutSplitPart.getResults(),
5283 delinearizeBack.getResults()));
5284 rewriter.
replaceOp(delinearizeOp, results);
5291void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
5292 RewritePatternSet &patterns, MLIRContext *context) {
5294 .
insert<CancelDelinearizeOfLinearizeDisjointExactTail,
5295 DropUnitExtentBasis, SplitDelinearizeSpanningLastLinearizeArg>(
5306 if (multiIndex.empty())
5307 return IndexType::get(ctx);
5308 return multiIndex.front().
getType();
5311void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
5312 OperationState &odsState,
5315 if (!basis.empty() && basis.front() == Value())
5316 basis = basis.drop_front();
5317 SmallVector<Value> dynamicBasis;
5318 SmallVector<int64_t> staticBasis;
5322 build(odsBuilder, odsState, resultType, multiIndex, dynamicBasis, staticBasis,
5326void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
5327 OperationState &odsState,
5329 ArrayRef<OpFoldResult> basis,
5331 if (!basis.empty() && basis.front() == OpFoldResult())
5332 basis = basis.drop_front();
5333 SmallVector<Value> dynamicBasis;
5334 SmallVector<int64_t> staticBasis;
5337 build(odsBuilder, odsState, resultType, multiIndex, dynamicBasis, staticBasis,
5341void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
5342 OperationState &odsState,
5344 ArrayRef<int64_t> basis,
bool disjoint) {
5346 build(odsBuilder, odsState, resultType, multiIndex,
ValueRange{}, basis,
5350LogicalResult AffineLinearizeIndexOp::verify() {
5351 size_t numIndexes = getMultiIndex().size();
5352 size_t numBasisElems = getStaticBasis().size();
5353 if (numIndexes != numBasisElems && numIndexes != numBasisElems + 1)
5354 return emitOpError(
"should be passed a basis element for each index except "
5355 "possibly the first");
5357 auto dynamicMarkersCount =
5358 llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
5359 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
5361 "mismatch between dynamic and static basis (kDynamic marker but no "
5362 "corresponding dynamic basis entry) -- this can only happen due to an "
5363 "incorrect fold/rewrite");
5368OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
5369 std::optional<SmallVector<int64_t>> maybeStaticBasis =
5371 adaptor.getDynamicBasis());
5372 if (maybeStaticBasis) {
5373 setStaticBasis(*maybeStaticBasis);
5377 if (getMultiIndex().empty())
5378 return IntegerAttr::get(getResult().
getType(), 0);
5381 if (getMultiIndex().size() == 1)
5382 return getMultiIndex().front();
5387 if (llvm::any_of(adaptor.getMultiIndex(), [](Attribute a) {
5388 return !isa_and_nonnull<IntegerAttr>(a);
5392 if (!adaptor.getDynamicBasis().empty())
5397 for (
auto [length, indexAttr] :
5398 llvm::zip_first(llvm::reverse(getStaticBasis()),
5399 llvm::reverse(adaptor.getMultiIndex()))) {
5400 result =
result + cast<IntegerAttr>(indexAttr).getInt() * stride;
5401 stride = stride * length;
5404 if (!hasOuterBound())
5407 cast<IntegerAttr>(adaptor.getMultiIndex().front()).getInt() * stride;
5412SmallVector<OpFoldResult> AffineLinearizeIndexOp::getEffectiveBasis() {
5414 if (hasOuterBound()) {
5415 if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
5417 getDynamicBasis().drop_front(), builder);
5419 return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
5423 return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
5426SmallVector<OpFoldResult> AffineLinearizeIndexOp::getPaddedBasis() {
5427 SmallVector<OpFoldResult> ret = getMixedBasis();
5428 if (!hasOuterBound())
5429 ret.insert(ret.begin(), OpFoldResult());
5444struct DropLinearizeUnitComponentsIfDisjointOrZero final
5445 : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5448 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5449 PatternRewriter &rewriter)
const override {
5451 size_t numIndices = multiIndex.size();
5452 SmallVector<Value> newIndices;
5453 newIndices.reserve(numIndices);
5454 SmallVector<OpFoldResult> newBasis;
5455 newBasis.reserve(numIndices);
5457 if (!op.hasOuterBound()) {
5458 newIndices.push_back(multiIndex.front());
5459 multiIndex = multiIndex.drop_front();
5462 SmallVector<OpFoldResult> basis = op.getMixedBasis();
5463 for (
auto [index, basisElem] : llvm::zip_equal(multiIndex, basis)) {
5465 if (!basisEntry || *basisEntry != 1) {
5466 newIndices.push_back(index);
5467 newBasis.push_back(basisElem);
5472 if (!op.getDisjoint() && (!indexValue || *indexValue != 0)) {
5473 newIndices.push_back(index);
5474 newBasis.push_back(basisElem);
5478 if (newIndices.size() == numIndices)
5480 "no unit basis entries to replace");
5482 if (newIndices.empty()) {
5484 op, rewriter.
getZeroAttr(op.getLinearIndex().getType()));
5488 op, newIndices, newBasis, op.getDisjoint());
5494 ArrayRef<OpFoldResult> terms) {
5495 int64_t nDynamic = 0;
5496 SmallVector<Value> dynamicPart;
5498 for (OpFoldResult term : terms) {
5505 dynamicPart.push_back(cast<Value>(term));
5509 if (
auto constant = dyn_cast<AffineConstantExpr>(
result))
5511 return AffineApplyOp::create(builder, loc,
result, dynamicPart).getResult();
5541struct CancelLinearizeOfDelinearizePortion final
5542 : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5552 unsigned linStart = 0;
5553 unsigned delinStart = 0;
5554 unsigned length = 0;
5558 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp,
5559 PatternRewriter &rewriter)
const override {
5560 SmallVector<Match> matches;
5562 const SmallVector<OpFoldResult> linBasis = linearizeOp.getPaddedBasis();
5563 ArrayRef<OpFoldResult> linBasisRef = linBasis;
5565 ValueRange multiIndex = linearizeOp.getMultiIndex();
5566 unsigned numLinArgs = multiIndex.size();
5567 unsigned linArgIdx = 0;
5570 llvm::SmallPtrSet<Operation *, 2> alreadyMatchedDelinearize;
5571 while (linArgIdx < numLinArgs) {
5572 auto asResult = dyn_cast<OpResult>(multiIndex[linArgIdx]);
5578 auto delinearizeOp =
5579 dyn_cast<AffineDelinearizeIndexOp>(asResult.getOwner());
5580 if (!delinearizeOp) {
5597 unsigned delinArgIdx = asResult.getResultNumber();
5598 SmallVector<OpFoldResult> delinBasis = delinearizeOp.getPaddedBasis();
5599 OpFoldResult firstDelinBound = delinBasis[delinArgIdx];
5600 OpFoldResult firstLinBound = linBasis[linArgIdx];
5601 bool boundsMatch = firstDelinBound == firstLinBound;
5602 bool bothAtFront = linArgIdx == 0 && delinArgIdx == 0;
5603 bool knownByDisjoint =
5604 linearizeOp.getDisjoint() && delinArgIdx == 0 && !firstDelinBound;
5605 if (!boundsMatch && !bothAtFront && !knownByDisjoint) {
5611 unsigned numDelinOuts = delinearizeOp.getNumResults();
5612 for (; j + linArgIdx < numLinArgs && j + delinArgIdx < numDelinOuts;
5614 if (multiIndex[linArgIdx + j] !=
5615 delinearizeOp.getResult(delinArgIdx + j))
5617 if (linBasis[linArgIdx + j] != delinBasis[delinArgIdx + j])
5623 if (j <= 1 || !alreadyMatchedDelinearize.insert(delinearizeOp).second) {
5627 matches.push_back(Match{delinearizeOp, linArgIdx, delinArgIdx, j});
5631 if (matches.empty())
5633 linearizeOp,
"no run of delinearize outputs to deal with");
5638 SmallVector<SmallVector<Value>> delinearizeReplacements;
5640 SmallVector<Value> newIndex;
5641 newIndex.reserve(numLinArgs);
5642 SmallVector<OpFoldResult> newBasis;
5643 newBasis.reserve(numLinArgs);
5644 unsigned prevMatchEnd = 0;
5645 for (Match m : matches) {
5646 unsigned gap = m.linStart - prevMatchEnd;
5647 llvm::append_range(newIndex, multiIndex.slice(prevMatchEnd, gap));
5648 llvm::append_range(newBasis, linBasisRef.slice(prevMatchEnd, gap));
5650 prevMatchEnd = m.linStart + m.length;
5652 PatternRewriter::InsertionGuard g(rewriter);
5655 ArrayRef<OpFoldResult> basisToMerge =
5656 linBasisRef.slice(m.linStart, m.length);
5659 OpFoldResult newSize =
5664 newIndex.push_back(m.delinearize.getLinearIndex());
5665 newBasis.push_back(newSize);
5667 delinearizeReplacements.push_back(SmallVector<Value>());
5671 SmallVector<Value> newDelinResults;
5672 SmallVector<OpFoldResult> newDelinBasis = m.delinearize.getPaddedBasis();
5673 newDelinBasis.erase(newDelinBasis.begin() + m.delinStart,
5674 newDelinBasis.begin() + m.delinStart + m.length);
5675 newDelinBasis.insert(newDelinBasis.begin() + m.delinStart, newSize);
5676 auto newDelinearize = AffineDelinearizeIndexOp::create(
5677 rewriter, m.delinearize.getLoc(), m.delinearize.getLinearIndex(),
5683 Value combinedElem = newDelinearize.getResult(m.delinStart);
5684 auto residualDelinearize = AffineDelinearizeIndexOp::create(
5685 rewriter, m.delinearize.getLoc(), combinedElem, basisToMerge);
5690 llvm::append_range(newDelinResults,
5691 newDelinearize.getResults().take_front(m.delinStart));
5692 llvm::append_range(newDelinResults, residualDelinearize.getResults());
5695 newDelinearize.getResults().drop_front(m.delinStart + 1));
5697 delinearizeReplacements.push_back(newDelinResults);
5698 newIndex.push_back(combinedElem);
5699 newBasis.push_back(newSize);
5701 llvm::append_range(newIndex, multiIndex.drop_front(prevMatchEnd));
5702 llvm::append_range(newBasis, linBasisRef.drop_front(prevMatchEnd));
5704 linearizeOp, newIndex, newBasis, linearizeOp.getDisjoint());
5706 for (
auto [m, newResults] :
5707 llvm::zip_equal(matches, delinearizeReplacements)) {
5708 if (newResults.empty())
5710 rewriter.
replaceOp(m.delinearize, newResults);
5721struct DropLinearizeLeadingZero final
5722 : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5725 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5726 PatternRewriter &rewriter)
const override {
5727 Value leadingIdx = op.getMultiIndex().front();
5731 if (op.getMultiIndex().size() == 1) {
5736 SmallVector<OpFoldResult> mixedBasis = op.getMixedBasis();
5737 ArrayRef<OpFoldResult> newMixedBasis = mixedBasis;
5738 if (op.hasOuterBound())
5739 newMixedBasis = newMixedBasis.drop_front();
5742 op, op.getMultiIndex().drop_front(), newMixedBasis, op.getDisjoint());
5748void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
5749 RewritePatternSet &patterns, MLIRContext *context) {
5750 patterns.
add<CancelLinearizeOfDelinearizePortion, DropLinearizeLeadingZero,
5751 DropLinearizeUnitComponentsIfDisjointOrZero>(context);
5758#define GET_OP_CLASSES
5759#include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
static AffineForOp buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb, int64_t ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds known to be constants.
static bool hasTrivialZeroTripCount(AffineForOp op)
Returns true if the affine.for has zero iterations in trivial cases.
static Type inferIndexType(MLIRContext *ctx, ValueRange multiIndex)
Infer the index type from a set of multi-index values. Returns the common type (index or vector<....
static LogicalResult verifyMemoryOpIndexing(AffineMemOpTy op, AffineMapAttr mapAttr, Operation::operand_range mapOperands, MemRefType memrefType, unsigned numIndexOperands)
Verify common indexing invariants of affine.load, affine.store, affine.vector_load and affine....
static void printAffineMinMaxOp(OpAsmPrinter &p, T op)
static bool isResultTypeMatchAtomicRMWKind(Type resultType, arith::AtomicRMWKind op)
static bool remainsLegalAfterInline(Value value, Region *src, Region *dest, const IRMapping &mapping, function_ref< bool(Value, Region *)> legalityCheck)
Checks if value known to be a legal affine dimension or symbol in src region remains legal if the ope...
static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr, DenseIntElementsAttr group, ValueRange operands, StringRef keyword)
Prints a lower(upper) bound of an affine parallel loop with max(min) conditions in it.
static OpFoldResult foldMinMaxOp(T op, ArrayRef< Attribute > operands)
Fold an affine min or max operation with the given operands.
static bool isTopLevelValueOrAbove(Value value, Region *region)
A utility function to check if a value is defined at the top level of region or is an argument of reg...
static LogicalResult canonicalizeLoopBounds(AffineForOp forOp)
Canonicalize the bounds of the given loop.
static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims, unsigned numSymbols, ArrayRef< Value > operands)
Simplify expr while exploiting information from the values in operands.
static bool isValidAffineIndexOperand(Value value, Region *region)
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
static ParseResult parseBound(bool isLower, OperationState &result, OpAsmParser &p)
Parse a for operation loop bounds.
static std::optional< SmallVector< int64_t > > foldCstValueToCstAttrBasis(ArrayRef< OpFoldResult > mixedBasis, MutableOperandRange mutableDynamicBasis, ArrayRef< Attribute > dynamicBasis)
Given mixed basis of affine.delinearize_index/linearize_index replace constant SSA values with the co...
static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
static void simplifyMinOrMaxExprWithOperands(AffineMap &map, ArrayRef< Value > operands, bool isMax)
Simplify the expressions in map while making use of lower or upper bounds of its operands.
static ParseResult parseAffineMinMaxOp(OpAsmParser &parser, OperationState &result)
static LogicalResult replaceAffineDelinearizeIndexInverseExpression(AffineDelinearizeIndexOp delinOp, Value resultToReplace, AffineMap *map, SmallVectorImpl< Value > &dims, SmallVectorImpl< Value > &syms)
If this map contains of the expression x_1 + x_1 * C_1 + ... x_n * C_N + / ... (not necessarily in or...
static void composeSetAndOperands(IntegerSet &set, SmallVectorImpl< Value > &operands, bool composeAffineMin=false)
Compose any affine.apply ops feeding into operands of the integer set set by composing the maps of su...
static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, Region *region)
Returns true if the 'index' dimension of the memref defined by memrefDefOp is a statically shaped one...
static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef< Value > operands, int64_t k)
Check if e is known to be: 0 <= e < k.
static AffineForOp buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds that may or may not be constants.
static void simplifyMapWithOperands(AffineMap &map, ArrayRef< Value > operands)
Simplify the map while exploiting information on the values in operands.
static void printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, OpAsmPrinter &printer)
Prints dimension and symbol list.
static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef< Value > operands)
Returns the largest known divisor of e.
static void composeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands, bool composeAffineMin=false)
Iterate over operands and fold away all those produced by an AffineApplyOp iteratively.
static void legalizeDemotedDims(MapOrSet &mapOrSet, SmallVectorImpl< Value > &operands)
A valid affine dimension may appear as a symbol in affine.apply operations.
static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static void buildAffineLoopNestImpl(OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, LoopCreatorTy &&loopCreatorFn)
Builds an affine loop nest, using "loopCreatorFn" to create individual loop operations.
static LogicalResult foldLoopBounds(AffineForOp forOp)
Fold the constant bounds of a loop.
static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp, AffineExpr dimOrSym, AffineMap *map, ValueRange dims, ValueRange syms)
Assuming dimOrSym is a quantity in the apply op map map and defined by minOp = affine_min(x_1,...
static SmallVector< OpFoldResult > AffineForEmptyLoopFolder(AffineForOp forOp)
Fold the empty loop.
static LogicalResult verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, unsigned numDims)
Utility function to verify that a set of operands are valid dimension and symbol identifiers.
static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region)
Returns true if the result of the dim op is a valid symbol for region.
static bool isQTimesDPlusR(AffineExpr e, ArrayRef< Value > operands, int64_t &div, AffineExpr "ientTimesDiv, AffineExpr &rem)
Check if expression e is of the form d*e_1 + e_2 where 0 <= e_2 < d.
static LogicalResult replaceDimOrSym(AffineMap *map, unsigned dimOrSymbolPosition, SmallVectorImpl< Value > &dims, SmallVectorImpl< Value > &syms, bool replaceAffineMin)
Replace all occurrences of AffineExpr at position pos in map by the defining AffineApplyOp expression...
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static std::optional< uint64_t > getTrivialConstantTripCount(AffineForOp forOp)
Returns constant trip count in trivial cases.
static LogicalResult verifyAffineMinMaxOp(T op)
static void printBound(AffineMapAttr boundMap, Operation::operand_range boundOperands, const char *prefix, OpAsmPrinter &p)
static void shortenAddChainsContainingAll(AffineExpr e, const llvm::SmallDenseSet< AffineExpr, 4 > &exprsToRemove, AffineExpr newVal, DenseMap< AffineExpr, AffineExpr > &replacementsMap)
Recursively traverse e.
static void composeMultiResultAffineMap(AffineMap &map, SmallVectorImpl< Value > &operands, bool composeAffineMin=false)
Composes the given affine map with the given list of operands, pulling in the maps from any affine....
static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map)
Canonicalize the result expression order of an affine map and return success if the order changed.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value getMemRef(Operation *memOp)
Returns the memref being read/written by a memref/affine load/store op.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static Operation::operand_range getLowerBoundOperands(AffineForOp forOp)
static Operation::operand_range getUpperBoundOperands(AffineForOp forOp)
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
AffineExpr shiftDims(unsigned numDims, unsigned shift, unsigned offset=0) const
Replace dims[offset ... numDims) by dims[offset + shift ... shift + numDims).
AffineExpr shiftSymbols(unsigned numSymbols, unsigned shift, unsigned offset=0) const
Replace symbols[offset ... numSymbols) by symbols[offset + shift ... shift + numSymbols).
AffineExpr floorDiv(uint64_t v) const
AffineExprKind getKind() const
Return the classification for this type.
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
MLIRContext * getContext() const
AffineExpr replace(AffineExpr expr, AffineExpr replacement) const
Sparse replace method.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
MLIRContext * getContext() const
bool isFunctionOfDim(unsigned position) const
Return true if any affine expression involves AffineDimExpr position.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ... numDims) by dims[offset + shift ... shift + numDims).
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
bool isFunctionOfSymbol(unsigned position) const
Return true if any affine expression involves AffineSymbolExpr position.
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
unsigned getNumInputs() const
AffineMap shiftSymbols(unsigned shift, unsigned offset=0) const
Replace symbols[offset ... numSymbols) by symbols[offset + shift ... shift + numSymbols).
AffineExpr getResult(unsigned idx) const
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
static AffineMap getConstantMap(int64_t val, MLIRContext *context)
Returns a single constant result affine map.
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
LogicalResult constantFold(ArrayRef< Attribute > operandConstants, SmallVectorImpl< Attribute > &results, bool *hasPoison=nullptr) const
Folds the results of the application of an affine map on the provided operands to a constant if possi...
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
AffineMap getDimIdentityMap()
AffineMap getMultiDimIdentityMap(unsigned rank)
AffineExpr getAffineSymbolExpr(unsigned position)
AffineExpr getAffineConstantExpr(int64_t constant)
DenseIntElementsAttr getI32TensorAttr(ArrayRef< int32_t > values)
Tensor-typed DenseIntElementsAttr getters.
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
AffineMap getEmptyAffineMap()
Returns a zero result affine map with no dimensions or symbols: () -> ().
TypedAttr getZeroAttr(Type type)
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
AffineMap getSymbolIdentityMap()
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
An attribute that represents a reference to a dense integer vector or tensor object.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
An integer set representing a conjunction of one or more affine equalities and inequalities.
unsigned getNumDims() const
static IntegerSet get(unsigned dimCount, unsigned symbolCount, ArrayRef< AffineExpr > constraints, ArrayRef< bool > eqFlags)
MLIRContext * getContext() const
unsigned getNumInputs() const
ArrayRef< AffineExpr > getConstraints() const
ArrayRef< bool > getEqFlags() const
Returns the equality bits, which specify whether each of the constraints is an equality or inequality...
unsigned getNumSymbols() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
void erase(unsigned subStart, unsigned subLen=1)
Erase the operands within the given sub-range.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseAffineMapOfSSAIds(SmallVectorImpl< UnresolvedOperand > &operands, Attribute &map, StringRef attrName, NamedAttrList &attrs, Delimiter delimiter=Delimiter::Square)=0
Parses an affine map attribute where dims and symbols are SSA operands.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseAffineExprOfSSAIds(SmallVectorImpl< UnresolvedOperand > &dimOperands, SmallVectorImpl< UnresolvedOperand > &symbOperands, AffineExpr &expr)=0
Parses an affine expression where dims and symbols are SSA operands.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, ValueRange symOperands)=0
Prints an affine expression of SSA ids with SSA id names used instead of dims and symbols.
virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, ValueRange operands)=0
Prints an affine map of SSA ids, where SSA id names are used in place of dims/symbols.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This class represents a single result from folding an operation.
A trait of region holding operations that defines a new scope for polyhedral optimization purposes.
This class provides the API for ops that are known to be isolated from above.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OperandRange operand_range
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
operand_range::iterator operand_iterator
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
RegionBranchTerminatorOpInterface getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
This class represents a successor of a region.
Region * getSuccessor() const
Return the given region successor.
bool isOperation() const
Return true if the successor is an operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
bool hasOneBlock()
Return true if this region has exactly one block.
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
static DefaultResource * get()
std::vector< SmallVector< int64_t, 8 > > operandExprStack
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
A variable that can be added to the constraint set as a "column".
static bool compare(const Variable &lhs, ComparisonOperator cmp, const Variable &rhs)
Return "true" if "lhs cmp rhs" was proven to hold.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
AffineBound represents a lower or upper bound in the for operation.
An AffineValueMap is an affine map plus its ML value operands and results for analysis purposes.
LogicalResult canonicalize()
Attempts to canonicalize the map and operands.
ArrayRef< Value > getOperands() const
AffineExpr getResult(unsigned i)
AffineMap getAffineMap() const
void reset(AffineMap map, ValueRange operands, ValueRange results={})
unsigned getNumResults() const
static void difference(const AffineValueMap &a, const AffineValueMap &b, AffineValueMap *res)
Return the value map that is the difference of value maps 'a' and 'b', represented as an affine map a...
Operation * getOwner() const
Return the owner of this operand.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
void buildAffineLoopNest(OpBuilder &builder, Location loc, ArrayRef< int64_t > lbs, ArrayRef< int64_t > ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn=nullptr)
Builds a perfect nest of affine.for loops, i.e., each loop except the innermost one contains only ano...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
void extractForInductionVars(ArrayRef< AffineForOp > forInsts, SmallVectorImpl< Value > *ivs)
Extracts the induction variables from a list of AffineForOps and places them in the output argument i...
bool isValidDim(Value value)
Returns true if the given Value can be used as a dimension id in the region of the closest surroundin...
bool isAffineInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp or AffineParallelOp.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
OpFoldResult computeProduct(Location loc, OpBuilder &builder, ArrayRef< OpFoldResult > terms)
Return the product of terms, creating an affine.apply if any of them are non-constant values.
AffineForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
bool isAffineForInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
bool isTopLevelValue(Value value)
A utility function to check if a value is defined at the top level of an op with trait AffineScope or...
Region * getAffineAnalysisScope(Operation *op)
Returns the closest region enclosing op that is held by a non-affine operation; nullptr if there is n...
void fullyComposeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands, bool composeAffineMin=false)
Given an affine map map and its input operands, this method composes into map, maps of AffineApplyOps...
void canonicalizeSetAndOperands(IntegerSet *set, SmallVectorImpl< Value > *operands)
Canonicalizes an integer set the same way canonicalizeMapAndOperands does for affine maps.
void extractInductionVars(ArrayRef< Operation * > affineOps, SmallVectorImpl< Value > &ivs)
Extracts the induction variables from a list of either AffineForOp or AffineParallelOp and places the...
bool isValidSymbol(Value value)
Returns true if the given value can be used as a symbol in the region of the closest surrounding op t...
AffineParallelOp getAffineParallelInductionVarOwner(Value val)
Returns true if the provided value is among the induction variables of an AffineParallelOp.
Region * getAffineScope(Operation *op)
Returns the closest region enclosing op that is held by an operation with trait AffineScope; nullptr ...
ParseResult parseDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl< Value > &operands, unsigned &numDims)
Parses dimension and symbol list.
bool isAffineParallelInductionVar(Value val)
Returns true if val is the induction variable of an AffineParallelOp.
AffineMinOp makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns an AffineMinOp obtained by composing map and operands with AffineApplyOps supplying those ope...
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
AffineMap removeDuplicateExprs(AffineMap map)
Returns a map with the same dimension and symbol count as map, but whose results are the unique affin...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::function< SmallVector< Value >( OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::optional< int64_t > getBoundForAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols, ArrayRef< std::optional< int64_t > > constLowerBounds, ArrayRef< std::optional< int64_t > > constUpperBounds, bool isUpper)
Get a lower or upper (depending on isUpper) bound for expr while using the constant lower and upper b...
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ SymbolId
Symbolic identifier.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
llvm::TypeSwitch< T, ResultT > TypeSwitch
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
AffineMap foldAttributesIntoMap(Builder &b, AffineMap map, ArrayRef< OpFoldResult > operands, SmallVector< Value > &remainingValues)
Fold all attributes among the given operands into the affine map.
llvm::function_ref< Fn > function_ref
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Canonicalize the affine map result expression order of an affine min/max operation.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Remove duplicated expressions in affine min/max ops.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Merge an affine min/max op to its consumers if its consumer is also an affine min/max op.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.