18 #include "llvm/ADT/SetOperations.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 #include "llvm/Support/Debug.h"
25 #define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION
26 #include "mlir/Dialect/Linalg/Passes.h.inc"
32 #define DEBUG_TYPE "linalg-data-layout-propagation"
36 static bool hasGatherSemantics(linalg::GenericOp genericOp) {
37 for (
Operation &op : genericOp.getBody()->getOperations())
38 if (isa<tensor::ExtractOp, linalg::IndexOp>(op))
46 int64_t getNumTiledLoops()
const {
return tileToPointMapping.size(); };
58 template <
typename OpTy>
59 static FailureOr<PackInfo>
60 getPackingInfoFromOperand(
OpOperand *opOperand, linalg::GenericOp genericOp,
61 OpTy packOrUnPackOp) {
62 static_assert(llvm::is_one_of<OpTy, linalg::PackOp, linalg::UnPackOp>::value,
63 "applies to only pack or unpack operations");
65 { llvm::dbgs() <<
"--- Construct PackInfo From an operand ---\n"; });
67 AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
70 genericOp.getIteratorTypesArray();
73 int64_t origNumDims = indexingMap.getNumDims();
76 for (
auto [index, innerDimPos, tileSize] :
77 llvm::zip_equal(llvm::seq<unsigned>(0,
innerDimsPos.size()),
79 auto expr = exprs[innerDimPos];
80 if (!isa<AffineDimExpr>(expr))
82 int64_t domainDimPos =
83 cast<AffineDimExpr>(exprs[innerDimPos]).getPosition();
86 packInfo.tiledDimsPos.push_back(domainDimPos);
87 packInfo.domainDimAndTileMapping[domainDimPos] = tileSize;
88 packInfo.tileToPointMapping[domainDimPos] = origNumDims + index;
90 llvm::dbgs() <<
"map innerDimPos=" << innerDimPos
91 <<
" to iteration dimension (d" << domainDimPos <<
", d"
92 << packInfo.tileToPointMapping[domainDimPos]
93 <<
"), which has size=("
94 << packInfo.domainDimAndTileMapping[domainDimPos] <<
")\n";
100 auto areAllAffineDimExpr = [&](
int dim) {
102 if (llvm::any_of(map.getResults(), [dim](
AffineExpr expr) {
103 return expr.isFunctionOfDim(dim) && !isa<AffineDimExpr>(expr);
110 for (int64_t i : packInfo.tiledDimsPos)
111 if (!areAllAffineDimExpr(i))
128 for (
auto [index, dim] :
llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) {
129 auto permutedExpr = indexingMap.getResult(dim);
130 if (
auto dimExpr = dyn_cast<AffineDimExpr>(permutedExpr)) {
131 permutedOuterDims.push_back(dimExpr.getPosition());
138 if (
static_cast<int64_t
>(index) != dim)
141 if (!permutedOuterDims.empty()) {
142 int64_t outerDimIndex = 0;
144 permutedOuterDims.end());
145 for (
int i = 0, e = indexingMap.getNumDims(); i < e; i++)
146 packInfo.outerDimsOnDomainPerm.push_back(
147 permutedDomainDims.contains(i) ? permutedOuterDims[outerDimIndex++]
150 llvm::dbgs() <<
"map outer dimsDimsPerm to ";
151 for (
auto dim : packInfo.outerDimsOnDomainPerm)
152 llvm::dbgs() << dim <<
" ";
153 llvm::dbgs() <<
"\n";
171 assert(!perm.empty() &&
"expect perm not to be empty");
172 assert(!exprs.empty() &&
"expect exprs not to be empty");
173 if (exprs.size() == 1)
181 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr))
182 currentPositionTileLoops[dimExpr.getPosition()] = pos;
184 currentPositionTileLoops[pos] = pos;
186 for (int64_t loopIdx : perm) {
187 if (currentPositionTileLoops.count(loopIdx))
188 outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx));
193 struct PackedOperandDetails {
203 static bool getPackedOperandDetails(
206 PackedOperandDetails currOperandDetails;
207 int64_t numOrigLoops = genericOp.getNumLoops();
208 int64_t numInnerLoops = packInfo.getNumTiledLoops();
209 int64_t numLoops = numOrigLoops + numInnerLoops;
210 AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand);
215 if (genericOp.isScalar(opOperand) || exprs.empty()) {
216 currOperandDetails.indexingMap =
218 packedOperandMap[opOperand] = currOperandDetails;
225 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
226 int64_t dimPos = dimExpr.getPosition();
227 domainDimToOperandDim[dimPos] = index;
233 for (
auto dimPos : packInfo.tiledDimsPos) {
234 if (!domainDimToOperandDim.count(dimPos))
236 int64_t index = domainDimToOperandDim[dimPos];
237 innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]);
244 if (!packInfo.outerDimsOnDomainPerm.empty()) {
245 outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs);
250 for (
auto i : llvm::seq<unsigned>(0, origIndexingMap.
getNumResults())) {
251 if (
auto dimExpr = dyn_cast<AffineDimExpr>(exprs[i])) {
252 int64_t dimPos = dimExpr.getPosition();
256 assert(isa<AffineConstantExpr>(exprs[i]) &&
257 "Attempted to permute non-constant and non-affine dim expression");
264 auxVec[en.index()] = exprs[en.value()];
268 currOperandDetails.indexingMap =
273 packedOperandMap[opOperand] = currOperandDetails;
276 auto inputType = cast<RankedTensorType>(opOperand->
get().
getType());
278 auto maybeIntInnerTileSizes =
279 llvm::map_to_vector(innerTileSizes, [](
OpFoldResult ofr) -> int64_t {
281 return maybeCst.value_or(ShapedType::kDynamic);
283 bool requirePadding = linalg::PackOp::requirePaddingValueStrict(
285 linalg::PackOp::inferPackedType(inputType, maybeIntInnerTileSizes,
290 currOperandDetails.innerTileSizes = innerTileSizes;
292 packedOperandMap[opOperand] = currOperandDetails;
294 return requirePadding;
328 static std::tuple<Value, AffineMap> getOrCreatePackedViewOfOperand(
331 assert(packedOperandMap.contains(opOperand) &&
332 "packed operand details expected to be populated");
333 auto currOperandDetails = packedOperandMap.at(opOperand);
336 auto innerTileSizes = currOperandDetails.innerTileSizes;
338 return std::make_tuple(opOperand->
get(), currOperandDetails.indexingMap);
340 auto empty = linalg::PackOp::createDestinationTensor(
342 auto poison = ub::PoisonOp::create(
344 Value packedOperand =
347 return std::make_tuple(packedOperand, currOperandDetails.indexingMap);
356 static FailureOr<GenericOp>
358 AffineMap packedOutIndexingMap,
const PackInfo &packInfo,
359 bool isFoldableUnpackPack,
bool poisonPaddingOk) {
364 auto hasEquivalentTiles = [](PackOp packOp, UnPackOp unPackOp) {
365 return packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm() &&
366 packOp.getInnerDimsPos() == unPackOp.getInnerDimsPos() &&
367 llvm::equal(packOp.getMixedTiles(), unPackOp.getMixedTiles());
370 bool requiresPadding =
false;
371 for (
OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
372 requiresPadding |= getPackedOperandDetails(rewriter, packInfo, genericOp,
373 inputOperand, packedOperandMap);
375 if (requiresPadding && !poisonPaddingOk)
378 for (
OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
379 auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
380 rewriter, loc, inputOperand, packedOperandMap);
381 auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>();
382 auto packOp = packedOperand.getDefiningOp<linalg::PackOp>();
383 if (packOp && unpackOp && hasEquivalentTiles(packOp, unpackOp)) {
384 inputOperandsFromUnpackedSource.push_back(unpackOp.getSource());
386 inputOperandsFromUnpackedSource.push_back(packedOperand);
388 inputOperands.push_back(packedOperand);
389 indexingMaps.push_back(packedIndexingMap);
394 if (isFoldableUnpackPack) {
395 inputOperands = inputOperandsFromUnpackedSource;
397 auto destUnPack = destPack.getSource().getDefiningOp<linalg::UnPackOp>();
398 if (destUnPack && hasEquivalentTiles(destPack, destUnPack)) {
399 dest = destUnPack.getSource();
404 int64_t numInnerLoops = packInfo.getNumTiledLoops();
406 genericOp.getIteratorTypesArray();
407 iterTypes.append(numInnerLoops, utils::IteratorType::parallel);
409 indexingMaps.push_back(packedOutIndexingMap);
411 auto newGenericOp = linalg::GenericOp::create(
412 rewriter, loc, dest.
getType(), inputOperands, dest, indexingMaps,
416 newGenericOp.getRegion().begin());
420 static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
421 return llvm::all_of(genericOp.getDpsInitsMutable(), [&](
OpOperand &operand) {
422 return genericOp.getMatchingBlockArgument(&operand).use_empty();
469 static FailureOr<GenericOp>
470 bubbleUpPackOpThroughGenericOp(
RewriterBase &rewriter, linalg::PackOp packOp,
472 bool poisonPaddingOk) {
473 auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
478 if (!controlFn(&packOp.getSourceMutable()))
484 if (hasGatherSemantics(genericOp))
489 if (genericOp.getNumResults() != 1)
496 if (!genericOp->getResult(0).hasOneUse())
503 if (packOp.getPaddingValue())
506 OpOperand *opOperand = genericOp.getDpsInitOperand(0);
507 auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp);
521 Value packOpDest = packOp.getDest();
524 if (
auto emptyOp = packOpDest.
getDefiningOp<tensor::EmptyOp>()) {
525 packOpDest = tensor::EmptyOp::create(rewriter, genericOp->getLoc(),
526 emptyOp.getMixedSizes(),
527 emptyOp.
getType().getElementType());
536 bool requiresPadding = getPackedOperandDetails(rewriter, *packInfo, genericOp,
537 opOperand, packedOperandMap);
538 if (requiresPadding && !poisonPaddingOk)
541 auto [packedOutOperand, packedOutIndexingMap] =
542 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), opOperand,
549 Value dest = packedOutOperand;
551 genericOp.getDpsInitOperand(0)->get().
getDefiningOp<tensor::EmptyOp>();
552 if (initTensor || isGenericOutsNotUsed(genericOp)) {
557 return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap,
563 struct BubbleUpPackOpThroughGenericOpPattern
566 BubbleUpPackOpThroughGenericOpPattern(
MLIRContext *context,
568 bool poisonPaddingOk)
570 poisonPaddingOk(std::move(poisonPaddingOk)) {}
572 LogicalResult matchAndRewrite(linalg::PackOp packOp,
574 auto genericOp = bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn,
578 rewriter.
replaceOp(packOp, genericOp->getResults());
584 bool poisonPaddingOk;
590 class BubbleUpPackThroughPadOp final :
public OpRewritePattern<linalg::PackOp> {
595 LogicalResult matchAndRewrite(linalg::PackOp packOp,
597 auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
602 if (!controlFn(&packOp.getSourceMutable()))
606 if (packOp.getPaddingValue())
613 Value paddingVal = padOp.getConstantPaddingValue();
617 if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
623 llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
624 llvm::SmallBitVector innerDims(paddedDims.size());
627 if (paddedDims.anyCommon(innerDims))
636 auto empty = linalg::PackOp::createDestinationTensor(
637 rewriter, loc, padOp.getSource(), mixedTiles,
innerDimsPos,
639 auto sourcePack = linalg::PackOp::create(
640 rewriter, loc, padOp.getSource(), empty,
innerDimsPos, mixedTiles,
647 applyPermutationToVector<OpFoldResult>(lowPad,
outerDimsPerm);
648 applyPermutationToVector<OpFoldResult>(highPad,
outerDimsPerm);
653 lowPad.append(pointLoopsSize, rewriter.
getIndexAttr(0));
654 highPad.append(pointLoopsSize, rewriter.
getIndexAttr(0));
657 tensor::PadOp::create(rewriter, loc,
Type(), sourcePack,
658 lowPad, highPad, paddingVal, padOp.getNofold());
662 if (!padOp->hasOneUse()) {
663 auto unpackEmpty = linalg::UnPackOp::createDestinationTensor(
666 linalg::UnPackOp::create(rewriter, loc, newPadOp, unpackEmpty,
672 rewriter.
replaceOp(packOp, newPadOp.getResult());
695 for (
auto pos : dimsPos) {
697 int64_t projectedPos = reassocIndices[pos].back();
698 for (
auto i : llvm::reverse(reassocIndices[pos])) {
699 int64_t dim = targetShape[i];
700 if (dim > 1 || ShapedType::isDynamic(dim)) {
705 projectedDimsPos.push_back(projectedPos);
707 return projectedDimsPos;
714 for (
auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) {
715 int64_t dim = shape[pos];
716 if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0)
728 static int64_t applyPermutationAndReindexReassoc(
731 if (!permutation.empty())
732 applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation);
735 for (
auto &index : indices) {
763 bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
764 linalg::PackOp packOp,
772 collapseOp.getReassociationIndices();
781 projectToInnerMostNonUnitDimsPos(
innerDimsPos, reassocIndices, srcShape);
783 if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape,
792 llvm::append_range(newOuterDimsPerm, reassocIndices[outerPos]);
794 auto emptyOp = linalg::PackOp::createDestinationTensor(
795 rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(),
796 projectedInnerDimsPos, newOuterDimsPerm);
797 auto newPackOp = linalg::PackOp::create(
798 rewriter, packOp.getLoc(), collapseOp.getSrc(), emptyOp,
799 projectedInnerDimsPos, packOp.getMixedTiles(), packOp.getPaddingValue(),
807 applyPermutationAndReindexReassoc(newReassocIndices,
outerDimsPerm);
810 newReassocIndices.push_back({nextPos});
814 auto newCollapseOp = tensor::CollapseShapeOp::create(
815 rewriter, collapseOp.getLoc(), packOp.
getType(), newPackOp,
817 rewriter.
replaceOp(packOp, newCollapseOp);
834 for (
auto pos : dimsPos) {
839 if (llvm::is_contained(indices, pos)) {
840 projectedPos.push_back(idx);
845 assert(projectedPos.size() == dimsPos.size() &&
"Invalid dim pos projection");
868 bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
869 linalg::PackOp packOp,
876 "non-identity outer dims perm NYI");
881 expandOp.getReassociationIndices();
890 llvm::set_intersection(packDimsPos, expandDimPos);
894 if (packedDims.empty())
899 if (packedDims.size() != 1)
901 packOp,
"only one of the expanded dimensions can be packed");
904 if (packedDims.front() != indices.back())
906 packOp,
"can only pack the inner-most expanded dimension");
911 projectDimsPosIntoReassocPos(packInnerDims, reassoc);
921 RankedTensorType newPackType = linalg::PackOp::inferPackedType(
922 expandOp.getSrcType(), packOp.getStaticInnerTiles(),
928 packOp,
"could not reassociate dims after bubbling up");
930 Value destTensor = linalg::PackOp::createDestinationTensor(
931 rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(),
933 Value packedVal = linalg::PackOp::create(
934 rewriter, packOp.getLoc(), expandOp.getSrc(), destTensor,
935 projectedInnerDimsPos, packOp.getMixedTiles(), packOp.getPaddingValue(),
938 Value newExpandOp = tensor::ExpandShapeOp::create(rewriter, packOp.getLoc(),
939 packOp.getDestType(),
940 packedVal, *reassocExpand);
946 class BubbleUpPackOpThroughReshapeOp final
952 LogicalResult matchAndRewrite(linalg::PackOp packOp,
954 Operation *srcOp = packOp.getSource().getDefiningOp();
961 if (llvm::any_of(packOp.getStaticTiles(), ShapedType::isDynamic))
965 if (!controlFn(&packOp.getSourceMutable()))
969 .Case([&](tensor::CollapseShapeOp op) {
970 return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter);
972 .Case([&](tensor::ExpandShapeOp op) {
973 return bubbleUpPackOpThroughExpandShape(op, packOp, rewriter);
975 .Default([](
Operation *) {
return failure(); });
1001 static LogicalResult pushDownUnPackOpThroughExpandShape(
1002 linalg::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp,
1005 if (!controlFn(&expandOp.getSrcMutable()))
1012 auto expandTy = dyn_cast<RankedTensorType>(expandOp.getType());
1017 expandOp.getReassociationIndices();
1026 projectToInnerMostNonUnitDimsPos(
innerDimsPos, reassocIndices, dstShape);
1028 if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape,
1037 llvm::append_range(newOuterDimsPerm, reassocIndices[outerPos]);
1044 applyPermutationAndReindexReassoc(newReassocIndices,
outerDimsPerm);
1047 newReassocIndices.push_back({nextPos});
1051 RankedTensorType newExpandType = linalg::PackOp::inferPackedType(
1052 expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm);
1054 tensor::ExpandShapeOp::create(rewriter, expandOp.getLoc(), newExpandType,
1055 unPackOp.getSource(), newReassocIndices);
1057 auto emptyOp = linalg::UnPackOp::createDestinationTensor(
1058 rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(),
1059 projectedInnerDimsPos, newOuterDimsPerm);
1060 auto newUnPackOp = linalg::UnPackOp::create(
1061 rewriter, unPackOp.getLoc(), newExpandOp.getResult(), emptyOp,
1062 projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm);
1063 rewriter.
replaceOp(expandOp, newUnPackOp);
1068 class PushDownUnPackOpThroughReshapeOp final
1071 PushDownUnPackOpThroughReshapeOp(
MLIRContext *context,
1076 LogicalResult matchAndRewrite(linalg::UnPackOp unPackOp,
1078 Value result = unPackOp.getResult();
1084 if (llvm::any_of(unPackOp.getStaticTiles(), ShapedType::isDynamic))
1089 .Case([&](tensor::ExpandShapeOp op) {
1090 return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter,
1093 .Default([](
Operation *) {
return failure(); });
1103 static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
1105 for (
OpOperand &operand : genericOp->getOpOperands()) {
1109 if (unPackedOperand)
1111 unPackedOperand = &operand;
1113 if (!unPackedOperand)
1115 return unPackedOperand;
1152 static FailureOr<std::tuple<GenericOp, Value>>
1153 pushDownUnPackOpThroughGenericOp(
RewriterBase &rewriter, GenericOp genericOp,
1155 bool poisonPaddingOk) {
1156 if (genericOp.getNumResults() != 1)
1159 if (hasGatherSemantics(genericOp))
1163 auto maybeUnPackedOperand = getUnPackedOperand(genericOp);
1164 if (
failed(maybeUnPackedOperand))
1166 OpOperand *unPackedOperand = *(maybeUnPackedOperand);
1169 linalg::UnPackOp producerUnPackOp =
1171 assert(producerUnPackOp &&
"expect a valid UnPackOp");
1173 if (!controlFn(unPackedOperand))
1177 getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp);
1183 bool requiresPadding =
1184 getPackedOperandDetails(rewriter, *packInfo, genericOp,
1185 genericOp.getDpsInitOperand(0), packedOperandMap);
1186 if (requiresPadding && !poisonPaddingOk)
1189 auto [packedOutOperand, packedOutIndexingMap] =
1190 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(),
1191 genericOp.getDpsInitOperand(0),
1193 auto destPack = packedOutOperand.getDefiningOp<linalg::PackOp>();
1200 Value dest = packedOutOperand;
1202 genericOp.getDpsInitOperand(0)->get().
getDefiningOp<tensor::EmptyOp>();
1203 if (initTensor || isGenericOutsNotUsed(genericOp)) {
1205 dest = destPack.getDest();
1212 auto maybeGenericOp =
1213 packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo,
1214 true, poisonPaddingOk);
1215 if (
failed(maybeGenericOp))
1217 GenericOp newGenericOp = *maybeGenericOp;
1219 newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0));
1223 return std::make_tuple(newGenericOp, newResult);
1225 auto mixedTiles = destPack.getMixedTiles();
1231 linalg::UnPackOp::create(rewriter, genericOp.getLoc(), newResult,
1236 return std::make_tuple(newGenericOp, unPackOpRes);
1240 struct PushDownUnPackOpThroughGenericOp :
public OpRewritePattern<GenericOp> {
1242 PushDownUnPackOpThroughGenericOp(
MLIRContext *context,
1244 bool poisonPaddingOk)
1246 poisonPaddingOk(std::move(poisonPaddingOk)) {}
1248 LogicalResult matchAndRewrite(GenericOp genericOp,
1250 auto genericAndRepl = pushDownUnPackOpThroughGenericOp(
1251 rewriter, genericOp, controlFn, poisonPaddingOk);
1252 if (
failed(genericAndRepl))
1254 rewriter.
replaceOp(genericOp, std::get<1>(*genericAndRepl));
1260 bool poisonPaddingOk;
1266 struct PushDownUnPackThroughPadOp :
public OpRewritePattern<tensor::PadOp> {
1270 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1272 linalg::UnPackOp unpackOp =
1273 padOp.getSource().getDefiningOp<linalg::UnPackOp>();
1277 if (!controlFn(&padOp.getSourceMutable()))
1282 llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
1284 llvm::SmallBitVector innerDims(paddedDims.size());
1286 innerDims.flip(dim);
1287 if (paddedDims.anyCommon(innerDims))
1290 Value paddingVal = padOp.getConstantPaddingValue();
1299 applyPermutationToVector<OpFoldResult>(lowPad,
outerDimsPerm);
1300 applyPermutationToVector<OpFoldResult>(highPad,
outerDimsPerm);
1304 lowPad.append(pointLoopsSize, rewriter.
getIndexAttr(0));
1305 highPad.append(pointLoopsSize, rewriter.
getIndexAttr(0));
1307 auto newPadOp = tensor::PadOp::create(rewriter, loc,
Type(),
1308 unpackOp.getSource(), lowPad, highPad,
1309 paddingVal, padOp.getNofold());
1312 Value outputUnPack =
1313 tensor::EmptyOp::create(rewriter, loc, padOp.getResultType().getShape(),
1314 padOp.getResultType().getElementType());
1316 Value replacement = linalg::UnPackOp::create(
1317 rewriter, loc, newPadOp.getResult(), outputUnPack,
innerDimsPos,
1328 struct SliceDimInfo {
1336 static FailureOr<SmallVector<OpOperand *>>
1337 getSliceOperands(GenericOp genericOp) {
1339 for (
auto operand : genericOp.getDpsInputOperands()) {
1340 auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
1343 sliceOperands.push_back(operand);
1345 if (sliceOperands.empty()) {
1348 return sliceOperands;
1354 static FailureOr<llvm::DenseMap<int64_t, SliceDimInfo>>
1355 getPartialSliceDimInfo(GenericOp genericOp,
OpOperand *sliceOperand) {
1356 tensor::ExtractSliceOp producerSliceOp =
1358 assert(producerSliceOp &&
"expect a valid ExtractSliceOp");
1364 genericOp.getContext(), producerSliceOp.getSourceType().getShape());
1367 genericOp.getMatchingIndexingMap(sliceOperand).getResults())) {
1376 if (!isa<AffineDimExpr>(expr)) {
1379 SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]};
1380 int64_t dimPos = cast<AffineDimExpr>(expr).getPosition();
1381 partialSliceDimMap[dimPos] = sliceDimInfo;
1385 for (
OpOperand &operand : genericOp->getOpOperands()) {
1386 if (operand == *sliceOperand) {
1389 AffineMap IndexingMap = genericOp.getMatchingIndexingMap(&operand);
1391 if (isa<AffineDimExpr>(expr)) {
1395 if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
1396 if (partialSliceDimMap.contains(dimExpr.getPosition())) {
1397 return WalkResult::interrupt();
1402 if (status.wasInterrupted()) {
1410 return partialSliceDimMap;
1413 static FailureOr<std::tuple<GenericOp, Value>>
1414 pushDownExtractSliceOpThroughGenericOp(
RewriterBase &rewriter,
1415 GenericOp genericOp,
1417 if (genericOp.getNumResults() != 1)
1419 genericOp,
"propagation through multi-result generic is unsupported.");
1420 if (hasGatherSemantics(genericOp))
1423 "propagation through generic with gather semantics is unsupported.");
1425 auto maybeSliceOperands = getSliceOperands(genericOp);
1426 if (
failed(maybeSliceOperands))
1431 bool foundValidOperand =
false;
1432 for (
auto currSliceOperand : sliceOperands) {
1433 if (controlFn(currSliceOperand)) {
1434 sliceOperand = currSliceOperand;
1435 foundValidOperand =
true;
1439 if (!foundValidOperand) {
1444 tensor::ExtractSliceOp producerSliceOp =
1446 assert(producerSliceOp &&
"expect a valid ExtractSliceOp");
1448 if (producerSliceOp.getSource().getType().getRank() !=
1449 producerSliceOp.getResult().getType().getRank()) {
1452 "propagation of rank-reducing extract slice is unsupported.");
1458 genericOp,
"propagation of strided extract slice is unsupported.");
1463 auto maybePartialSliceDimMap =
1464 getPartialSliceDimInfo(genericOp, sliceOperand);
1466 if (
failed(maybePartialSliceDimMap)) {
1470 auto partialSliceDimMap = *maybePartialSliceDimMap;
1473 genericOp.getIteratorTypesArray();
1474 bool hasPartialReductionDimSlice =
1475 llvm::any_of(partialSliceDimMap, [&](
const auto &slice) {
1476 int64_t sliceDim = slice.first;
1477 return iterators[sliceDim] == utils::IteratorType::reduction;
1481 Location loc = genericOp->getLoc();
1492 for (
auto [idx, operand] :
llvm::enumerate(genericOp.getDpsInputOperands())) {
1493 if (idx == OperandIndex && !hasPartialReductionDimSlice) {
1494 paddedInputs.push_back(producerSliceOp.getSource());
1497 AffineMap IndexingMap = genericOp.getMatchingIndexingMap(operand);
1499 paddedInputs.push_back(operand->get());
1507 if (!isa<AffineDimExpr>(expr)) {
1511 if (!partialSliceDimMap.contains(dimExpr.
getPosition())) {
1514 SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.
getPosition()];
1515 operandLowPads[idx] = sliceDimInfo.offset;
1516 operandHighPads[idx] =
1517 sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
1518 sliceDimInfo.sliceSize);
1520 auto paddingValue = ub::PoisonOp::create(
1522 auto paddedOperand = tensor::PadOp::create(
1523 rewriter, loc,
Type(), operand->get(), operandLowPads, operandHighPads,
1524 paddingValue,
false);
1525 paddedInputs.push_back(paddedOperand);
1528 genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0));
1530 auto outputShapeType =
1531 llvm::cast<ShapedType>(genericOp.getDpsInitOperand(0)->get().getType());
1533 outputShapeType.getShape(),
1534 [&](int64_t sz) ->
OpFoldResult { return rewriter.getIndexAttr(sz); });
1543 if (!isa<AffineDimExpr>(expr)) {
1547 if (!partialSliceDimMap.contains(dimExpr.
getPosition())) {
1550 SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.
getPosition()];
1551 outputLowPads[idx] = sliceDimInfo.offset;
1552 outputHighPads[idx] = sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
1553 sliceDimInfo.sliceSize);
1554 OutputShape[idx] = sliceDimInfo.outputSize;
1555 newSizes[idx] = sliceDimInfo.sliceSize;
1560 if (isGenericOutsNotUsed(genericOp)) {
1562 tensor::EmptyOp::create(rewriter, loc, OutputShape, outputElType);
1564 auto paddingValue = ub::PoisonOp::create(rewriter, loc, outputElType);
1565 newPadOutput = tensor::PadOp::create(
1566 rewriter, loc,
Type(), genericOp.getDpsInits()[0], outputLowPads,
1567 outputHighPads, paddingValue,
false);
1570 auto newGenericOp = linalg::GenericOp::create(
1571 rewriter, loc, newPadOutput.
getType(), paddedInputs, {newPadOutput},
1572 genericOp.getIndexingMapsArray(), genericOp.getIteratorTypesArray(),
1575 newGenericOp.getRegion().begin());
1577 auto extractOp = tensor::ExtractSliceOp::create(
1579 newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)),
1580 outputLowPads, newSizes, newStrides);
1581 Value extractRes = extractOp.getResult();
1583 return std::make_tuple(newGenericOp, extractRes);
1586 class PushDownExtractSliceOpThroughGenericOp final
1589 PushDownExtractSliceOpThroughGenericOp(
MLIRContext *context,
1593 LogicalResult matchAndRewrite(GenericOp genericOp,
1595 auto genericAndRepl =
1596 pushDownExtractSliceOpThroughGenericOp(rewriter, genericOp, controlFn);
1597 if (
failed(genericAndRepl))
1599 rewriter.
replaceOp(genericOp, std::get<1>(*genericAndRepl));
1612 bool PoisonPaddingOk) {
1613 patterns.insert<BubbleUpPackThroughPadOp, BubbleUpPackOpThroughReshapeOp,
1614 PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
1615 patterns.getContext(), controlPackUnPackPropagation);
1616 patterns.insert<BubbleUpPackOpThroughGenericOpPattern,
1617 PushDownUnPackOpThroughGenericOp>(
1618 patterns.getContext(), controlPackUnPackPropagation, PoisonPaddingOk);
1624 patterns.insert<PushDownExtractSliceOpThroughGenericOp>(
1625 patterns.getContext(), controlPackUnPackPropagation);
SmallVector< int64_t > outerDimsPerm
SmallVector< int64_t > innerDimsPos
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
Base type for affine expression.
RetT walk(FnT &&callback) const
Walk all of the AffineExpr's in this expression in postorder.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
IntegerAttr getIndexAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_iterator user_begin() const
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateDataLayoutPropagationPatterns(RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation, bool PoisonPaddingOk=false)
Patterns to bubble up or down data layout ops across other operations.
void populateExtractSliceSinkingPatterns(RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation)
Patterns to sink extract slice across other operations.
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
SmallVector< NamedAttribute > getPrunedAttributeList(OpTy op)
Returns an attribute list that excludes pre-defined attributes.
std::function< bool(OpOperand *opOperand)> ControlPropagationFn
Function type which is used to control propagation of linalg.pack/unpack ops.
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...