46 struct BubbleUpExtractSliceOpPattern
50 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
52 Value source = sliceOp.getSource();
55 return rewriter.notifyMatchFailure(sliceOp,
56 "expected source to be linalg op");
61 if (!linalgOp->hasOneUse()) {
62 return rewriter.notifyMatchFailure(sliceOp,
63 "expected single use of linalg op");
66 if (linalgOp.getNumDpsInits() != 1) {
67 return rewriter.notifyMatchFailure(sliceOp,
68 "expected single output of linalg op");
71 if (!linalgOp.hasPureTensorSemantics()) {
72 return rewriter.notifyMatchFailure(sliceOp,
73 "expected tensor of linalg op");
76 if (!sliceOp.hasUnitStride())
77 return rewriter.notifyMatchFailure(sliceOp,
"expected unit stride");
79 if (sliceOp.getType().getRank() != sliceOp.getSourceType().getRank()) {
80 return rewriter.notifyMatchFailure(sliceOp,
"expected no rank reduction");
83 OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
84 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(outOperand);
86 return rewriter.notifyMatchFailure(
87 sliceOp,
"expected a projected permutation for output");
90 auto linalgLoc = linalgOp.getLoc();
92 linalgOp.createFlatListOfOperandDims(rewriter, linalgLoc);
93 AffineMap shapeSizesToLoopsMap = linalgOp.getShapesToLoopsMap();
94 if (!shapeSizesToLoopsMap) {
95 return rewriter.notifyMatchFailure(
96 linalgOp,
"failed to get loops map from shape sizes");
100 rewriter, linalgLoc, shapeSizesToLoopsMap, allShapeSizes);
107 rewriter.getIndexAttr(0));
110 unsigned position = cast<AffineDimExpr>(result.value()).getPosition();
111 tileOffsets[position] = sliceOp.getMixedOffsets()[result.index()];
112 tileSizes[position] = sliceOp.getMixedSizes()[result.index()];
118 tileOffsets, tileSizes, sizeBounds,
122 for (
OpOperand &opOperand : linalgOp.getDpsInitsMutable())
123 resultTensorTypes.push_back(
124 tiledOperands[opOperand.getOperandNumber()].getType());
127 clone(rewriter, linalgOp, resultTensorTypes, tiledOperands);
128 rewriter.replaceOp(sliceOp, newOp->
getResults());
137 patterns.
add<BubbleUpExtractSliceOpPattern>(context);
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
ArrayRef< AffineExpr > getResults() const
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
result_range getResults()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Patterns that are used to bubble up extract slice op above linalg op.
SmallVector< Value > makeTiledShapes(OpBuilder &builder, Location loc, LinalgOp linalgOp, ValueRange valuesToTile, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds, bool omitPartialTileCheck)
Creates extract_slice/subview ops for all valuesToTile of the given linalgOp with builder,...
Include the generated interface declarations.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...