17 #include "llvm/ADT/SetOperations.h"
18 #include "llvm/ADT/SetVector.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 #include "llvm/Support/Debug.h"
24 #define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION
25 #include "mlir/Dialect/Linalg/Passes.h.inc"
31 #define DEBUG_TYPE "linalg-data-layout-propagation"
35 static bool hasGatherSemantics(linalg::GenericOp genericOp) {
36 for (
Operation &op : genericOp.getBody()->getOperations())
37 if (isa<tensor::ExtractOp, linalg::IndexOp>(op))
45 int64_t getNumTiledLoops()
const {
return tileToPointMapping.size(); };
57 template <
typename OpTy>
58 static FailureOr<PackInfo>
59 getPackingInfoFromOperand(
OpOperand *opOperand, linalg::GenericOp genericOp,
60 OpTy packOrUnPackOp) {
61 static_assert(llvm::is_one_of<OpTy, linalg::PackOp, linalg::UnPackOp>::value,
62 "applies to only pack or unpack operations");
64 { llvm::dbgs() <<
"--- Construct PackInfo From an operand ---\n"; });
66 AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
69 genericOp.getIteratorTypesArray();
72 int64_t origNumDims = indexingMap.getNumDims();
75 for (
auto [index, innerDimPos, tileSize] :
76 llvm::zip_equal(llvm::seq<unsigned>(0,
innerDimsPos.size()),
78 auto expr = exprs[innerDimPos];
79 if (!isa<AffineDimExpr>(expr))
81 int64_t domainDimPos =
82 cast<AffineDimExpr>(exprs[innerDimPos]).getPosition();
85 packInfo.tiledDimsPos.push_back(domainDimPos);
86 packInfo.domainDimAndTileMapping[domainDimPos] = tileSize;
87 packInfo.tileToPointMapping[domainDimPos] = origNumDims + index;
89 llvm::dbgs() <<
"map innerDimPos=" << innerDimPos
90 <<
" to iteration dimension (d" << domainDimPos <<
", d"
91 << packInfo.tileToPointMapping[domainDimPos]
92 <<
"), which has size=("
93 << packInfo.domainDimAndTileMapping[domainDimPos] <<
")\n";
99 auto areAllAffineDimExpr = [&](
int dim) {
101 if (llvm::any_of(map.getResults(), [dim](
AffineExpr expr) {
102 return expr.isFunctionOfDim(dim) && !isa<AffineDimExpr>(expr);
109 for (int64_t i : packInfo.tiledDimsPos)
110 if (!areAllAffineDimExpr(i))
127 for (
auto [index, dim] :
llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) {
128 auto permutedExpr = indexingMap.getResult(dim);
129 if (
auto dimExpr = dyn_cast<AffineDimExpr>(permutedExpr)) {
130 permutedOuterDims.push_back(dimExpr.getPosition());
137 if (
static_cast<int64_t
>(index) != dim)
140 if (!permutedOuterDims.empty()) {
141 int64_t outerDimIndex = 0;
143 permutedOuterDims.end());
144 for (
int i = 0, e = indexingMap.getNumDims(); i < e; i++)
145 packInfo.outerDimsOnDomainPerm.push_back(
146 permutedDomainDims.contains(i) ? permutedOuterDims[outerDimIndex++]
149 llvm::dbgs() <<
"map outer dimsDimsPerm to ";
150 for (
auto dim : packInfo.outerDimsOnDomainPerm)
151 llvm::dbgs() << dim <<
" ";
152 llvm::dbgs() <<
"\n";
170 assert(!perm.empty() &&
"expect perm not to be empty");
171 assert(!exprs.empty() &&
"expect exprs not to be empty");
172 if (exprs.size() == 1)
180 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr))
181 currentPositionTileLoops[dimExpr.getPosition()] = pos;
183 currentPositionTileLoops[pos] = pos;
185 for (int64_t loopIdx : perm) {
186 if (currentPositionTileLoops.count(loopIdx))
187 outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx));
223 static std::tuple<Value, AffineMap>
225 GenericOp genericOp,
OpOperand *opOperand) {
226 int64_t numOrigLoops = genericOp.getNumLoops();
227 int64_t numInnerLoops = packInfo.getNumTiledLoops();
228 int64_t numLoops = numOrigLoops + numInnerLoops;
229 AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand);
234 if (genericOp.isScalar(opOperand) || exprs.empty())
235 return std::make_tuple(opOperand->
get(),
241 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
242 int64_t dimPos = dimExpr.getPosition();
243 domainDimToOperandDim[dimPos] = index;
249 for (
auto dimPos : packInfo.tiledDimsPos) {
250 if (!domainDimToOperandDim.count(dimPos))
252 int64_t index = domainDimToOperandDim[dimPos];
253 innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]);
260 if (!packInfo.outerDimsOnDomainPerm.empty()) {
261 outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs);
266 for (
auto i : llvm::seq<unsigned>(0, origIndexingMap.
getNumResults())) {
267 if (
auto dimExpr = dyn_cast<AffineDimExpr>(exprs[i])) {
268 int64_t dimPos = dimExpr.getPosition();
272 assert(isa<AffineConstantExpr>(exprs[i]) &&
273 "Attempted to permute non-constant and non-affine dim expression");
280 auxVec[en.index()] = exprs[en.value()];
288 return std::make_tuple(opOperand->
get(), indexingMap);
290 auto empty = linalg::PackOp::createDestinationTensor(
292 auto packedOperand = linalg::PackOp::create(
295 return std::make_tuple(packedOperand, indexingMap);
304 static GenericOp packGenericOp(
RewriterBase &rewriter, GenericOp genericOp,
306 const PackInfo &packInfo,
307 bool isFoldableUnpackPack) {
312 auto hasEquivalentTiles = [](PackOp packOp, UnPackOp unPackOp) {
313 return packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm() &&
314 packOp.getInnerDimsPos() == unPackOp.getInnerDimsPos() &&
315 llvm::equal(packOp.getMixedTiles(), unPackOp.getMixedTiles());
317 for (
OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
318 auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
319 rewriter, loc, packInfo, genericOp, inputOperand);
320 auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>();
321 auto packOp = packedOperand.getDefiningOp<linalg::PackOp>();
322 if (packOp && unpackOp && hasEquivalentTiles(packOp, unpackOp)) {
323 inputOperandsFromUnpackedSource.push_back(unpackOp.getSource());
325 inputOperandsFromUnpackedSource.push_back(packedOperand);
327 inputOperands.push_back(packedOperand);
328 indexingMaps.push_back(packedIndexingMap);
333 if (isFoldableUnpackPack) {
334 inputOperands = inputOperandsFromUnpackedSource;
336 auto destUnPack = destPack.getSource().getDefiningOp<linalg::UnPackOp>();
337 if (destUnPack && hasEquivalentTiles(destPack, destUnPack)) {
338 dest = destUnPack.getSource();
343 int64_t numInnerLoops = packInfo.getNumTiledLoops();
345 genericOp.getIteratorTypesArray();
346 iterTypes.append(numInnerLoops, utils::IteratorType::parallel);
348 indexingMaps.push_back(packedOutIndexingMap);
350 auto newGenericOp = linalg::GenericOp::create(
351 rewriter, loc, dest.
getType(), inputOperands, dest, indexingMaps,
355 newGenericOp.getRegion().begin());
359 static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
360 return llvm::all_of(genericOp.getDpsInitsMutable(), [&](
OpOperand &operand) {
361 return genericOp.getMatchingBlockArgument(&operand).use_empty();
408 static FailureOr<GenericOp>
409 bubbleUpPackOpThroughGenericOp(
RewriterBase &rewriter, linalg::PackOp packOp,
411 auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
416 if (!controlFn(&packOp.getSourceMutable()))
422 if (hasGatherSemantics(genericOp))
427 if (genericOp.getNumResults() != 1)
434 if (!genericOp->getResult(0).hasOneUse())
441 if (packOp.getPaddingValue())
444 OpOperand *opOperand = genericOp.getDpsInitOperand(0);
445 auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp);
459 Value packOpDest = packOp.getDest();
462 if (
auto emptyOp = packOpDest.
getDefiningOp<tensor::EmptyOp>()) {
463 packOpDest = tensor::EmptyOp::create(rewriter, genericOp->getLoc(),
464 emptyOp.getMixedSizes(),
465 emptyOp.
getType().getElementType());
473 auto [packedOutOperand, packedOutIndexingMap] =
474 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
475 genericOp, opOperand);
482 Value dest = packedOutOperand;
484 genericOp.getDpsInitOperand(0)->get().
getDefiningOp<tensor::EmptyOp>();
485 if (initTensor || isGenericOutsNotUsed(genericOp)) {
490 return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap,
495 struct BubbleUpPackOpThroughGenericOpPattern
498 BubbleUpPackOpThroughGenericOpPattern(
MLIRContext *context,
502 LogicalResult matchAndRewrite(linalg::PackOp packOp,
505 bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn);
508 rewriter.
replaceOp(packOp, genericOp->getResults());
519 class BubbleUpPackThroughPadOp final :
public OpRewritePattern<linalg::PackOp> {
524 LogicalResult matchAndRewrite(linalg::PackOp packOp,
526 auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
531 if (!controlFn(&packOp.getSourceMutable()))
535 if (packOp.getPaddingValue())
542 Value paddingVal = padOp.getConstantPaddingValue();
546 if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
552 llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
553 llvm::SmallBitVector innerDims(paddedDims.size());
556 if (paddedDims.anyCommon(innerDims))
565 auto empty = linalg::PackOp::createDestinationTensor(
566 rewriter, loc, padOp.getSource(), mixedTiles,
innerDimsPos,
568 auto sourcePack = linalg::PackOp::create(
569 rewriter, loc, padOp.getSource(), empty,
innerDimsPos, mixedTiles,
576 applyPermutationToVector<OpFoldResult>(lowPad,
outerDimsPerm);
577 applyPermutationToVector<OpFoldResult>(highPad,
outerDimsPerm);
582 lowPad.append(pointLoopsSize, rewriter.
getIndexAttr(0));
583 highPad.append(pointLoopsSize, rewriter.
getIndexAttr(0));
586 tensor::PadOp::create(rewriter, loc,
Type(), sourcePack,
587 lowPad, highPad, paddingVal, padOp.getNofold());
591 if (!padOp->hasOneUse()) {
592 auto unpackEmpty = linalg::UnPackOp::createDestinationTensor(
595 linalg::UnPackOp::create(rewriter, loc, newPadOp, unpackEmpty,
601 rewriter.
replaceOp(packOp, newPadOp.getResult());
624 for (
auto pos : dimsPos) {
626 int64_t projectedPos = reassocIndices[pos].back();
627 for (
auto i : llvm::reverse(reassocIndices[pos])) {
628 int64_t dim = targetShape[i];
629 if (dim > 1 || ShapedType::isDynamic(dim)) {
634 projectedDimsPos.push_back(projectedPos);
636 return projectedDimsPos;
643 for (
auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) {
644 int64_t dim = shape[pos];
645 if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0)
657 static int64_t applyPermutationAndReindexReassoc(
660 if (!permutation.empty())
661 applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation);
664 for (
auto &index : indices) {
692 bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
693 linalg::PackOp packOp,
701 collapseOp.getReassociationIndices();
710 projectToInnerMostNonUnitDimsPos(
innerDimsPos, reassocIndices, srcShape);
712 if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape,
721 llvm::append_range(newOuterDimsPerm, reassocIndices[outerPos]);
723 auto emptyOp = linalg::PackOp::createDestinationTensor(
724 rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(),
725 projectedInnerDimsPos, newOuterDimsPerm);
726 auto newPackOp = linalg::PackOp::create(
727 rewriter, packOp.getLoc(), collapseOp.getSrc(), emptyOp,
728 projectedInnerDimsPos, packOp.getMixedTiles(), packOp.getPaddingValue(),
736 applyPermutationAndReindexReassoc(newReassocIndices,
outerDimsPerm);
739 newReassocIndices.push_back({nextPos});
743 auto newCollapseOp = tensor::CollapseShapeOp::create(
744 rewriter, collapseOp.getLoc(), packOp.
getType(), newPackOp,
746 rewriter.
replaceOp(packOp, newCollapseOp);
763 for (
auto pos : dimsPos) {
768 if (llvm::is_contained(indices, pos)) {
769 projectedPos.push_back(idx);
774 assert(projectedPos.size() == dimsPos.size() &&
"Invalid dim pos projection");
797 bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
798 linalg::PackOp packOp,
805 "non-identity outer dims perm NYI");
810 expandOp.getReassociationIndices();
819 llvm::set_intersection(packDimsPos, expandDimPos);
823 if (packedDims.empty())
828 if (packedDims.size() != 1)
830 packOp,
"only one of the expanded dimensions can be packed");
833 if (packedDims.front() != indices.back())
835 packOp,
"can only pack the inner-most expanded dimension");
840 projectDimsPosIntoReassocPos(packInnerDims, reassoc);
850 RankedTensorType newPackType = linalg::PackOp::inferPackedType(
851 expandOp.getSrcType(), packOp.getStaticInnerTiles(),
857 packOp,
"could not reassociate dims after bubbling up");
859 Value destTensor = linalg::PackOp::createDestinationTensor(
860 rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(),
862 Value packedVal = linalg::PackOp::create(
863 rewriter, packOp.getLoc(), expandOp.getSrc(), destTensor,
864 projectedInnerDimsPos, packOp.getMixedTiles(), packOp.getPaddingValue(),
867 Value newExpandOp = tensor::ExpandShapeOp::create(rewriter, packOp.getLoc(),
868 packOp.getDestType(),
869 packedVal, *reassocExpand);
875 class BubbleUpPackOpThroughReshapeOp final
881 LogicalResult matchAndRewrite(linalg::PackOp packOp,
883 Operation *srcOp = packOp.getSource().getDefiningOp();
890 if (llvm::any_of(packOp.getStaticTiles(), ShapedType::isDynamic))
894 if (!controlFn(&packOp.getSourceMutable()))
898 .Case([&](tensor::CollapseShapeOp op) {
899 return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter);
901 .Case([&](tensor::ExpandShapeOp op) {
902 return bubbleUpPackOpThroughExpandShape(op, packOp, rewriter);
904 .Default([](
Operation *) {
return failure(); });
930 static LogicalResult pushDownUnPackOpThroughExpandShape(
931 linalg::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp,
934 if (!controlFn(&expandOp.getSrcMutable()))
941 auto expandTy = dyn_cast<RankedTensorType>(expandOp.getType());
946 expandOp.getReassociationIndices();
955 projectToInnerMostNonUnitDimsPos(
innerDimsPos, reassocIndices, dstShape);
957 if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape,
966 llvm::append_range(newOuterDimsPerm, reassocIndices[outerPos]);
973 applyPermutationAndReindexReassoc(newReassocIndices,
outerDimsPerm);
976 newReassocIndices.push_back({nextPos});
980 RankedTensorType newExpandType = linalg::PackOp::inferPackedType(
981 expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm);
983 tensor::ExpandShapeOp::create(rewriter, expandOp.getLoc(), newExpandType,
984 unPackOp.getSource(), newReassocIndices);
986 auto emptyOp = linalg::UnPackOp::createDestinationTensor(
987 rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(),
988 projectedInnerDimsPos, newOuterDimsPerm);
989 auto newUnPackOp = linalg::UnPackOp::create(
990 rewriter, unPackOp.getLoc(), newExpandOp.getResult(), emptyOp,
991 projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm);
992 rewriter.
replaceOp(expandOp, newUnPackOp);
997 class PushDownUnPackOpThroughReshapeOp final
1000 PushDownUnPackOpThroughReshapeOp(
MLIRContext *context,
1005 LogicalResult matchAndRewrite(linalg::UnPackOp unPackOp,
1007 Value result = unPackOp.getResult();
1013 if (llvm::any_of(unPackOp.getStaticTiles(), ShapedType::isDynamic))
1018 .Case([&](tensor::ExpandShapeOp op) {
1019 return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter,
1022 .Default([](
Operation *) {
return failure(); });
1032 static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
1034 for (
OpOperand &operand : genericOp->getOpOperands()) {
1038 if (unPackedOperand)
1040 unPackedOperand = &operand;
1042 if (!unPackedOperand)
1044 return unPackedOperand;
1081 static FailureOr<std::tuple<GenericOp, Value>>
1082 pushDownUnPackOpThroughGenericOp(
RewriterBase &rewriter, GenericOp genericOp,
1084 if (genericOp.getNumResults() != 1)
1087 if (hasGatherSemantics(genericOp))
1091 auto maybeUnPackedOperand = getUnPackedOperand(genericOp);
1092 if (
failed(maybeUnPackedOperand))
1094 OpOperand *unPackedOperand = *(maybeUnPackedOperand);
1097 linalg::UnPackOp producerUnPackOp =
1099 assert(producerUnPackOp &&
"expect a valid UnPackOp");
1101 if (!controlFn(unPackedOperand))
1105 getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp);
1110 auto [packedOutOperand, packedOutIndexingMap] =
1111 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
1112 genericOp, genericOp.getDpsInitOperand(0));
1113 auto destPack = packedOutOperand.getDefiningOp<linalg::PackOp>();
1120 Value dest = packedOutOperand;
1122 genericOp.getDpsInitOperand(0)->get().
getDefiningOp<tensor::EmptyOp>();
1123 if (initTensor || isGenericOutsNotUsed(genericOp)) {
1125 dest = destPack.getDest();
1132 GenericOp newGenericOp =
1133 packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo,
1136 newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0));
1140 return std::make_tuple(newGenericOp, newResult);
1142 auto mixedTiles = destPack.getMixedTiles();
1148 linalg::UnPackOp::create(rewriter, genericOp.getLoc(), newResult,
1153 return std::make_tuple(newGenericOp, unPackOpRes);
1157 struct PushDownUnPackOpThroughGenericOp :
public OpRewritePattern<GenericOp> {
1159 PushDownUnPackOpThroughGenericOp(
MLIRContext *context,
1163 LogicalResult matchAndRewrite(GenericOp genericOp,
1165 auto genericAndRepl =
1166 pushDownUnPackOpThroughGenericOp(rewriter, genericOp, controlFn);
1167 if (
failed(genericAndRepl))
1169 rewriter.
replaceOp(genericOp, std::get<1>(*genericAndRepl));
1180 struct PushDownUnPackThroughPadOp :
public OpRewritePattern<tensor::PadOp> {
1184 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1186 linalg::UnPackOp unpackOp =
1187 padOp.getSource().getDefiningOp<linalg::UnPackOp>();
1191 if (!controlFn(&padOp.getSourceMutable()))
1196 llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
1198 llvm::SmallBitVector innerDims(paddedDims.size());
1200 innerDims.flip(dim);
1201 if (paddedDims.anyCommon(innerDims))
1204 Value paddingVal = padOp.getConstantPaddingValue();
1213 applyPermutationToVector<OpFoldResult>(lowPad,
outerDimsPerm);
1214 applyPermutationToVector<OpFoldResult>(highPad,
outerDimsPerm);
1218 lowPad.append(pointLoopsSize, rewriter.
getIndexAttr(0));
1219 highPad.append(pointLoopsSize, rewriter.
getIndexAttr(0));
1221 auto newPadOp = tensor::PadOp::create(rewriter, loc,
Type(),
1222 unpackOp.getSource(), lowPad, highPad,
1223 paddingVal, padOp.getNofold());
1226 Value outputUnPack =
1227 tensor::EmptyOp::create(rewriter, loc, padOp.getResultType().getShape(),
1228 padOp.getResultType().getElementType());
1230 Value replacement = linalg::UnPackOp::create(
1231 rewriter, loc, newPadOp.getResult(), outputUnPack,
innerDimsPos,
1242 struct SliceDimInfo {
1250 static FailureOr<OpOperand *> getSliceOperand(GenericOp genericOp) {
1252 for (
auto operand : genericOp.getDpsInputOperands()) {
1256 sliceOperand = operand;
1259 if (!sliceOperand) {
1262 return sliceOperand;
1268 static FailureOr<llvm::DenseMap<int64_t, SliceDimInfo>>
1269 getPartialSliceDimInfo(GenericOp genericOp,
OpOperand *sliceOperand) {
1270 tensor::ExtractSliceOp producerSliceOp =
1272 assert(producerSliceOp &&
"expect a valid ExtractSliceOp");
1278 genericOp.getContext(), producerSliceOp.getSourceType().getShape());
1281 genericOp.getMatchingIndexingMap(sliceOperand).getResults())) {
1290 if (!isa<AffineDimExpr>(expr)) {
1293 SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]};
1294 int64_t dimPos = cast<AffineDimExpr>(expr).getPosition();
1295 partialSliceDimMap[dimPos] = sliceDimInfo;
1299 for (
OpOperand &operand : genericOp->getOpOperands()) {
1300 if (operand == *sliceOperand) {
1303 AffineMap IndexingMap = genericOp.getMatchingIndexingMap(&operand);
1305 if (isa<AffineDimExpr>(expr)) {
1309 if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
1310 if (partialSliceDimMap.contains(dimExpr.getPosition())) {
1311 return WalkResult::interrupt();
1316 if (status.wasInterrupted()) {
1324 return partialSliceDimMap;
1327 static FailureOr<std::tuple<GenericOp, Value>>
1328 pushDownExtractSliceOpThroughGenericOp(
RewriterBase &rewriter,
1329 GenericOp genericOp,
1331 if (genericOp.getNumResults() != 1)
1333 genericOp,
"propagation through multi-result generic is unsupported.");
1334 if (hasGatherSemantics(genericOp))
1337 "propagation through generic with gather semantics is unsupported.");
1339 auto maybeSliceOperand = getSliceOperand(genericOp);
1340 if (
failed(maybeSliceOperand))
1342 OpOperand *sliceOperand = *maybeSliceOperand;
1345 if (!controlFn(sliceOperand))
1348 tensor::ExtractSliceOp producerSliceOp =
1350 assert(producerSliceOp &&
"expect a valid ExtractSliceOp");
1352 if (producerSliceOp.getSource().getType().getRank() !=
1353 producerSliceOp.getResult().getType().getRank()) {
1356 "propagation of rank-reducing extract slice is unsupported.");
1362 genericOp,
"propagation of strided extract slice is unsupported.");
1367 auto maybePartialSliceDimMap =
1368 getPartialSliceDimInfo(genericOp, sliceOperand);
1370 if (
failed(maybePartialSliceDimMap)) {
1374 auto partialSliceDimMap = *maybePartialSliceDimMap;
1377 genericOp.getIteratorTypesArray();
1378 bool hasPartialReductionDimSlice =
1379 llvm::any_of(partialSliceDimMap, [&](
const auto &slice) {
1380 int64_t sliceDim = slice.first;
1381 return iterators[sliceDim] == utils::IteratorType::reduction;
1385 Location loc = genericOp->getLoc();
1396 for (
auto [idx, operand] :
llvm::enumerate(genericOp.getDpsInputOperands())) {
1397 if (idx == OperandIndex && !hasPartialReductionDimSlice) {
1398 paddedInputs.push_back(producerSliceOp.getSource());
1401 AffineMap IndexingMap = genericOp.getMatchingIndexingMap(operand);
1407 if (!isa<AffineDimExpr>(expr)) {
1411 if (!partialSliceDimMap.contains(dimExpr.
getPosition())) {
1414 SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.
getPosition()];
1415 operandLowPads[idx] = sliceDimInfo.offset;
1416 operandHighPads[idx] =
1417 sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
1418 sliceDimInfo.sliceSize);
1420 auto paddingValue = ub::PoisonOp::create(
1422 auto paddedOperand = tensor::PadOp::create(
1423 rewriter, loc,
Type(), operand->get(), operandLowPads, operandHighPads,
1424 paddingValue,
false);
1425 paddedInputs.push_back(paddedOperand);
1428 genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0));
1430 auto outputShapeType =
1431 llvm::cast<ShapedType>(genericOp.getDpsInitOperand(0)->get().getType());
1433 outputShapeType.getShape(),
1434 [&](int64_t sz) ->
OpFoldResult { return rewriter.getIndexAttr(sz); });
1443 if (!isa<AffineDimExpr>(expr)) {
1447 if (!partialSliceDimMap.contains(dimExpr.
getPosition())) {
1450 SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.
getPosition()];
1451 outputLowPads[idx] = sliceDimInfo.offset;
1452 outputHighPads[idx] = sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
1453 sliceDimInfo.sliceSize);
1454 OutputShape[idx] = sliceDimInfo.outputSize;
1455 newSizes[idx] = sliceDimInfo.sliceSize;
1460 if (isGenericOutsNotUsed(genericOp)) {
1462 tensor::EmptyOp::create(rewriter, loc, OutputShape, outputElType);
1464 auto paddingValue = ub::PoisonOp::create(rewriter, loc, outputElType);
1465 newPadOutput = tensor::PadOp::create(
1466 rewriter, loc,
Type(), genericOp.getDpsInits()[0], outputLowPads,
1467 outputHighPads, paddingValue,
false);
1470 auto newGenericOp = linalg::GenericOp::create(
1471 rewriter, loc, newPadOutput.
getType(), paddedInputs, {newPadOutput},
1472 genericOp.getIndexingMapsArray(), genericOp.getIteratorTypesArray(),
1475 newGenericOp.getRegion().begin());
1477 auto extractOp = tensor::ExtractSliceOp::create(
1479 newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)),
1480 outputLowPads, newSizes, newStrides);
1481 Value extractRes = extractOp.getResult();
1483 return std::make_tuple(newGenericOp, extractRes);
1486 class PushDownExtractSliceOpThroughGenericOp final
1489 PushDownExtractSliceOpThroughGenericOp(
MLIRContext *context,
1493 LogicalResult matchAndRewrite(GenericOp genericOp,
1495 auto genericAndRepl =
1496 pushDownExtractSliceOpThroughGenericOp(rewriter, genericOp, controlFn);
1497 if (
failed(genericAndRepl))
1499 rewriter.
replaceOp(genericOp, std::get<1>(*genericAndRepl));
1513 .insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp,
1514 BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp,
1515 PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
1516 patterns.getContext(), controlPackUnPackPropagation);
1522 patterns.insert<PushDownExtractSliceOpThroughGenericOp>(
1523 patterns.getContext(), controlPackUnPackPropagation);
SmallVector< int64_t > outerDimsPerm
SmallVector< int64_t > innerDimsPos
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
Base type for affine expression.
RetT walk(FnT &&callback) const
Walk all of the AffineExpr's in this expression in postorder.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
IntegerAttr getIndexAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_iterator user_begin() const
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateExtractSliceSinkingPatterns(RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation)
Patterns to sink extract slice across other operations.
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
void populateDataLayoutPropagationPatterns(RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation)
Patterns to bubble up or down data layout ops across other operations.
SmallVector< NamedAttribute > getPrunedAttributeList(OpTy op)
Returns an attribute list that excludes pre-defined attributes.
std::function< bool(OpOperand *opOperand)> ControlPropagationFn
Function type which is used to control propagation of linalg.pack/unpack ops.
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...