MLIR  20.0.0git
MergeConsecutiveInsertExtractSlicePatterns.cpp
Go to the documentation of this file.
1 //===- MergeConsecutiveInsertExtractSlicePatterns.cpp ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/OpDefinition.h"
15 #include "mlir/IR/PatternMatch.h"
16 
17 using namespace mlir;
18 using namespace mlir::tensor;
19 
20 namespace {
21 /// Merges consecutive tensor.extract_slice ops into one.
22 // TODO: move to FoldTensorSubsetOps and unify APIs with FoldMemRefAliasOps.
23 struct MergeConsecutiveExtractSlice : public OpRewritePattern<ExtractSliceOp> {
25 
26  LogicalResult matchAndRewrite(ExtractSliceOp nextOp,
27  PatternRewriter &rewriter) const override {
28  auto prevOp = nextOp.getSource().getDefiningOp<ExtractSliceOp>();
29  if (!prevOp)
30  return failure();
31 
32  SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
34  rewriter, nextOp.getLoc(), prevOp, nextOp, prevOp.getDroppedDims(),
35  newOffsets, newSizes, newStrides)))
36  return failure();
37 
38  rewriter.replaceOpWithNewOp<ExtractSliceOp>(nextOp, nextOp.getType(),
39  prevOp.getSource(), newOffsets,
40  newSizes, newStrides);
41  return success();
42  }
43 };
44 
45 /// Merges consecutive tensor.insert_slice ops into one.
46 // TODO: move to FoldTensorSubsetOps and unify APIs with FoldMemRefAliasOps.
47 template <typename OpTy>
48 struct MergeConsecutiveInsertSlice : public OpRewritePattern<OpTy> {
50 
51  LogicalResult matchAndRewrite(OpTy nextOp,
52  PatternRewriter &rewriter) const override {
53  auto prevOp = nextOp.getSource().template getDefiningOp<InsertSliceOp>();
54  if (!prevOp)
55  return failure();
56 
57  if (!prevOp.hasUnitStride() || !nextOp.hasUnitStride())
58  return failure();
59 
60  // The first insert_slice op should be rank reducing to make sure we cover
61  // the full source tensor to be inserted in the second insert_slice op.
63  isRankReducedType(prevOp.getDestType(), prevOp.getSourceType());
65  return failure();
66 
67  // Dynamic dimensions can pass rank reducing check in the above, e.g,
68  // inserting <?xf32> into <1x?x1xf32>. For such cases we cannot be certain
69  // the dynamic size covers the full tensor.
70  if (!prevOp.getSourceType().hasStaticShape() ||
71  !prevOp.getDestType().hasStaticShape())
72  return failure();
73 
74  rewriter.replaceOpWithNewOp<OpTy>(
75  nextOp, prevOp.getSource(), nextOp.getDest(), nextOp.getMixedOffsets(),
76  nextOp.getMixedSizes(), nextOp.getMixedStrides());
77  return success();
78  }
79 };
80 
81 /// Drop redundant rank expansion of insert_slice that are directly followed
82 /// by extract_slice. E.g.:
83 /// %0 = tensor.insert_slice ... : tensor<5x10xf32> into tensor<1x1x5x10xf32>
84 /// %1 = tensor.extract_slice %0[0, 0, 2, 3] [1, 1, 2, 2] [1, 1, 1, 1]
85 /// : tensor<1x1x5x10xf32> to tensor<2x2xf32>
86 struct DropRedundantRankExpansionOnExtractSliceOfInsertSlice
87  : public OpRewritePattern<ExtractSliceOp> {
89 
90  LogicalResult matchAndRewrite(ExtractSliceOp extractSliceOp,
91  PatternRewriter &rewriter) const override {
92  // Nothing to do if no dims are dropped.
93  llvm::SmallBitVector droppedDims = extractSliceOp.getDroppedDims();
94  if (droppedDims.none())
95  return failure();
96 
97  // Look for tensor.insert_slice op that has an inverse rank expansion.
98  auto insertSliceOp =
99  extractSliceOp.getSource().getDefiningOp<InsertSliceOp>();
100  if (!insertSliceOp)
101  return failure();
102  llvm::SmallBitVector expandedDims = insertSliceOp.getDroppedDims();
103 
104  // TODO: This could be extended to support cases where the dropped dims are
105  // a subset of the expanded dims.
106  if (expandedDims != droppedDims)
107  return failure();
108 
109  // The tensor.insert_slice may not be redundant if it has multiple users.
110  if (!insertSliceOp->hasOneUse())
111  return failure();
112 
113  // Only consider tensor.insert_slice ops that are pure rank-reductions.
114  // I.e., no elements are taken from the destination.
115  if (!isCastLikeInsertSliceOp(insertSliceOp))
116  return failure();
117 
118  // Extract directly from the source.
119  OpBuilder::InsertionGuard g(rewriter);
120  rewriter.setInsertionPoint(extractSliceOp);
121  SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
122  for (int64_t i = 0, e = extractSliceOp.getSourceType().getRank(); i < e;
123  ++i) {
124  if (droppedDims.test(i))
125  continue;
126  newOffsets.push_back(extractSliceOp.getMixedOffsets()[i]);
127  newSizes.push_back(extractSliceOp.getMixedSizes()[i]);
128  newStrides.push_back(extractSliceOp.getMixedStrides()[i]);
129  }
130  rewriter.replaceOpWithNewOp<ExtractSliceOp>(
131  extractSliceOp, /*source=*/insertSliceOp.getSource(), newOffsets,
132  newSizes, newStrides);
133  rewriter.eraseOp(insertSliceOp);
134  return success();
135  }
136 };
137 
138 /// Drop redundant rank expansion of insert_slice that direclty follows
139 /// extract_slice.
140 ///
141 /// This can be done when the insert_slice op purely expands ranks (adds unit
142 /// dims) and the extrace_slice drops corresponding unit dims. For example:
143 ///
144 /// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
145 /// : tensor<2x8xf32> to tensor<8xf32>
146 /// %inserted_slice = tensor.insert_slice %extracted_slice
147 /// into %dest[0, 0] [1, 8] [1, 1]
148 /// : tensor<8xf32> into tensor<1x8xf32>
149 ///
150 /// can be folded into:
151 ///
152 /// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
153 /// : tensor<2x8xf32> to tensor<1x8xf32>
154 struct DropRedundantRankExpansionOnInsertSliceOfExtractSlice final
155  : public OpRewritePattern<tensor::InsertSliceOp> {
157 
158  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
159  PatternRewriter &rewriter) const override {
160  auto extractSliceOp =
161  insertSliceOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
162  if (!extractSliceOp) {
163  return rewriter.notifyMatchFailure(insertSliceOp,
164  "source is not extract_slice");
165  }
166 
167  // Can't fold if the extract_slice op has other users.
168  if (!extractSliceOp->hasOneUse()) {
169  return rewriter.notifyMatchFailure(insertSliceOp,
170  "source has multi-uses");
171  }
172 
173  // Check if the insert_slice op purely expands ranks (add unit dims).
174  if (!isCastLikeInsertSliceOp(insertSliceOp)) {
175  return rewriter.notifyMatchFailure(insertSliceOp,
176  "insert_slice is not cast-like");
177  }
178 
179  llvm::SmallBitVector extractDroppedDims = extractSliceOp.getDroppedDims();
180  llvm::SmallBitVector insertDroppedDims = insertSliceOp.getDroppedDims();
181  // Can't fold if the insert_slice op expands to more dims.
182  if (extractDroppedDims.size() < insertDroppedDims.size()) {
183  return rewriter.notifyMatchFailure(insertSliceOp,
184  "insert_slice expands more dims");
185  }
186 
187  // Try to match the extract dropped dims to the insert dropped dims. This is
188  // done by scanning the dims of extract_slice and find the left-most one can
189  // match the dim of insert_slice. If a match is found, advance the dim of
190  // insert_slice to match the next one.
191  unsigned insertDimPos = 0;
192  for (unsigned extractDimPos = 0; extractDimPos < extractDroppedDims.size();
193  ++extractDimPos) {
194  // Matched all dims.
195  if (insertDimPos == insertDroppedDims.size())
196  break;
197 
198  bool isExtractDropped = extractDroppedDims[extractDimPos];
199  bool isInsertDropped = insertDroppedDims[insertDimPos];
200  // Match if both sides drop/keep the dim. Advance and match the next dim
201  // of insert_slice.
202  if (isExtractDropped == isInsertDropped) {
203  insertDimPos += 1;
204  } else if (!isExtractDropped && isInsertDropped) {
205  // Not enough extract dropped dims to match the insert dropped dims.
206  return rewriter.notifyMatchFailure(insertSliceOp,
207  "insert_slice drops more unit dims");
208  }
209  // If the dim is dropped by extract_slice and not by insert_slice, look
210  // the next dim of extract_slice to see if it can match the current dim of
211  // insert_slice.
212  }
213  // Can't match some insert dims.
214  if (insertDimPos != insertDroppedDims.size()) {
215  return rewriter.notifyMatchFailure(insertSliceOp,
216  "insert_slice has unmatched dims");
217  }
218 
219  rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
220  insertSliceOp, insertSliceOp.getType(), extractSliceOp.getSource(),
221  extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(),
222  extractSliceOp.getMixedStrides());
223  rewriter.eraseOp(extractSliceOp);
224 
225  return success();
226  }
227 };
228 } // namespace
229 
231  RewritePatternSet &patterns) {
232  patterns.add<MergeConsecutiveExtractSlice,
233  MergeConsecutiveInsertSlice<InsertSliceOp>,
234  MergeConsecutiveInsertSlice<ParallelInsertSliceOp>>(
235  patterns.getContext());
236 }
237 
239  RewritePatternSet &patterns) {
240  patterns.add<DropRedundantRankExpansionOnExtractSliceOfInsertSlice,
241  DropRedundantRankExpansionOnInsertSliceOfExtractSlice>(
242  patterns.getContext());
243 }
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:351
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:401
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
LogicalResult mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > producerOffsets, ArrayRef< OpFoldResult > producerSizes, ArrayRef< OpFoldResult > producerStrides, const llvm::SmallBitVector &droppedProducerDims, ArrayRef< OpFoldResult > consumerOffsets, ArrayRef< OpFoldResult > consumerSizes, ArrayRef< OpFoldResult > consumerStrides, SmallVector< OpFoldResult > &combinedOffsets, SmallVector< OpFoldResult > &combinedSizes, SmallVector< OpFoldResult > &combinedStrides)
Fills the combinedOffsets, combinedSizes and combinedStrides to use when combining a producer slice i...
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
bool isCastLikeInsertSliceOp(InsertSliceOp op)
A tensor.insert_slice is a cast-like operation if it merely rank-extends the source tensor or inserts...
Definition: Utils.cpp:131
void populateDropRedundantInsertSliceRankExpansionPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that drop redundant tensor.insert_slice rank expansions.
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:381
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362