MLIR  20.0.0git
ReshapePatterns.cpp
Go to the documentation of this file.
1 //===- RankReductionPatterns.cpp - Patterns related to rank reductions ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/PatternMatch.h"
12 #include "llvm/Support/Debug.h"
13 
14 using namespace mlir;
15 using namespace mlir::tensor;
16 
17 namespace {
18 /// Fold expand_shape(extract_slice) ops that cancel itself out.
19 struct FoldExpandOfRankReducingExtract
20  : public OpRewritePattern<ExpandShapeOp> {
22 
23  LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp,
24  PatternRewriter &rewriter) const override {
25  RankedTensorType resultType = expandShapeOp.getResultType();
26  auto extractSliceOp =
27  expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>();
28  if (!extractSliceOp)
29  return failure();
30  RankedTensorType srcType = extractSliceOp.getSourceType();
31 
32  // Only cases where the ExpandShapeOp can be folded away entirely are
33  // supported. Moreover, only simple cases where the resulting ExtractSliceOp
34  // has no rank-reduction anymore are supported at the moment.
35  RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
36  srcType, extractSliceOp.getStaticOffsets(),
37  extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides());
38  if (nonReducingExtractType != resultType)
39  return failure();
40 
41  SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
42  SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
43  SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
44  rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
45  expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes,
46  mixedStrides);
47  return success();
48  }
49 };
50 
51 /// Fold collapse_shape which only removes static dimensions of size `1`
52 /// into extract_slice.
53 struct FoldUnPaddingCollapseIntoExtract
54  : public OpRewritePattern<tensor::CollapseShapeOp> {
56 
57  LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp,
58  PatternRewriter &rewriter) const override {
59  auto extractSliceOp =
60  collapseShapeOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>();
61  // Collapse cannot be folded away with multiple users of the extract slice
62  // and it is not necessarily beneficial to only convert the collapse into
63  // another extract slice.
64  if (!extractSliceOp || !extractSliceOp->hasOneUse())
65  return failure();
66 
67  // Only fold away simple collapse where all removed dimensions have static
68  // size `1`.
70  collapseShapeOp.getSrcType(), collapseShapeOp.getResultType());
72  return rewriter.notifyMatchFailure(collapseShapeOp,
73  "expected unpadding collapse");
74 
75  Value unPaddedExtractSlice = rewriter.create<tensor::ExtractSliceOp>(
76  extractSliceOp.getLoc(), collapseShapeOp.getResultType(),
77  extractSliceOp.getSource(), extractSliceOp.getMixedOffsets(),
78  extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
79  rewriter.replaceOp(collapseShapeOp, unPaddedExtractSlice);
80  return success();
81  }
82 };
83 
84 /// Fold insert_slice(collapse_shape) ops that cancel itself out.
85 template <typename OpTy>
86 struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> {
88 
89  LogicalResult matchAndRewrite(OpTy insertSliceOp,
90  PatternRewriter &rewriter) const override {
91  auto collapseShapeOp =
92  insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>();
93  if (!collapseShapeOp)
94  return failure();
95  RankedTensorType srcType = collapseShapeOp.getSrcType();
96 
97  // Only cases where the CollapseShapeOp can be folded away entirely are
98  // supported. Moreover, only simple cases where the resulting InsertSliceOp
99  // has no rank-reduction anymore are supported at the moment.
100  RankedTensorType nonReducingInsertType =
101  RankedTensorType::get(insertSliceOp.getStaticSizes(),
102  insertSliceOp.getDestType().getElementType());
103  if (nonReducingInsertType != srcType)
104  return failure();
105 
106  SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
107  SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
108  SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
109  rewriter.replaceOpWithNewOp<OpTy>(insertSliceOp, collapseShapeOp.getSrc(),
110  insertSliceOp.getDest(), mixedOffsets,
111  mixedSizes, mixedStrides);
112  return success();
113  }
114 };
115 
116 /// Fold expand_shape which only adds static dimensions of size `1`
117 /// into insert_slice.
118 template <typename OpTy>
119 struct FoldPaddingExpandIntoInsert : public OpRewritePattern<OpTy> {
121 
122  LogicalResult matchAndRewrite(OpTy insertSliceOp,
123  PatternRewriter &rewriter) const override {
124  auto expandShapeOp = insertSliceOp.getSource()
125  .template getDefiningOp<tensor::ExpandShapeOp>();
126  if (!expandShapeOp)
127  return failure();
128 
129  // Only fold away simple expansion where all added dimensions have static
130  // size `1`.
132  expandShapeOp.getResultType(), expandShapeOp.getSrcType());
134  return rewriter.notifyMatchFailure(insertSliceOp,
135  "expected rank increasing expansion");
136 
137  rewriter.modifyOpInPlace(insertSliceOp, [&]() {
138  insertSliceOp.getSourceMutable().assign(expandShapeOp.getSrc());
139  });
140  return success();
141  }
142 };
143 } // namespace
144 
146  RewritePatternSet &patterns) {
147  patterns
148  .add<FoldExpandOfRankReducingExtract, FoldUnPaddingCollapseIntoExtract,
149  FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
150  FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>,
151  FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>,
152  FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
153  patterns.getContext());
154 }
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold tensor.expand_shape and tensor.collapse_shape into other o...
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:381
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358