MLIR  19.0.0git
ReshapePatterns.cpp
Go to the documentation of this file.
1 //===- RankReductionPatterns.cpp - Patterns related to rank reductions ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/PatternMatch.h"
12 #include "llvm/Support/Debug.h"
13 
14 using namespace mlir;
15 using namespace mlir::tensor;
16 
17 namespace {
18 /// Fold expand_shape(extract_slice) ops that cancel itself out.
19 struct FoldExpandOfRankReducingExtract
20  : public OpRewritePattern<ExpandShapeOp> {
22 
23  LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp,
24  PatternRewriter &rewriter) const override {
25  RankedTensorType resultType = expandShapeOp.getResultType();
26  auto extractSliceOp =
27  expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>();
28  if (!extractSliceOp)
29  return failure();
30  RankedTensorType srcType = extractSliceOp.getSourceType();
31 
32  // Only cases where the ExpandShapeOp can be folded away entirely are
33  // supported. Moreover, only simple cases where the resulting ExtractSliceOp
34  // has no rank-reduction anymore are supported at the moment.
35  RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
36  srcType, extractSliceOp.getStaticOffsets(),
37  extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides());
38  if (nonReducingExtractType != resultType)
39  return failure();
40 
41  SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
42  SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
43  SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
44  rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
45  expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes,
46  mixedStrides);
47  return success();
48  }
49 };
50 
51 /// Fold insert_slice(collapse_shape) ops that cancel itself out.
52 template <typename OpTy>
53 struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> {
55 
56  LogicalResult matchAndRewrite(OpTy insertSliceOp,
57  PatternRewriter &rewriter) const override {
58  auto collapseShapeOp =
59  insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>();
60  if (!collapseShapeOp)
61  return failure();
62  RankedTensorType srcType = collapseShapeOp.getSrcType();
63 
64  // Only cases where the CollapseShapeOp can be folded away entirely are
65  // supported. Moreover, only simple cases where the resulting InsertSliceOp
66  // has no rank-reduction anymore are supported at the moment.
67  RankedTensorType nonReducingInsertType =
68  RankedTensorType::get(insertSliceOp.getStaticSizes(),
69  insertSliceOp.getDestType().getElementType());
70  if (nonReducingInsertType != srcType)
71  return failure();
72 
73  SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
74  SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
75  SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
76  rewriter.replaceOpWithNewOp<OpTy>(insertSliceOp, collapseShapeOp.getSrc(),
77  insertSliceOp.getDest(), mixedOffsets,
78  mixedSizes, mixedStrides);
79  return success();
80  }
81 };
82 } // namespace
83 
85  RewritePatternSet &patterns) {
86  patterns.add<FoldExpandOfRankReducingExtract,
87  FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
88  FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>>(
89  patterns.getContext());
90 }
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:775
MLIRContext * getContext() const
Definition: PatternMatch.h:812
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:836
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:534
void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold tensor.expand_shape and tensor.collapse_shape into other o...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357