46 struct BubbleUpExtractSliceOpPattern
50 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
52 Value source = sliceOp.getSource();
55 return rewriter.notifyMatchFailure(sliceOp,
56 "expected source to be linalg op");
61 if (!linalgOp->hasOneUse()) {
62 return rewriter.notifyMatchFailure(sliceOp,
63 "expected single use of linalg op");
66 if (linalgOp.getNumDpsInits() != 1) {
67 return rewriter.notifyMatchFailure(sliceOp,
68 "expected single output of linalg op");
71 if (!linalgOp.hasPureTensorSemantics()) {
72 return rewriter.notifyMatchFailure(sliceOp,
73 "expected tensor of linalg op");
76 if (!sliceOp.hasUnitStride())
77 return rewriter.notifyMatchFailure(sliceOp,
"expected unit stride");
79 if (sliceOp.getType().getRank() != sliceOp.getSourceType().getRank()) {
80 return rewriter.notifyMatchFailure(sliceOp,
"expected no rank reduction");
83 OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
84 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(outOperand);
86 return rewriter.notifyMatchFailure(
87 sliceOp,
"expected a projected permutation for output");
90 auto linalgLoc = linalgOp.getLoc();
92 linalgOp.createFlatListOfOperandDims(rewriter, linalgLoc);
93 AffineMap shapeSizesToLoopsMap = linalgOp.getShapesToLoopsMap();
94 if (!shapeSizesToLoopsMap) {
95 return rewriter.notifyMatchFailure(
96 linalgOp,
"failed to get loops map from shape sizes");
100 rewriter, linalgLoc, shapeSizesToLoopsMap, allShapeSizes);
107 rewriter.getIndexAttr(0));
110 unsigned position = cast<AffineDimExpr>(result.value()).getPosition();
111 tileOffsets[position] = sliceOp.getMixedOffsets()[result.index()];
112 tileSizes[position] = sliceOp.getMixedSizes()[result.index()];
118 tileOffsets, tileSizes, sizeBounds,
122 for (
OpOperand &opOperand : linalgOp.getDpsInitsMutable())
123 resultTensorTypes.push_back(
124 tiledOperands[opOperand.getOperandNumber()].getType());
127 clone(rewriter, linalgOp, resultTensorTypes, tiledOperands);
128 rewriter.replaceOp(sliceOp, newOp->
getResults());
136 auto *context =
patterns.getContext();
137 patterns.add<BubbleUpExtractSliceOpPattern>(context);
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
ArrayRef< AffineExpr > getResults() const
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
result_range getResults()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Patterns that are used to bubble up extract slice op above linalg op.
SmallVector< Value > makeTiledShapes(OpBuilder &builder, Location loc, LinalgOp linalgOp, ValueRange valuesToTile, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds, bool omitPartialTileCheck)
Creates extract_slice/subview ops for all valuesToTile of the given linalgOp with builder,...
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...