18#include "llvm/ADT/SetOperations.h"
19#include "llvm/ADT/SetVector.h"
20#include "llvm/ADT/TypeSwitch.h"
21#include "llvm/Support/Debug.h"
25#define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION
26#include "mlir/Dialect/Linalg/Passes.h.inc"
32#define DEBUG_TYPE "linalg-data-layout-propagation"
36static bool hasGatherSemantics(linalg::GenericOp genericOp) {
37 for (
Operation &op : genericOp.getBody()->getOperations())
38 if (isa<tensor::ExtractOp, linalg::IndexOp>(op))
46 int64_t getNumTiledLoops()
const {
return tileToPointMapping.size(); };
58template <
typename OpTy>
59static FailureOr<PackInfo>
60getPackingInfoFromOperand(
OpOperand *opOperand, linalg::GenericOp genericOp,
61 OpTy packOrUnPackOp) {
62 static_assert(llvm::is_one_of<OpTy, linalg::PackOp, linalg::UnPackOp>::value,
63 "applies to only pack or unpack operations");
65 { llvm::dbgs() <<
"--- Construct PackInfo From an operand ---\n"; });
67 AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
70 genericOp.getIteratorTypesArray();
73 int64_t origNumDims = indexingMap.getNumDims();
76 for (
auto [
index, innerDimPos, tileSize] :
77 llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()),
78 innerDimsPos, packOrUnPackOp.getMixedTiles())) {
79 auto expr = exprs[innerDimPos];
80 if (!isa<AffineDimExpr>(expr))
83 cast<AffineDimExpr>(exprs[innerDimPos]).getPosition();
86 packInfo.tiledDimsPos.push_back(domainDimPos);
87 packInfo.domainDimAndTileMapping[domainDimPos] = tileSize;
88 packInfo.tileToPointMapping[domainDimPos] = origNumDims +
index;
90 llvm::dbgs() <<
"map innerDimPos=" << innerDimPos
91 <<
" to iteration dimension (d" << domainDimPos <<
", d"
92 << packInfo.tileToPointMapping[domainDimPos]
93 <<
"), which has size=("
94 << packInfo.domainDimAndTileMapping[domainDimPos] <<
")\n";
100 auto areAllAffineDimExpr = [&](
int dim) {
102 if (llvm::any_of(map.getResults(), [dim](
AffineExpr expr) {
103 return expr.isFunctionOfDim(dim) && !isa<AffineDimExpr>(expr);
110 for (
int64_t i : packInfo.tiledDimsPos)
111 if (!areAllAffineDimExpr(i))
128 for (
auto [
index, dim] : llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) {
129 auto permutedExpr = indexingMap.getResult(dim);
130 if (
auto dimExpr = dyn_cast<AffineDimExpr>(permutedExpr)) {
131 permutedOuterDims.push_back(dimExpr.getPosition());
141 if (!permutedOuterDims.empty()) {
144 permutedOuterDims.end());
145 for (
int i = 0, e = indexingMap.getNumDims(); i < e; i++)
146 packInfo.outerDimsOnDomainPerm.push_back(
147 permutedDomainDims.contains(i) ? permutedOuterDims[outerDimIndex++]
150 llvm::dbgs() <<
"map outer dimsDimsPerm to ";
151 for (
auto dim : packInfo.outerDimsOnDomainPerm)
152 llvm::dbgs() << dim <<
" ";
153 llvm::dbgs() <<
"\n";
171 assert(!perm.empty() &&
"expect perm not to be empty");
172 assert(!exprs.empty() &&
"expect exprs not to be empty");
173 if (exprs.size() == 1)
177 for (
auto [pos, expr] : llvm::enumerate(exprs)) {
181 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr))
182 currentPositionTileLoops[dimExpr.getPosition()] = pos;
184 currentPositionTileLoops[pos] = pos;
187 if (currentPositionTileLoops.count(loopIdx))
188 outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx));
190 return outerDimsPerm;
193struct PackedOperandDetails {
203static bool getPackedOperandDetails(
206 PackedOperandDetails currOperandDetails;
207 int64_t numOrigLoops = genericOp.getNumLoops();
208 int64_t numInnerLoops = packInfo.getNumTiledLoops();
209 int64_t numLoops = numOrigLoops + numInnerLoops;
210 AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand);
215 if (genericOp.isScalar(opOperand) || exprs.empty()) {
216 currOperandDetails.indexingMap =
218 packedOperandMap[opOperand] = currOperandDetails;
224 for (
auto [
index, expr] : llvm::enumerate(exprs)) {
225 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
226 int64_t dimPos = dimExpr.getPosition();
227 domainDimToOperandDim[dimPos] =
index;
233 for (
auto dimPos : packInfo.tiledDimsPos) {
234 if (!domainDimToOperandDim.count(dimPos))
237 innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]);
238 innerDimsPos.push_back(
index);
239 exprs.push_back(
b.getAffineDimExpr(packInfo.tileToPointMapping[dimPos]));
244 if (!packInfo.outerDimsOnDomainPerm.empty()) {
245 outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs);
250 for (
auto i : llvm::seq<unsigned>(0, origIndexingMap.
getNumResults())) {
251 if (
auto dimExpr = dyn_cast<AffineDimExpr>(exprs[i])) {
252 int64_t dimPos = dimExpr.getPosition();
253 exprs[i] =
b.getAffineDimExpr(inversedOuterPerm[dimPos]);
256 assert(isa<AffineConstantExpr>(exprs[i]) &&
257 "Attempted to permute non-constant and non-affine dim expression");
261 if (!outerDimsPerm.empty()) {
263 for (
const auto &en : enumerate(outerDimsPerm))
264 auxVec[en.index()] = exprs[en.value()];
268 currOperandDetails.indexingMap =
272 if (innerDimsPos.empty() && outerDimsPerm.empty()) {
273 packedOperandMap[opOperand] = currOperandDetails;
276 auto inputType = cast<RankedTensorType>(opOperand->
get().
getType());
278 auto maybeIntInnerTileSizes =
281 return maybeCst.value_or(ShapedType::kDynamic);
283 bool requirePadding = linalg::PackOp::requirePaddingValueStrict(
284 inputType.getShape(), innerDimsPos,
285 linalg::PackOp::inferPackedTensorType(inputType, maybeIntInnerTileSizes,
286 innerDimsPos, outerDimsPerm)
288 outerDimsPerm, innerTileSizes);
289 currOperandDetails.innerDimsPos = innerDimsPos;
290 currOperandDetails.innerTileSizes = innerTileSizes;
291 currOperandDetails.outerDimsPerm = outerDimsPerm;
292 packedOperandMap[opOperand] = currOperandDetails;
294 return requirePadding;
328static std::tuple<Value, AffineMap> getOrCreatePackedViewOfOperand(
331 assert(packedOperandMap.contains(opOperand) &&
332 "packed operand details expected to be populated");
333 auto currOperandDetails = packedOperandMap.at(opOperand);
334 auto innerDimsPos = currOperandDetails.innerDimsPos;
335 auto outerDimsPerm = currOperandDetails.outerDimsPerm;
336 auto innerTileSizes = currOperandDetails.innerTileSizes;
337 if (innerDimsPos.empty() && outerDimsPerm.empty())
338 return std::make_tuple(opOperand->
get(), currOperandDetails.indexingMap);
340 auto empty = linalg::PackOp::createDestinationTensor(
341 b, loc, opOperand->
get(), innerTileSizes, innerDimsPos, outerDimsPerm);
342 auto poison = ub::PoisonOp::create(
344 PackOp packedOperand =
345 linalg::PackOp::create(
b, loc, opOperand->
get(), empty, innerDimsPos,
346 innerTileSizes, poison, outerDimsPerm);
347 return std::make_tuple(packedOperand.getResult(),
348 currOperandDetails.indexingMap);
357static FailureOr<GenericOp>
359 AffineMap packedOutIndexingMap,
const PackInfo &packInfo,
360 bool isFoldableUnpackPack,
bool poisonPaddingOk) {
365 auto hasEquivalentTiles = [](PackOp packOp, UnPackOp unPackOp) {
366 return packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm() &&
367 packOp.getInnerDimsPos() == unPackOp.getInnerDimsPos() &&
368 llvm::equal(packOp.getMixedTiles(), unPackOp.getMixedTiles());
371 bool requiresPadding =
false;
372 for (
OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
373 requiresPadding |= getPackedOperandDetails(rewriter, packInfo, genericOp,
374 inputOperand, packedOperandMap);
376 if (requiresPadding && !poisonPaddingOk)
379 for (
OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
380 auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
381 rewriter, loc, inputOperand, packedOperandMap);
382 auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>();
383 auto packOp = packedOperand.getDefiningOp<linalg::PackOp>();
384 if (packOp && unpackOp && hasEquivalentTiles(packOp, unpackOp)) {
385 inputOperandsFromUnpackedSource.push_back(unpackOp.getSource());
387 inputOperandsFromUnpackedSource.push_back(packedOperand);
389 inputOperands.push_back(packedOperand);
390 indexingMaps.push_back(packedIndexingMap);
395 if (isFoldableUnpackPack) {
396 inputOperands = inputOperandsFromUnpackedSource;
398 auto destUnPack = destPack.getSource().getDefiningOp<linalg::UnPackOp>();
399 if (destUnPack && hasEquivalentTiles(destPack, destUnPack)) {
400 dest = destUnPack.getSource();
405 int64_t numInnerLoops = packInfo.getNumTiledLoops();
407 genericOp.getIteratorTypesArray();
408 iterTypes.append(numInnerLoops, utils::IteratorType::parallel);
410 indexingMaps.push_back(packedOutIndexingMap);
412 auto newGenericOp = linalg::GenericOp::create(
413 rewriter, loc, dest.
getType(), inputOperands, dest, indexingMaps,
417 newGenericOp.getRegion().begin());
421static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
422 return llvm::all_of(genericOp.getDpsInitsMutable(), [&](
OpOperand &operand) {
423 return genericOp.getMatchingBlockArgument(&operand).use_empty();
470static FailureOr<GenericOp>
471bubbleUpPackOpThroughGenericOp(
RewriterBase &rewriter, linalg::PackOp packOp,
473 bool poisonPaddingOk) {
474 auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
479 if (!controlFn(&packOp.getSourceMutable()))
485 if (hasGatherSemantics(genericOp))
490 if (genericOp.getNumResults() != 1)
497 if (!genericOp->getResult(0).hasOneUse())
504 if (packOp.getPaddingValue())
507 OpOperand *opOperand = genericOp.getDpsInitOperand(0);
508 auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp);
509 if (failed(packInfo))
522 Value packOpDest = packOp.getDest();
525 if (
auto emptyOp = packOpDest.
getDefiningOp<tensor::EmptyOp>()) {
526 packOpDest = tensor::EmptyOp::create(rewriter, genericOp->getLoc(),
527 emptyOp.getMixedSizes(),
528 emptyOp.
getType().getElementType());
537 bool requiresPadding = getPackedOperandDetails(rewriter, *packInfo, genericOp,
538 opOperand, packedOperandMap);
539 if (requiresPadding && !poisonPaddingOk)
542 auto [packedOutOperand, packedOutIndexingMap] =
543 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), opOperand,
550 Value dest = packedOutOperand;
552 genericOp.getDpsInitOperand(0)->get().
getDefiningOp<tensor::EmptyOp>();
553 if (initTensor || isGenericOutsNotUsed(genericOp)) {
558 return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap,
564struct BubbleUpPackOpThroughGenericOpPattern
567 BubbleUpPackOpThroughGenericOpPattern(
MLIRContext *context,
569 bool poisonPaddingOk)
571 poisonPaddingOk(std::move(poisonPaddingOk)) {}
573 LogicalResult matchAndRewrite(linalg::PackOp packOp,
575 if (!packOp.hasPureTensorSemantics())
578 auto genericOp = bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn,
580 if (failed(genericOp))
582 rewriter.
replaceOp(packOp, genericOp->getResults());
588 bool poisonPaddingOk;
594class BubbleUpPackThroughPadOp final :
public OpRewritePattern<linalg::PackOp> {
599 LogicalResult matchAndRewrite(linalg::PackOp packOp,
601 if (!packOp.hasPureTensorSemantics())
604 auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
609 if (!controlFn(&packOp.getSourceMutable()))
613 if (packOp.getPaddingValue())
620 Value paddingVal = padOp.getConstantPaddingValue();
624 if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
630 llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
631 llvm::SmallBitVector innerDims(paddedDims.size());
632 for (
int64_t dim : innerDimsPos)
634 if (paddedDims.anyCommon(innerDims))
643 auto empty = linalg::PackOp::createDestinationTensor(
644 rewriter, loc, padOp.getSource(), mixedTiles, innerDimsPos,
646 auto sourcePack = linalg::PackOp::create(
647 rewriter, loc, padOp.getSource(), empty, innerDimsPos, mixedTiles,
648 std::nullopt, outerDimsPerm);
653 if (!outerDimsPerm.empty()) {
659 size_t pointLoopsSize = innerDimsPos.size();
660 lowPad.append(pointLoopsSize, rewriter.
getIndexAttr(0));
661 highPad.append(pointLoopsSize, rewriter.
getIndexAttr(0));
663 auto newPadOp = tensor::PadOp::create(
664 rewriter, loc,
Type(), sourcePack.getResult(), lowPad,
665 highPad, paddingVal, padOp.getNofold());
669 if (!padOp->hasOneUse()) {
670 auto unpackEmpty = linalg::UnPackOp::createDestinationTensor(
671 rewriter, loc, newPadOp, mixedTiles, innerDimsPos, outerDimsPerm);
672 UnPackOp unpackedPad =
673 linalg::UnPackOp::create(rewriter, loc, newPadOp, unpackEmpty,
674 innerDimsPos, mixedTiles, outerDimsPerm);
679 rewriter.
replaceOp(packOp, newPadOp.getResult());
702 for (
auto pos : dimsPos) {
704 int64_t projectedPos = reassocIndices[pos].back();
705 for (
auto i : llvm::reverse(reassocIndices[pos])) {
707 if (dim > 1 || ShapedType::isDynamic(dim)) {
712 projectedDimsPos.push_back(projectedPos);
714 return projectedDimsPos;
721 for (
auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) {
723 if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0)
735static int64_t applyPermutationAndReindexReassoc(
738 if (!permutation.empty())
770bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
771 linalg::PackOp packOp,
773 if (!packOp.hasPureTensorSemantics())
782 collapseOp.getReassociationIndices();
791 projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape);
793 if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape,
801 for (
auto outerPos : outerDimsPerm)
802 llvm::append_range(newOuterDimsPerm, reassocIndices[outerPos]);
804 auto emptyOp = linalg::PackOp::createDestinationTensor(
805 rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(),
806 projectedInnerDimsPos, newOuterDimsPerm);
807 auto newPackOp = linalg::PackOp::create(
808 rewriter, packOp.getLoc(), collapseOp.getSrc(), emptyOp,
809 projectedInnerDimsPos, packOp.getMixedTiles(), packOp.getPaddingValue(),
817 applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
819 for (
size_t i = 0; i < innerDimsPos.size(); ++i) {
820 newReassocIndices.push_back({nextPos});
824 auto newCollapseOp = tensor::CollapseShapeOp::create(
825 rewriter, collapseOp.getLoc(), packOp.getResult().
getType(),
826 newPackOp.getResult(), newReassocIndices);
827 rewriter.
replaceOp(packOp, newCollapseOp);
844 for (
auto pos : dimsPos) {
845 for (
auto [idx,
indices] : llvm::enumerate(reassocIndices)) {
849 if (llvm::is_contained(
indices, pos)) {
850 projectedPos.push_back(idx);
855 assert(projectedPos.size() == dimsPos.size() &&
"Invalid dim pos projection");
878bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
879 linalg::PackOp packOp,
881 if (!packOp.hasPureTensorSemantics())
889 "non-identity outer dims perm NYI");
894 expandOp.getReassociationIndices();
898 for (
auto [idx,
indices] : llvm::enumerate(reassoc)) {
903 llvm::set_intersection(packDimsPos, expandDimPos);
907 if (packedDims.empty())
912 if (packedDims.size() != 1)
914 packOp,
"only one of the expanded dimensions can be packed");
917 if (packedDims.front() !=
indices.back())
919 packOp,
"can only pack the inner-most expanded dimension");
924 projectDimsPosIntoReassocPos(packInnerDims, reassoc);
934 RankedTensorType newPackType = linalg::PackOp::inferPackedTensorType(
935 expandOp.getSrcType(), packOp.getStaticInnerTiles(),
941 packOp,
"could not reassociate dims after bubbling up");
943 Value destTensor = linalg::PackOp::createDestinationTensor(
944 rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(),
946 PackOp packedVal = linalg::PackOp::create(
947 rewriter, packOp.getLoc(), expandOp.getSrc(), destTensor,
948 projectedInnerDimsPos, packOp.getMixedTiles(), packOp.getPaddingValue(),
951 Value newExpandOp = tensor::ExpandShapeOp::create(
952 rewriter, packOp.getLoc(), packOp.getDestType(), packedVal.getResult(),
959class BubbleUpPackOpThroughReshapeOp final
965 LogicalResult matchAndRewrite(linalg::PackOp packOp,
967 if (!packOp.hasPureTensorSemantics())
970 Operation *srcOp = packOp.getSource().getDefiningOp();
977 if (llvm::any_of(packOp.getStaticTiles(), ShapedType::isDynamic))
981 if (!controlFn(&packOp.getSourceMutable()))
985 .Case([&](tensor::CollapseShapeOp op) {
986 return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter);
988 .Case([&](tensor::ExpandShapeOp op) {
989 return bubbleUpPackOpThroughExpandShape(op, packOp, rewriter);
1017static LogicalResult pushDownUnPackOpThroughExpandShape(
1018 linalg::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp,
1020 if (!unPackOp.hasPureTensorSemantics())
1024 if (!controlFn(&expandOp.getSrcMutable()))
1031 auto expandTy = dyn_cast<RankedTensorType>(expandOp.getType());
1036 expandOp.getReassociationIndices();
1045 projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape);
1047 if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape,
1055 for (
auto outerPos : outerDimsPerm)
1056 llvm::append_range(newOuterDimsPerm, reassocIndices[outerPos]);
1063 applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
1065 for (
size_t i = 0; i < innerDimsPos.size(); ++i) {
1066 newReassocIndices.push_back({nextPos});
1070 RankedTensorType newExpandType = linalg::PackOp::inferPackedTensorType(
1071 expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm);
1073 tensor::ExpandShapeOp::create(rewriter, expandOp.getLoc(), newExpandType,
1074 unPackOp.getSource(), newReassocIndices);
1076 auto emptyOp = linalg::UnPackOp::createDestinationTensor(
1077 rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(),
1078 projectedInnerDimsPos, newOuterDimsPerm);
1079 auto newUnPackOp = linalg::UnPackOp::create(
1080 rewriter, unPackOp.getLoc(), newExpandOp.getResult(), emptyOp,
1081 projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm);
1082 rewriter.
replaceOp(expandOp, newUnPackOp);
1087class PushDownUnPackOpThroughReshapeOp final
1090 PushDownUnPackOpThroughReshapeOp(
MLIRContext *context,
1095 LogicalResult matchAndRewrite(linalg::UnPackOp unPackOp,
1097 if (!unPackOp.hasPureTensorSemantics())
1102 if (!
result.hasOneUse()) {
1106 if (llvm::any_of(unPackOp.getStaticTiles(), ShapedType::isDynamic))
1111 .Case([&](tensor::ExpandShapeOp op) {
1112 return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter,
1115 .Default(failure());
1125static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
1127 for (
OpOperand &operand : genericOp->getOpOperands()) {
1128 auto unPackOp = operand.get().getDefiningOp<linalg::UnPackOp>();
1131 if (unPackedOperand)
1133 unPackedOperand = &operand;
1135 if (!unPackedOperand)
1137 return unPackedOperand;
1174static FailureOr<std::tuple<GenericOp, Value>>
1175pushDownUnPackOpThroughGenericOp(
RewriterBase &rewriter, GenericOp genericOp,
1177 bool poisonPaddingOk) {
1178 if (genericOp.getNumResults() != 1)
1181 if (hasGatherSemantics(genericOp))
1185 auto maybeUnPackedOperand = getUnPackedOperand(genericOp);
1186 if (failed(maybeUnPackedOperand))
1188 OpOperand *unPackedOperand = *(maybeUnPackedOperand);
1191 linalg::UnPackOp producerUnPackOp =
1193 assert(producerUnPackOp &&
"expect a valid UnPackOp");
1195 if (!controlFn(unPackedOperand))
1199 getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp);
1200 if (failed(packInfo))
1205 bool requiresPadding =
1206 getPackedOperandDetails(rewriter, *packInfo, genericOp,
1207 genericOp.getDpsInitOperand(0), packedOperandMap);
1208 if (requiresPadding && !poisonPaddingOk)
1211 auto [packedOutOperand, packedOutIndexingMap] =
1212 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(),
1213 genericOp.getDpsInitOperand(0),
1215 auto destPack = packedOutOperand.getDefiningOp<linalg::PackOp>();
1222 Value dest = packedOutOperand;
1224 genericOp.getDpsInitOperand(0)->get().
getDefiningOp<tensor::EmptyOp>();
1225 if (initTensor || isGenericOutsNotUsed(genericOp)) {
1227 dest = destPack.getDest();
1234 auto maybeGenericOp =
1235 packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo,
1236 true, poisonPaddingOk);
1237 if (failed(maybeGenericOp))
1239 GenericOp newGenericOp = *maybeGenericOp;
1241 newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0));
1245 return std::make_tuple(newGenericOp, newResult);
1247 auto mixedTiles = destPack.getMixedTiles();
1248 auto innerDimsPos = destPack.getInnerDimsPos();
1249 auto outerDimsPerm = destPack.getOuterDimsPerm();
1253 linalg::UnPackOp::create(rewriter, genericOp.getLoc(), newResult,
1254 destPack.getSource(), innerDimsPos, mixedTiles,
1258 return std::make_tuple(newGenericOp, unPackOpRes);
1262struct PushDownUnPackOpThroughGenericOp :
public OpRewritePattern<GenericOp> {
1264 PushDownUnPackOpThroughGenericOp(
MLIRContext *context,
1266 bool poisonPaddingOk)
1268 poisonPaddingOk(std::move(poisonPaddingOk)) {}
1270 LogicalResult matchAndRewrite(GenericOp genericOp,
1272 auto genericAndRepl = pushDownUnPackOpThroughGenericOp(
1273 rewriter, genericOp, controlFn, poisonPaddingOk);
1274 if (failed(genericAndRepl))
1276 rewriter.
replaceOp(genericOp, std::get<1>(*genericAndRepl));
1282 bool poisonPaddingOk;
1288struct PushDownUnPackThroughPadOp :
public OpRewritePattern<tensor::PadOp> {
1292 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1294 linalg::UnPackOp unpackOp =
1295 padOp.getSource().getDefiningOp<linalg::UnPackOp>();
1299 if (!unpackOp.hasPureTensorSemantics())
1302 if (!controlFn(&padOp.getSourceMutable()))
1307 llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
1309 llvm::SmallBitVector innerDims(paddedDims.size());
1310 for (
int64_t dim : innerDimsPos)
1311 innerDims.flip(dim);
1312 if (paddedDims.anyCommon(innerDims))
1315 Value paddingVal = padOp.getConstantPaddingValue();
1323 if (!outerDimsPerm.empty()) {
1328 size_t pointLoopsSize = innerDimsPos.size();
1329 lowPad.append(pointLoopsSize, rewriter.
getIndexAttr(0));
1330 highPad.append(pointLoopsSize, rewriter.
getIndexAttr(0));
1332 auto newPadOp = tensor::PadOp::create(rewriter, loc,
Type(),
1333 unpackOp.getSource(), lowPad, highPad,
1334 paddingVal, padOp.getNofold());
1337 Value outputUnPack =
1338 tensor::EmptyOp::create(rewriter, loc, padOp.getResultType().getShape(),
1339 padOp.getResultType().getElementType());
1342 rewriter, loc, newPadOp.getResult(), outputUnPack, innerDimsPos,
1343 unpackOp.getMixedTiles(), outerDimsPerm);
1353struct SliceDimInfo {
1361static FailureOr<SmallVector<OpOperand *>>
1362getSliceOperands(GenericOp genericOp) {
1364 for (
auto operand : genericOp.getDpsInputOperands()) {
1365 auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
1368 sliceOperands.push_back(operand);
1370 if (sliceOperands.empty()) {
1373 return sliceOperands;
1379static FailureOr<llvm::DenseMap<int64_t, SliceDimInfo>>
1380getPartialSliceDimInfo(GenericOp genericOp,
OpOperand *sliceOperand) {
1381 tensor::ExtractSliceOp producerSliceOp =
1383 assert(producerSliceOp &&
"expect a valid ExtractSliceOp");
1389 genericOp.getContext(), producerSliceOp.getSourceType().getShape());
1391 for (
auto [idx, expr] : llvm::enumerate(
1392 genericOp.getMatchingIndexingMap(sliceOperand).getResults())) {
1401 if (!isa<AffineDimExpr>(expr)) {
1404 SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx],
shape[idx]};
1405 int64_t dimPos = cast<AffineDimExpr>(expr).getPosition();
1406 partialSliceDimMap[dimPos] = sliceDimInfo;
1410 for (
OpOperand &operand : genericOp->getOpOperands()) {
1411 if (operand == *sliceOperand) {
1414 AffineMap IndexingMap = genericOp.getMatchingIndexingMap(&operand);
1416 if (isa<AffineDimExpr>(expr)) {
1420 if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
1421 if (partialSliceDimMap.contains(dimExpr.getPosition())) {
1422 return WalkResult::interrupt();
1427 if (status.wasInterrupted()) {
1435 return partialSliceDimMap;
1438static FailureOr<std::tuple<GenericOp, Value>>
1439pushDownExtractSliceOpThroughGenericOp(
RewriterBase &rewriter,
1440 GenericOp genericOp,
1442 if (genericOp.getNumResults() != 1)
1444 genericOp,
"propagation through multi-result generic is unsupported.");
1445 if (hasGatherSemantics(genericOp))
1448 "propagation through generic with gather semantics is unsupported.");
1450 auto maybeSliceOperands = getSliceOperands(genericOp);
1451 if (
failed(maybeSliceOperands))
1456 bool foundValidOperand =
false;
1457 for (
auto currSliceOperand : sliceOperands) {
1458 if (controlFn(currSliceOperand)) {
1459 sliceOperand = currSliceOperand;
1460 foundValidOperand =
true;
1464 if (!foundValidOperand) {
1469 tensor::ExtractSliceOp producerSliceOp =
1471 assert(producerSliceOp &&
"expect a valid ExtractSliceOp");
1473 if (producerSliceOp.getSource().getType().getRank() !=
1474 producerSliceOp.getResult().getType().getRank()) {
1477 "propagation of rank-reducing extract slice is unsupported.");
1483 genericOp,
"propagation of strided extract slice is unsupported.");
1488 auto maybePartialSliceDimMap =
1489 getPartialSliceDimInfo(genericOp, sliceOperand);
1491 if (
failed(maybePartialSliceDimMap)) {
1495 auto partialSliceDimMap = *maybePartialSliceDimMap;
1498 genericOp.getIteratorTypesArray();
1499 bool hasPartialReductionDimSlice =
1500 llvm::any_of(partialSliceDimMap, [&](
const auto &slice) {
1501 int64_t sliceDim = slice.first;
1502 return iterators[sliceDim] == utils::IteratorType::reduction;
1506 Location loc = genericOp->getLoc();
1517 for (
auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
1518 if (idx == OperandIndex && !hasPartialReductionDimSlice) {
1519 paddedInputs.push_back(producerSliceOp.getSource());
1522 AffineMap IndexingMap = genericOp.getMatchingIndexingMap(operand);
1524 paddedInputs.push_back(operand->get());
1531 for (
auto [idx, expr] : llvm::enumerate(IndexingMap.
getResults())) {
1532 if (!isa<AffineDimExpr>(expr)) {
1536 if (!partialSliceDimMap.contains(dimExpr.
getPosition())) {
1539 SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.
getPosition()];
1540 operandLowPads[idx] = sliceDimInfo.offset;
1541 operandHighPads[idx] =
1542 sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
1543 sliceDimInfo.sliceSize);
1545 auto paddingValue = ub::PoisonOp::create(
1547 auto paddedOperand = tensor::PadOp::create(
1548 rewriter, loc,
Type(), operand->get(), operandLowPads, operandHighPads,
1549 paddingValue,
false);
1550 paddedInputs.push_back(paddedOperand);
1553 genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0));
1555 auto outputShapeType =
1556 llvm::cast<ShapedType>(genericOp.getDpsInitOperand(0)->get().getType());
1558 outputShapeType.getShape(),
1567 for (
auto [idx, expr] : llvm::enumerate(outputIndexingMap.
getResults())) {
1568 if (!isa<AffineDimExpr>(expr)) {
1572 if (!partialSliceDimMap.contains(dimExpr.
getPosition())) {
1575 SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.
getPosition()];
1576 outputLowPads[idx] = sliceDimInfo.offset;
1577 outputHighPads[idx] = sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
1578 sliceDimInfo.sliceSize);
1579 OutputShape[idx] = sliceDimInfo.outputSize;
1580 newSizes[idx] = sliceDimInfo.sliceSize;
1585 if (isGenericOutsNotUsed(genericOp)) {
1587 tensor::EmptyOp::create(rewriter, loc, OutputShape, outputElType);
1589 auto paddingValue = ub::PoisonOp::create(rewriter, loc, outputElType);
1590 newPadOutput = tensor::PadOp::create(
1591 rewriter, loc,
Type(), genericOp.getDpsInits()[0], outputLowPads,
1592 outputHighPads, paddingValue,
false);
1595 auto newGenericOp = linalg::GenericOp::create(
1596 rewriter, loc, newPadOutput.
getType(), paddedInputs, {newPadOutput},
1597 genericOp.getIndexingMapsArray(), genericOp.getIteratorTypesArray(),
1600 newGenericOp.getRegion().begin());
1602 auto extractOp = tensor::ExtractSliceOp::create(
1604 newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)),
1605 outputLowPads, newSizes, newStrides);
1606 Value extractRes = extractOp.getResult();
1608 return std::make_tuple(newGenericOp, extractRes);
1611class PushDownExtractSliceOpThroughGenericOp final
1614 PushDownExtractSliceOpThroughGenericOp(MLIRContext *context,
1616 : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
1618 LogicalResult matchAndRewrite(GenericOp genericOp,
1619 PatternRewriter &rewriter)
const override {
1620 auto genericAndRepl =
1621 pushDownExtractSliceOpThroughGenericOp(rewriter, genericOp, controlFn);
1622 if (
failed(genericAndRepl))
1624 rewriter.
replaceOp(genericOp, std::get<1>(*genericAndRepl));
1637 bool PoisonPaddingOk) {
1638 patterns.insert<BubbleUpPackThroughPadOp, BubbleUpPackOpThroughReshapeOp,
1639 PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
1640 patterns.getContext(), controlPackUnPackPropagation);
1641 patterns.insert<BubbleUpPackOpThroughGenericOpPattern,
1642 PushDownUnPackOpThroughGenericOp>(
1643 patterns.getContext(), controlPackUnPackPropagation, PoisonPaddingOk);
1649 patterns.insert<PushDownExtractSliceOpThroughGenericOp>(
1650 patterns.getContext(), controlPackUnPackPropagation);
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
Base type for affine expression.
RetT walk(FnT &&callback) const
Walk all of the AffineExpr's in this expression in postorder.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
IntegerAttr getIndexAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
MLIRContext * getContext() const
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void populateDataLayoutPropagationPatterns(RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation, bool PoisonPaddingOk=false)
Patterns to bubble up or down data layout ops across other operations.
void populateExtractSliceSinkingPatterns(RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation)
Patterns to sink extract slice across other operations.
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
std::function< bool(OpOperand *opOperand)> ControlPropagationFn
Function type which is used to control propagation of linalg.pack/unpack ops.
SmallVector< NamedAttribute > getPrunedAttributeList(OpTy op)
Returns an attribute list that excludes pre-defined attributes.
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
llvm::TypeSwitch< T, ResultT > TypeSwitch
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
SmallVector< int64_t, 2 > ReassociationIndices
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...