MLIR  19.0.0git
MergeConsecutiveInsertExtractSlicePatterns.cpp
Go to the documentation of this file.
1 //===- MergeConsecutiveInsertExtractSlicePatterns.cpp ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/OpDefinition.h"
15 #include "mlir/IR/PatternMatch.h"
16 
17 using namespace mlir;
18 using namespace mlir::tensor;
19 
20 namespace {
21 /// Merges consecutive tensor.extract_slice ops into one.
22 // TODO: move to FoldTensorSubsetOps and unify APIs with FoldMemRefAliasOps.
23 struct MergeConsecutiveExtractSlice : public OpRewritePattern<ExtractSliceOp> {
25 
26  LogicalResult matchAndRewrite(ExtractSliceOp nextOp,
27  PatternRewriter &rewriter) const override {
28  auto prevOp = nextOp.getSource().getDefiningOp<ExtractSliceOp>();
29  if (!prevOp)
30  return failure();
31 
32  SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
34  rewriter, nextOp.getLoc(), prevOp, nextOp, prevOp.getDroppedDims(),
35  newOffsets, newSizes, newStrides)))
36  return failure();
37 
38  rewriter.replaceOpWithNewOp<ExtractSliceOp>(nextOp, nextOp.getType(),
39  prevOp.getSource(), newOffsets,
40  newSizes, newStrides);
41  return success();
42  }
43 };
44 
45 /// Merges consecutive tensor.insert_slice ops into one.
46 // TODO: move to FoldTensorSubsetOps and unify APIs with FoldMemRefAliasOps.
47 template <typename OpTy>
48 struct MergeConsecutiveInsertSlice : public OpRewritePattern<OpTy> {
50 
51  LogicalResult matchAndRewrite(OpTy nextOp,
52  PatternRewriter &rewriter) const override {
53  auto prevOp = nextOp.getSource().template getDefiningOp<InsertSliceOp>();
54  if (!prevOp)
55  return failure();
56 
57  if (!prevOp.hasUnitStride() || !nextOp.hasUnitStride())
58  return failure();
59 
60  // The first insert_slice op should be rank reducing to make sure we cover
61  // the full source tensor to be inserted in the second insert_slice op.
63  isRankReducedType(prevOp.getDestType(), prevOp.getSourceType());
65  return failure();
66 
67  // Dynamic dimensions can pass rank reducing check in the above, e.g,
68  // inserting <?xf32> into <1x?x1xf32>. For such cases we cannot be certain
69  // the dynamic size covers the full tensor.
70  if (!prevOp.getSourceType().hasStaticShape() ||
71  !prevOp.getDestType().hasStaticShape())
72  return failure();
73 
74  rewriter.replaceOpWithNewOp<OpTy>(
75  nextOp, prevOp.getSource(), nextOp.getDest(), nextOp.getMixedOffsets(),
76  nextOp.getMixedSizes(), nextOp.getMixedStrides());
77  return success();
78  }
79 };
80 
81 /// Drop redundant rank expansion of insert_slice that are directly followed
82 /// by extract_slice. E.g.:
83 /// %0 = tensor.insert_slice ... : tensor<5x10xf32> into tensor<1x1x5x10xf32>
84 /// %1 = tensor.extract_slice %0[0, 0, 2, 3] [1, 1, 2, 2] [1, 1, 1, 1]
85 /// : tensor<1x1x5x10xf32> to tensor<2x2xf32>
86 struct DropRedundantRankExpansionOnExtractSliceOfInsertSlice
87  : public OpRewritePattern<ExtractSliceOp> {
89 
90  LogicalResult matchAndRewrite(ExtractSliceOp extractSliceOp,
91  PatternRewriter &rewriter) const override {
92  // Nothing to do if no dims are dropped.
93  llvm::SmallBitVector droppedDims = extractSliceOp.getDroppedDims();
94  if (droppedDims.none())
95  return failure();
96 
97  // Look for tensor.insert_slice op that has an inverse rank expansion.
98  auto insertSliceOp =
99  extractSliceOp.getSource().getDefiningOp<InsertSliceOp>();
100  if (!insertSliceOp)
101  return failure();
102  llvm::SmallBitVector expandedDims = insertSliceOp.getDroppedDims();
103 
104  // TODO: This could be extended to support cases where the dropped dims are
105  // a subset of the expanded dims.
106  if (expandedDims != droppedDims)
107  return failure();
108 
109  // The tensor.insert_slice may not be redundant if it has multiple users.
110  if (!insertSliceOp->hasOneUse())
111  return failure();
112 
113  // Only consider tensor.insert_slice ops that are pure rank-reductions.
114  // I.e., no elements are taken from the destination.
115  if (!isCastLikeInsertSliceOp(insertSliceOp))
116  return failure();
117 
118  // Extract directly from the source.
119  OpBuilder::InsertionGuard g(rewriter);
120  rewriter.setInsertionPoint(extractSliceOp);
121  SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
122  for (int64_t i = 0, e = extractSliceOp.getSourceType().getRank(); i < e;
123  ++i) {
124  if (droppedDims.test(i))
125  continue;
126  newOffsets.push_back(extractSliceOp.getMixedOffsets()[i]);
127  newSizes.push_back(extractSliceOp.getMixedSizes()[i]);
128  newStrides.push_back(extractSliceOp.getMixedStrides()[i]);
129  }
130  rewriter.replaceOpWithNewOp<ExtractSliceOp>(
131  extractSliceOp, /*source=*/insertSliceOp.getSource(), newOffsets,
132  newSizes, newStrides);
133  rewriter.eraseOp(insertSliceOp);
134  return success();
135  }
136 };
137 
138 /// Drop redundant rank expansion of insert_slice that direclty follows
139 /// extract_slice.
140 ///
141 /// This can be done when the insert_slice op purely expands ranks (adds unit
142 /// dims) and the extrace_slice drops corresponding unit dims. For example:
143 ///
144 /// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
145 /// : tensor<2x8xf32> to tensor<8xf32>
146 /// %inserted_slice = tensor.insert_slice %extracted_slice
147 /// into %dest[0, 0] [1, 8] [1, 1]
148 /// : tensor<8xf32> into tensor<1x8xf32>
149 ///
150 /// can be folded into:
151 ///
152 /// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
153 /// : tensor<2x8xf32> to tensor<1x8xf32>
154 struct DropRedundantRankExpansionOnInsertSliceOfExtractSlice final
155  : public OpRewritePattern<tensor::InsertSliceOp> {
157 
158  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
159  PatternRewriter &rewriter) const override {
160  auto extractSliceOp =
161  insertSliceOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
162  if (!extractSliceOp) {
163  return rewriter.notifyMatchFailure(insertSliceOp,
164  "source is not extract_slice");
165  }
166 
167  // Can't fold if the extract_slice op has other users.
168  if (!extractSliceOp->hasOneUse()) {
169  return rewriter.notifyMatchFailure(insertSliceOp,
170  "source has multi-uses");
171  }
172 
173  // Check if the insert_slice op purely expands ranks (add unit dims).
174  if (!isCastLikeInsertSliceOp(insertSliceOp)) {
175  return rewriter.notifyMatchFailure(insertSliceOp,
176  "insert_slice is not cast-like");
177  }
178 
179  llvm::SmallBitVector extractDroppedDims = extractSliceOp.getDroppedDims();
180  llvm::SmallBitVector insertDroppedDims = insertSliceOp.getDroppedDims();
181  // Can't fold if the insert_slice op expands to more dims.
182  if (extractDroppedDims.size() < insertDroppedDims.size()) {
183  return rewriter.notifyMatchFailure(insertSliceOp,
184  "insert_slice expands more dims");
185  }
186 
187  // Try to match the extract dropped dims to the insert dropped dims. This is
188  // done by scanning the dims of extract_slice and find the left-most one can
189  // match the dim of insert_slice. If a match is found, advance the dim of
190  // insert_slice to match the next one.
191  unsigned insertDimPos = 0;
192  for (unsigned extractDimPos = 0; extractDimPos < extractDroppedDims.size();
193  ++extractDimPos) {
194  // Matched all dims.
195  if (insertDimPos == insertDroppedDims.size())
196  break;
197 
198  bool isExtractDropped = extractDroppedDims[extractDimPos];
199  bool isInsertDropped = insertDroppedDims[insertDimPos];
200  // Match if both sides drop/keep the dim. Advance and match the next dim
201  // of insert_slice.
202  if (isExtractDropped == isInsertDropped) {
203  insertDimPos += 1;
204  } else if (!isExtractDropped && isInsertDropped) {
205  // Not enough extract dropped dims to match the insert dropped dims.
206  return rewriter.notifyMatchFailure(insertSliceOp,
207  "insert_slice drops more unit dims");
208  }
209  // If the dim is dropped by extract_slice and not by insert_slice, look
210  // the next dim of extract_slice to see if it can match the current dim of
211  // insert_slice.
212  }
213  // Can't match some insert dims.
214  if (insertDimPos != insertDroppedDims.size()) {
215  return rewriter.notifyMatchFailure(insertSliceOp,
216  "insert_slice has unmatched dims");
217  }
218 
219  rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
220  insertSliceOp, insertSliceOp.getType(), extractSliceOp.getSource(),
221  extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(),
222  extractSliceOp.getMixedStrides());
223  rewriter.eraseOp(extractSliceOp);
224 
225  return success();
226  }
227 };
228 } // namespace
229 
231  RewritePatternSet &patterns) {
232  patterns.add<MergeConsecutiveExtractSlice,
233  MergeConsecutiveInsertSlice<InsertSliceOp>,
234  MergeConsecutiveInsertSlice<ParallelInsertSliceOp>>(
235  patterns.getContext());
236 }
237 
239  RewritePatternSet &patterns) {
240  patterns.add<DropRedundantRankExpansionOnExtractSliceOfInsertSlice,
241  DropRedundantRankExpansionOnInsertSliceOfExtractSlice>(
242  patterns.getContext());
243 }
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:775
MLIRContext * getContext() const
Definition: PatternMatch.h:812
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:836
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:708
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:534
LogicalResult mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > producerOffsets, ArrayRef< OpFoldResult > producerSizes, ArrayRef< OpFoldResult > producerStrides, const llvm::SmallBitVector &droppedProducerDims, ArrayRef< OpFoldResult > consumerOffsets, ArrayRef< OpFoldResult > consumerSizes, ArrayRef< OpFoldResult > consumerStrides, SmallVector< OpFoldResult > &combinedOffsets, SmallVector< OpFoldResult > &combinedSizes, SmallVector< OpFoldResult > &combinedStrides)
Fills the combinedOffsets, combinedSizes and combinedStrides to use when combining a producer slice i...
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
bool isCastLikeInsertSliceOp(InsertSliceOp op)
A tensor.insert_slice is a cast-like operation if it merely rank-extends the source tensor or inserts...
Definition: Utils.cpp:142
void populateDropRedundantInsertSliceRankExpansionPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that drop redundant tensor.insert_slice rank expansions.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:369
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:361