MLIR  20.0.0git
ReshapePatterns.cpp
Go to the documentation of this file.
1 //===- RankReductionPatterns.cpp - Patterns related to rank reductions ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/PatternMatch.h"
12 #include "llvm/Support/Debug.h"
13 
14 using namespace mlir;
15 using namespace mlir::tensor;
16 
17 namespace {
18 /// Fold expand_shape(extract_slice) ops that cancel itself out.
19 struct FoldExpandOfRankReducingExtract
20  : public OpRewritePattern<ExpandShapeOp> {
22 
23  LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp,
24  PatternRewriter &rewriter) const override {
25  RankedTensorType resultType = expandShapeOp.getResultType();
26  auto extractSliceOp =
27  expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>();
28  if (!extractSliceOp)
29  return failure();
30  RankedTensorType srcType = extractSliceOp.getSourceType();
31 
32  // Only cases where the ExpandShapeOp can be folded away entirely are
33  // supported. Moreover, only simple cases where the resulting ExtractSliceOp
34  // has no rank-reduction anymore are supported at the moment.
35  RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
36  srcType, extractSliceOp.getStaticOffsets(),
37  extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides());
38  if (nonReducingExtractType != resultType)
39  return failure();
40 
41  SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
42  SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
43  SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
44  rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
45  expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes,
46  mixedStrides);
47  return success();
48  }
49 };
50 
51 /// Fold collapse_shape which only removes static dimensions of size `1`
52 /// into extract_slice.
53 struct FoldUnPaddingCollapseIntoExtract
54  : public OpRewritePattern<tensor::CollapseShapeOp> {
56 
57  LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp,
58  PatternRewriter &rewriter) const override {
59  auto extractSliceOp =
60  collapseShapeOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>();
61  // Collapse cannot be folded away with multiple users of the extract slice
62  // and it is not necessarily beneficial to only convert the collapse into
63  // another extract slice.
64  if (!extractSliceOp || !extractSliceOp->hasOneUse())
65  return failure();
66 
67  // Only fold away simple collapse where all removed dimensions have static
68  // size `1`.
70  collapseShapeOp.getSrcType(), collapseShapeOp.getResultType());
72  return rewriter.notifyMatchFailure(collapseShapeOp,
73  "expected unpadding collapse");
74 
75  Value unPaddedExtractSlice = rewriter.create<tensor::ExtractSliceOp>(
76  extractSliceOp.getLoc(), collapseShapeOp.getResultType(),
77  extractSliceOp.getSource(), extractSliceOp.getMixedOffsets(),
78  extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
79  rewriter.replaceOp(collapseShapeOp, unPaddedExtractSlice);
80  return success();
81  }
82 };
83 
84 /// Fold insert_slice(collapse_shape) ops that cancel itself out.
85 template <typename OpTy>
86 struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> {
88 
89  LogicalResult matchAndRewrite(OpTy insertSliceOp,
90  PatternRewriter &rewriter) const override {
91  auto collapseShapeOp =
92  insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>();
93  if (!collapseShapeOp)
94  return failure();
95  RankedTensorType srcType = collapseShapeOp.getSrcType();
96 
97  // Only cases where the CollapseShapeOp can be folded away entirely are
98  // supported. Moreover, only simple cases where the resulting InsertSliceOp
99  // has no rank-reduction anymore are supported at the moment.
100  RankedTensorType nonReducingInsertType =
101  RankedTensorType::get(insertSliceOp.getStaticSizes(),
102  insertSliceOp.getDestType().getElementType());
103  if (nonReducingInsertType != srcType)
104  return failure();
105 
106  SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
107  SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
108  SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
109  rewriter.replaceOpWithNewOp<OpTy>(insertSliceOp, collapseShapeOp.getSrc(),
110  insertSliceOp.getDest(), mixedOffsets,
111  mixedSizes, mixedStrides);
112  return success();
113  }
114 };
115 
116 /// Fold expand_shape which only adds static dimensions of size `1`
117 /// into insert_slice.
118 template <typename OpTy>
119 struct FoldPaddingExpandIntoInsert : public OpRewritePattern<OpTy> {
121 
122  LogicalResult matchAndRewrite(OpTy insertSliceOp,
123  PatternRewriter &rewriter) const override {
124  auto expandShapeOp = insertSliceOp.getSource()
125  .template getDefiningOp<tensor::ExpandShapeOp>();
126  if (!expandShapeOp)
127  return failure();
128 
129  // Only fold away simple expansion where all added dimensions have static
130  // size `1`.
132  expandShapeOp.getResultType(), expandShapeOp.getSrcType());
134  return rewriter.notifyMatchFailure(insertSliceOp,
135  "expected rank increasing expansion");
136 
137  rewriter.modifyOpInPlace(insertSliceOp, [&]() {
138  insertSliceOp.getSourceMutable().assign(expandShapeOp.getSrc());
139  });
140  return success();
141  }
142 };
143 
144 /// Pattern to bubble up a tensor.expand_shape op through a producer
145 /// tensor.collapse_shape op that has non intersecting reassociations.
146 struct BubbleUpExpandThroughParallelCollapse
147  : public OpRewritePattern<tensor::ExpandShapeOp> {
149 
150  LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
151  PatternRewriter &rewriter) const override {
152  auto collapseOp =
153  expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
154  if (!collapseOp)
155  return failure();
156  auto expandReInds = expandOp.getReassociationIndices();
157  auto collapseReInds = collapseOp.getReassociationIndices();
158 
159  // Reshapes are parallel to each other if none of the reassociation indices
160  // have greater than 1 index for both reshapes.
161  for (auto [expandReassociation, collapseReassociation] :
162  llvm::zip_equal(expandReInds, collapseReInds)) {
163  if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
164  return failure();
165  }
166 
167  // Compute new reassociation indices and expanded/collaped shapes.
168  SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
169  Location loc = expandOp->getLoc();
170  SmallVector<OpFoldResult> collapseSizes =
171  tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc());
173  expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
174  SmallVector<OpFoldResult> newExpandSizes;
175  int64_t index = 0, expandIndex = 0, collapseIndex = 0;
176  for (auto [idx, collapseReassociation] : llvm::enumerate(collapseReInds)) {
177  if (collapseReassociation.size() != 1) {
178  ReassociationIndices newCollapseReassociation;
179  for (size_t i = 0; i < collapseReassociation.size(); ++i) {
180  newCollapseReassociation.push_back(index);
181  newExpandReInds.push_back({index++});
182  newExpandSizes.push_back(collapseSizes[collapseIndex++]);
183  }
184  newCollapseReInds.push_back(newCollapseReassociation);
185  expandIndex++;
186  continue;
187  }
188  ReassociationIndices newExpandReassociation;
189  auto expandReassociation = expandReInds[idx];
190  for (size_t i = 0; i < expandReassociation.size(); ++i) {
191  newExpandReassociation.push_back(index);
192  newCollapseReInds.push_back({index++});
193  newExpandSizes.push_back(expandSizes[expandIndex++]);
194  }
195  newExpandReInds.push_back(newExpandReassociation);
196  collapseIndex++;
197  }
198 
199  // Swap reshape order.
200  SmallVector<Value> dynamicSizes;
201  SmallVector<int64_t> staticSizes;
202  dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes);
203  auto expandResultType = expandOp.getResultType().clone(staticSizes);
204  auto newExpand = rewriter.create<tensor::ExpandShapeOp>(
205  loc, expandResultType, collapseOp.getSrc(), newExpandReInds,
206  newExpandSizes);
207  rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
208  expandOp, newExpand.getResult(), newCollapseReInds);
209  return success();
210  }
211 };
212 
213 } // namespace
214 
217  patterns
218  .add<FoldExpandOfRankReducingExtract, FoldUnPaddingCollapseIntoExtract,
219  FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
220  FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>,
221  FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>,
222  FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
223  patterns.getContext());
224 }
225 
228  patterns.add<BubbleUpExpandThroughParallelCollapse>(patterns.getContext());
229 }
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold tensor.expand_shape and tensor.collapse_shape into other o...
void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that bubble up tensor.expand_shape through tensor....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:66
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:387
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358