MLIR  21.0.0git
EmptyOpPatterns.cpp
Go to the documentation of this file.
1 //===- EmptyOpPatterns.cpp - Patterns related to tensor.empty folding ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
11 #include "mlir/IR/PatternMatch.h"
12 #include "llvm/Support/Debug.h"
13 
14 using namespace mlir;
15 using namespace mlir::tensor;
16 
17 namespace {
18 
19 template <typename ReshapeOp>
20 struct FoldEmptyTensorWithReshapeOp : public OpRewritePattern<ReshapeOp> {
21  FoldEmptyTensorWithReshapeOp(MLIRContext *ctx, PatternBenefit benefit = 1,
22  bool foldSingleUseOnly = false)
23  : OpRewritePattern<ReshapeOp>(ctx, benefit),
24  foldSingleUseOnly(foldSingleUseOnly) {}
25 
26  LogicalResult matchAndRewrite(ReshapeOp reshapeOp,
27  PatternRewriter &rewriter) const override {
28  // Check for tensor.empty source.
29  auto emptyOp = reshapeOp.getSrc().template getDefiningOp<EmptyOp>();
30  if (!emptyOp)
31  return failure();
32 
33  // Check for single use.
34  if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses()))
35  return failure();
36 
37  // Reify result shape.
38  Location loc = reshapeOp.getLoc();
39  ReifiedRankedShapedTypeDims resultShapes;
40  if (failed(reifyResultShapes(rewriter, reshapeOp, resultShapes)) ||
41  !llvm::hasSingleElement(resultShapes))
42  return failure();
43 
44  // Create new tensor.empty op.
45  // TODO: Do not drop tensor type encoding.
46  Value emptyTensor = rewriter.create<EmptyOp>(
47  loc, resultShapes[0], reshapeOp.getResultType().getElementType());
48  if (emptyTensor.getType() != reshapeOp.getResultType()) {
49  rewriter.replaceOpWithNewOp<tensor::CastOp>(
50  reshapeOp, reshapeOp.getResultType(), emptyTensor);
51  } else {
52  rewriter.replaceOp(reshapeOp, emptyTensor);
53  }
54  return success();
55  }
56 
57 private:
58  bool foldSingleUseOnly = false;
59 };
60 
61 /// tensor.empty does not define any tensor contents, so a slice of a
62 /// tensor.empty can be folded to a smaller tensor.empty.
63 struct FoldEmptyTensorWithExtractSliceOp
64  : public OpRewritePattern<ExtractSliceOp> {
65  FoldEmptyTensorWithExtractSliceOp(MLIRContext *ctx,
66  PatternBenefit benefit = 1,
67  bool foldSingleUseOnly = false)
68  : OpRewritePattern<ExtractSliceOp>(ctx, benefit),
69  foldSingleUseOnly(foldSingleUseOnly) {}
70 
71  LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
72  PatternRewriter &rewriter) const override {
73  // Check for tensor.empty source.
74  auto emptyOp = sliceOp.getSource().template getDefiningOp<EmptyOp>();
75  if (!emptyOp)
76  return failure();
77 
78  // Check for single use.
79  if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses()))
80  return failure();
81 
82  // Create new tensor.empty op. tensor.extract_slice may be rank-reducing;
83  // its dynamic sizes must be preserved as well as its result type.
84  auto tensorType = RankedTensorType::get(sliceOp.getType().getShape(),
85  sliceOp.getType().getElementType(),
86  sliceOp.getType().getEncoding());
87  rewriter.replaceOpWithNewOp<EmptyOp>(sliceOp, tensorType,
88  sliceOp.getSizes());
89  return success();
90  }
91 
92 private:
93  bool foldSingleUseOnly = false;
94 };
95 
96 // Fold concat operation where all the operands are empty.
97 struct FoldConcatsOfEmpty : public OpRewritePattern<ConcatOp> {
99 
100  LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
101  PatternRewriter &rewriter) const override {
102  auto concatOperands = concatOp.getInputs();
103  if (concatOperands.empty()) {
104  return failure();
105  }
106  auto firstEmptyOp = concatOperands.front().getDefiningOp<tensor::EmptyOp>();
107  if (!firstEmptyOp) {
108  return failure();
109  }
110  auto isDefinedByEmptyOp = [](Value v) -> bool {
111  return v.getDefiningOp<tensor::EmptyOp>();
112  };
113  if (!llvm::all_of(concatOperands.drop_front(), isDefinedByEmptyOp)) {
114  return rewriter.notifyMatchFailure(
115  concatOp, "not all operands are defined by an empty op");
116  }
118  if (failed(concatOp.reifyResultShapes(rewriter, resultShape))) {
119  return rewriter.notifyMatchFailure(concatOp,
120  "failed to get result shape");
121  }
122  rewriter.replaceOpWithNewOp<tensor::EmptyOp>(
123  concatOp, resultShape[0], concatOp.getResultType().getElementType());
124  return success();
125  }
126 };
127 
128 } // namespace
129 
131  bool foldSingleUseOnly) {
132  patterns.add<FoldEmptyTensorWithExtractSliceOp,
133  FoldEmptyTensorWithReshapeOp<tensor::ExpandShapeOp>,
134  FoldEmptyTensorWithReshapeOp<tensor::CollapseShapeOp>>(
135  patterns.getContext(), /*benefit=*/1, foldSingleUseOnly);
136  patterns.add<FoldConcatsOfEmpty>(patterns.getContext(),
137  /*benefit=*/1);
138 }
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:803
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:736
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:554
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, bool foldSingleUseOnly=false)
Populates patterns with patterns that fold tensor.empty with its consumers.
Include the generated interface declarations.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358