MLIR 23.0.0git
DropRedundantRankExpansionPatterns.cpp
Go to the documentation of this file.
1//===- DropRedundantRankExpansionPatterns.cpp -----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
16
17using namespace mlir;
18using namespace mlir::tensor;
19
20namespace {
21/// Drop redundant rank expansion of insert_slice that are directly followed
22/// by extract_slice. E.g.:
23/// %0 = tensor.insert_slice %in... : tensor<5x10xf32> into tensor<1x1x5x10xf32>
24/// %1 = tensor.extract_slice %0[0, 0, 2, 3] [1, 1, 2, 2] [1, 1, 1, 1]
25/// : tensor<1x1x5x10xf32> to tensor<2x2xf32>
26///
27/// can be folded into:
28///
29/// %1 = tensor.extract_slice %in[2, 3] [2, 2] [1, 1]
30/// : tensor<5x10xf32> to tensor<2x2xf32>
31struct DropRedundantRankExpansionOnExtractSliceOfInsertSlice
32 : public OpRewritePattern<ExtractSliceOp> {
34
35 LogicalResult matchAndRewrite(ExtractSliceOp extractSliceOp,
36 PatternRewriter &rewriter) const override {
37 // Nothing to do if no dims are dropped.
38 llvm::SmallBitVector droppedDims = extractSliceOp.getDroppedDims();
39 if (droppedDims.none())
40 return failure();
41
42 // Look for tensor.insert_slice op that has an inverse rank expansion.
43 auto insertSliceOp =
44 extractSliceOp.getSource().getDefiningOp<InsertSliceOp>();
45 if (!insertSliceOp)
46 return failure();
47 llvm::SmallBitVector expandedDims = insertSliceOp.getDroppedDims();
48
49 // Support cases where the expanded dims are a subset of the droped dims.
50 if (!expandedDims.subsetOf(droppedDims))
51 return failure();
52
53 // The tensor.insert_slice may not be redundant if it has multiple users.
54 if (!insertSliceOp->hasOneUse())
55 return failure();
56
57 // Only consider tensor.insert_slice ops that are pure rank-reductions.
58 // I.e., no elements are taken from the destination.
59 if (!isCastLikeInsertSliceOp(insertSliceOp))
60 return failure();
61
62 // Extract directly from the source.
63 OpBuilder::InsertionGuard g(rewriter);
64 rewriter.setInsertionPoint(extractSliceOp);
65 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
66 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
67 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
68 SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
69 for (int64_t i = 0, e = extractSliceOp.getSourceType().getRank(); i < e;
70 ++i) {
71 if (expandedDims.test(i))
72 continue;
73 newOffsets.push_back(mixedOffsets[i]);
74 newSizes.push_back(mixedSizes[i]);
75 newStrides.push_back(mixedStrides[i]);
76 }
77 rewriter.replaceOpWithNewOp<ExtractSliceOp>(
78 extractSliceOp, extractSliceOp.getResultType(),
79 /*source=*/insertSliceOp.getSource(), newOffsets, newSizes, newStrides);
80 rewriter.eraseOp(insertSliceOp);
81 return success();
82 }
83};
84
85/// Drop redundant rank expansion of insert_slice that direclty follows
86/// extract_slice.
87///
88/// This can be done when the insert_slice op purely expands ranks (adds unit
89/// dims) and the extrace_slice drops corresponding unit dims. For example:
90///
91/// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
92/// : tensor<2x8xf32> to tensor<8xf32>
93/// %inserted_slice = tensor.insert_slice %extracted_slice
94/// into %dest[0, 0] [1, 8] [1, 1]
95/// : tensor<8xf32> into tensor<1x8xf32>
96///
97/// can be folded into:
98///
99/// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
100/// : tensor<2x8xf32> to tensor<1x8xf32>
101struct DropRedundantRankExpansionOnInsertSliceOfExtractSlice final
102 : public OpRewritePattern<tensor::InsertSliceOp> {
103 using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
104
105 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
106 PatternRewriter &rewriter) const override {
107 auto extractSliceOp =
108 insertSliceOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
109 if (!extractSliceOp) {
110 return rewriter.notifyMatchFailure(insertSliceOp,
111 "source is not extract_slice");
112 }
113
114 // Can't fold if the extract_slice op has other users.
115 if (!extractSliceOp->hasOneUse()) {
116 return rewriter.notifyMatchFailure(insertSliceOp,
117 "source has multi-uses");
118 }
119
120 // Check if the insert_slice op purely expands ranks (add unit dims).
121 if (!isCastLikeInsertSliceOp(insertSliceOp)) {
122 return rewriter.notifyMatchFailure(insertSliceOp,
123 "insert_slice is not cast-like");
124 }
125
126 llvm::SmallBitVector extractDroppedDims = extractSliceOp.getDroppedDims();
127 llvm::SmallBitVector insertDroppedDims = insertSliceOp.getDroppedDims();
128 // Can't fold if the insert_slice op expands to more dims.
129 if (extractDroppedDims.size() < insertDroppedDims.size()) {
130 return rewriter.notifyMatchFailure(insertSliceOp,
131 "insert_slice expands more dims");
132 }
133
134 // Try to match the extract dropped dims to the insert dropped dims. This is
135 // done by scanning the dims of extract_slice and find the left-most one can
136 // match the dim of insert_slice. If a match is found, advance the dim of
137 // insert_slice to match the next one.
138 unsigned insertDimPos = 0;
139 for (unsigned extractDimPos = 0; extractDimPos < extractDroppedDims.size();
140 ++extractDimPos) {
141 // Matched all dims.
142 if (insertDimPos == insertDroppedDims.size())
143 break;
144
145 bool isExtractDropped = extractDroppedDims[extractDimPos];
146 bool isInsertDropped = insertDroppedDims[insertDimPos];
147 // Match if both sides drop/keep the dim. Advance and match the next dim
148 // of insert_slice.
149 if (isExtractDropped == isInsertDropped) {
150 insertDimPos += 1;
151 } else if (!isExtractDropped && isInsertDropped) {
152 // Not enough extract dropped dims to match the insert dropped dims.
153 return rewriter.notifyMatchFailure(insertSliceOp,
154 "insert_slice drops more unit dims");
155 }
156 // If the dim is dropped by extract_slice and not by insert_slice, look
157 // the next dim of extract_slice to see if it can match the current dim of
158 // insert_slice.
159 }
160 // Can't match some insert dims.
161 if (insertDimPos != insertDroppedDims.size()) {
162 return rewriter.notifyMatchFailure(insertSliceOp,
163 "insert_slice has unmatched dims");
164 }
165
166 rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
167 insertSliceOp, insertSliceOp.getType(), extractSliceOp.getSource(),
168 extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(),
169 extractSliceOp.getMixedStrides());
170 rewriter.eraseOp(extractSliceOp);
171
172 return success();
173 }
174};
175} // namespace
176
178 RewritePatternSet &patterns) {
179 patterns.add<DropRedundantRankExpansionOnExtractSliceOfInsertSlice,
180 DropRedundantRankExpansionOnInsertSliceOfExtractSlice>(
181 patterns.getContext());
182}
return success()
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
bool isCastLikeInsertSliceOp(InsertSliceOp op)
A tensor.insert_slice is a cast-like operation if it merely rank-extends the source tensor or inserts...
Definition Utils.cpp:125
void populateDropRedundantInsertSliceRankExpansionPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that drop redundant tensor.insert_slice rank expansions.
Include the generated interface declarations.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...