MLIR 23.0.0git
DropRedundantRankExpansionPatterns.cpp
Go to the documentation of this file.
1//===- DropRedundantRankExpansionPatterns.cpp -----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
16
17using namespace mlir;
18using namespace mlir::tensor;
19
20namespace {
21/// Drop redundant rank expansion of insert_slice that are directly followed
22/// by extract_slice. E.g.:
23/// %0 = tensor.insert_slice ... : tensor<5x10xf32> into tensor<1x1x5x10xf32>
24/// %1 = tensor.extract_slice %0[0, 0, 2, 3] [1, 1, 2, 2] [1, 1, 1, 1]
25/// : tensor<1x1x5x10xf32> to tensor<2x2xf32>
26struct DropRedundantRankExpansionOnExtractSliceOfInsertSlice
27 : public OpRewritePattern<ExtractSliceOp> {
29
30 LogicalResult matchAndRewrite(ExtractSliceOp extractSliceOp,
31 PatternRewriter &rewriter) const override {
32 // Nothing to do if no dims are dropped.
33 llvm::SmallBitVector droppedDims = extractSliceOp.getDroppedDims();
34 if (droppedDims.none())
35 return failure();
36
37 // Look for tensor.insert_slice op that has an inverse rank expansion.
38 auto insertSliceOp =
39 extractSliceOp.getSource().getDefiningOp<InsertSliceOp>();
40 if (!insertSliceOp)
41 return failure();
42 llvm::SmallBitVector expandedDims = insertSliceOp.getDroppedDims();
43
44 // TODO: This could be extended to support cases where the dropped dims are
45 // a subset of the expanded dims.
46 if (expandedDims != droppedDims)
47 return failure();
48
49 // The tensor.insert_slice may not be redundant if it has multiple users.
50 if (!insertSliceOp->hasOneUse())
51 return failure();
52
53 // Only consider tensor.insert_slice ops that are pure rank-reductions.
54 // I.e., no elements are taken from the destination.
55 if (!isCastLikeInsertSliceOp(insertSliceOp))
56 return failure();
57
58 // Extract directly from the source.
59 OpBuilder::InsertionGuard g(rewriter);
60 rewriter.setInsertionPoint(extractSliceOp);
61 SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
62 for (int64_t i = 0, e = extractSliceOp.getSourceType().getRank(); i < e;
63 ++i) {
64 if (droppedDims.test(i))
65 continue;
66 newOffsets.push_back(extractSliceOp.getMixedOffsets()[i]);
67 newSizes.push_back(extractSliceOp.getMixedSizes()[i]);
68 newStrides.push_back(extractSliceOp.getMixedStrides()[i]);
69 }
70 rewriter.replaceOpWithNewOp<ExtractSliceOp>(
71 extractSliceOp, /*source=*/insertSliceOp.getSource(), newOffsets,
72 newSizes, newStrides);
73 rewriter.eraseOp(insertSliceOp);
74 return success();
75 }
76};
77
78/// Drop redundant rank expansion of insert_slice that direclty follows
79/// extract_slice.
80///
81/// This can be done when the insert_slice op purely expands ranks (adds unit
82/// dims) and the extrace_slice drops corresponding unit dims. For example:
83///
84/// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
85/// : tensor<2x8xf32> to tensor<8xf32>
86/// %inserted_slice = tensor.insert_slice %extracted_slice
87/// into %dest[0, 0] [1, 8] [1, 1]
88/// : tensor<8xf32> into tensor<1x8xf32>
89///
90/// can be folded into:
91///
92/// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
93/// : tensor<2x8xf32> to tensor<1x8xf32>
94struct DropRedundantRankExpansionOnInsertSliceOfExtractSlice final
95 : public OpRewritePattern<tensor::InsertSliceOp> {
96 using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
97
98 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
99 PatternRewriter &rewriter) const override {
100 auto extractSliceOp =
101 insertSliceOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
102 if (!extractSliceOp) {
103 return rewriter.notifyMatchFailure(insertSliceOp,
104 "source is not extract_slice");
105 }
106
107 // Can't fold if the extract_slice op has other users.
108 if (!extractSliceOp->hasOneUse()) {
109 return rewriter.notifyMatchFailure(insertSliceOp,
110 "source has multi-uses");
111 }
112
113 // Check if the insert_slice op purely expands ranks (add unit dims).
114 if (!isCastLikeInsertSliceOp(insertSliceOp)) {
115 return rewriter.notifyMatchFailure(insertSliceOp,
116 "insert_slice is not cast-like");
117 }
118
119 llvm::SmallBitVector extractDroppedDims = extractSliceOp.getDroppedDims();
120 llvm::SmallBitVector insertDroppedDims = insertSliceOp.getDroppedDims();
121 // Can't fold if the insert_slice op expands to more dims.
122 if (extractDroppedDims.size() < insertDroppedDims.size()) {
123 return rewriter.notifyMatchFailure(insertSliceOp,
124 "insert_slice expands more dims");
125 }
126
127 // Try to match the extract dropped dims to the insert dropped dims. This is
128 // done by scanning the dims of extract_slice and find the left-most one can
129 // match the dim of insert_slice. If a match is found, advance the dim of
130 // insert_slice to match the next one.
131 unsigned insertDimPos = 0;
132 for (unsigned extractDimPos = 0; extractDimPos < extractDroppedDims.size();
133 ++extractDimPos) {
134 // Matched all dims.
135 if (insertDimPos == insertDroppedDims.size())
136 break;
137
138 bool isExtractDropped = extractDroppedDims[extractDimPos];
139 bool isInsertDropped = insertDroppedDims[insertDimPos];
140 // Match if both sides drop/keep the dim. Advance and match the next dim
141 // of insert_slice.
142 if (isExtractDropped == isInsertDropped) {
143 insertDimPos += 1;
144 } else if (!isExtractDropped && isInsertDropped) {
145 // Not enough extract dropped dims to match the insert dropped dims.
146 return rewriter.notifyMatchFailure(insertSliceOp,
147 "insert_slice drops more unit dims");
148 }
149 // If the dim is dropped by extract_slice and not by insert_slice, look
150 // the next dim of extract_slice to see if it can match the current dim of
151 // insert_slice.
152 }
153 // Can't match some insert dims.
154 if (insertDimPos != insertDroppedDims.size()) {
155 return rewriter.notifyMatchFailure(insertSliceOp,
156 "insert_slice has unmatched dims");
157 }
158
159 rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
160 insertSliceOp, insertSliceOp.getType(), extractSliceOp.getSource(),
161 extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(),
162 extractSliceOp.getMixedStrides());
163 rewriter.eraseOp(extractSliceOp);
164
165 return success();
166 }
167};
168} // namespace
169
171 RewritePatternSet &patterns) {
172 patterns.add<DropRedundantRankExpansionOnExtractSliceOfInsertSlice,
173 DropRedundantRankExpansionOnInsertSliceOfExtractSlice>(
174 patterns.getContext());
175}
return success()
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
bool isCastLikeInsertSliceOp(InsertSliceOp op)
A tensor.insert_slice is a cast-like operation if it merely rank-extends the source tensor or inserts...
Definition Utils.cpp:125
void populateDropRedundantInsertSliceRankExpansionPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that drop redundant tensor.insert_slice rank expansions.
Include the generated interface declarations.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...