MLIR  20.0.0git
EmptyOpPatterns.cpp
Go to the documentation of this file.
1 //===- EmptyOpPatterns.cpp - Patterns related to tensor.empty folding ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
11 #include "mlir/IR/PatternMatch.h"
12 #include "llvm/Support/Debug.h"
13 
14 using namespace mlir;
15 using namespace mlir::tensor;
16 
17 namespace {
18 
19 template <typename ReshapeOp>
20 struct FoldEmptyTensorWithReshapeOp : public OpRewritePattern<ReshapeOp> {
21  FoldEmptyTensorWithReshapeOp(MLIRContext *ctx, PatternBenefit benefit = 1,
22  bool foldSingleUseOnly = false)
23  : OpRewritePattern<ReshapeOp>(ctx, benefit),
24  foldSingleUseOnly(foldSingleUseOnly) {}
25 
26  LogicalResult matchAndRewrite(ReshapeOp reshapeOp,
27  PatternRewriter &rewriter) const override {
28  // Check for tensor.empty source.
29  auto emptyOp = reshapeOp.getSrc().template getDefiningOp<EmptyOp>();
30  if (!emptyOp)
31  return failure();
32 
33  // Check for single use.
34  if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses()))
35  return failure();
36 
37  // Reify result shape.
38  Location loc = reshapeOp.getLoc();
39  ReifiedRankedShapedTypeDims resultShapes;
40  if (failed(reifyResultShapes(rewriter, reshapeOp, resultShapes)) ||
41  !llvm::hasSingleElement(resultShapes))
42  return failure();
43 
44  // Create new tensor.empty op.
45  // TODO: Do not drop tensor type encoding.
46  Value emptyTensor = rewriter.create<EmptyOp>(
47  loc, resultShapes[0], reshapeOp.getResultType().getElementType());
48  if (emptyTensor.getType() != reshapeOp.getResultType()) {
49  rewriter.replaceOpWithNewOp<tensor::CastOp>(
50  reshapeOp, reshapeOp.getResultType(), emptyTensor);
51  } else {
52  rewriter.replaceOp(reshapeOp, emptyTensor);
53  }
54  return success();
55  }
56 
57 private:
58  bool foldSingleUseOnly = false;
59 };
60 
61 /// tensor.empty does not define any tensor contents, so a slice of a
62 /// tensor.empty can be folded to a smaller tensor.empty.
63 struct FoldEmptyTensorWithExtractSliceOp
64  : public OpRewritePattern<ExtractSliceOp> {
65  FoldEmptyTensorWithExtractSliceOp(MLIRContext *ctx,
66  PatternBenefit benefit = 1,
67  bool foldSingleUseOnly = false)
68  : OpRewritePattern<ExtractSliceOp>(ctx, benefit),
69  foldSingleUseOnly(foldSingleUseOnly) {}
70 
71  LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
72  PatternRewriter &rewriter) const override {
73  // Check for tensor.empty source.
74  auto emptyOp = sliceOp.getSource().template getDefiningOp<EmptyOp>();
75  if (!emptyOp)
76  return failure();
77 
78  // Check for single use.
79  if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses()))
80  return failure();
81 
82  // Create new tensor.empty op. tensor.extract_slice may be rank-reducing;
83  // its dynamic sizes must be preserved as well as its result type.
84  auto tensorType = RankedTensorType::get(sliceOp.getType().getShape(),
85  sliceOp.getType().getElementType(),
86  sliceOp.getType().getEncoding());
87  rewriter.replaceOpWithNewOp<EmptyOp>(sliceOp, tensorType,
88  sliceOp.getSizes());
89  return success();
90  }
91 
92 private:
93  bool foldSingleUseOnly = false;
94 };
95 
96 /// tensor.empty does not define any tensor contents, so an unpadded pack
97 /// can be folded away.
98 struct FoldEmptyTensorWithPackOp : public OpRewritePattern<PackOp> {
100 
101  LogicalResult matchAndRewrite(PackOp packOp,
102  PatternRewriter &rewriter) const override {
103  // Check for tensor.empty source.
104  auto emptyOp = packOp.getSource().getDefiningOp<EmptyOp>();
105  if (!emptyOp)
106  return failure();
107 
108  // Check for padding.
109  // Packing with padding cannot be simply removed.
110  if (packOp.getPaddingValue())
111  return rewriter.notifyMatchFailure(packOp, "expects no padding value");
112 
113  // Replace the pack directly with its destination.
114  rewriter.replaceOp(packOp, packOp.getDest());
115 
116  return success();
117  }
118 };
119 
120 /// tensor.empty does not define any tensor contents, so an unpack
121 /// can be folded away.
122 struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> {
124 
125  LogicalResult matchAndRewrite(UnPackOp unPackOp,
126  PatternRewriter &rewriter) const override {
127  // Check for tensor.empty source.
128  auto emptyOp = unPackOp.getSource().getDefiningOp<EmptyOp>();
129  if (!emptyOp)
130  return failure();
131 
132  // Replace the unpack directly with its destination.
133  rewriter.replaceOp(unPackOp, unPackOp.getDest());
134 
135  return success();
136  }
137 };
138 
139 // Fold concat operation where all the operands are empty.
140 struct FoldConcatsOfEmpty : public OpRewritePattern<ConcatOp> {
142 
143  LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
144  PatternRewriter &rewriter) const override {
145  auto concatOperands = concatOp.getInputs();
146  if (concatOperands.empty()) {
147  return failure();
148  }
149  auto firstEmptyOp = concatOperands.front().getDefiningOp<tensor::EmptyOp>();
150  if (!firstEmptyOp) {
151  return failure();
152  }
153  auto isDefinedByEmptyOp = [](Value v) -> bool {
154  return v.getDefiningOp<tensor::EmptyOp>();
155  };
156  if (!llvm::all_of(concatOperands.drop_front(), isDefinedByEmptyOp)) {
157  return rewriter.notifyMatchFailure(
158  concatOp, "not all operands are defined by an empty op");
159  }
161  if (failed(concatOp.reifyResultShapes(rewriter, resultShape))) {
162  return rewriter.notifyMatchFailure(concatOp,
163  "failed to get result shape");
164  }
165  rewriter.replaceOpWithNewOp<tensor::EmptyOp>(
166  concatOp, resultShape[0], concatOp.getResultType().getElementType());
167  return success();
168  }
169 };
170 
171 } // namespace
172 
174  bool foldSingleUseOnly) {
175  patterns.add<FoldEmptyTensorWithExtractSliceOp,
176  FoldEmptyTensorWithReshapeOp<tensor::ExpandShapeOp>,
177  FoldEmptyTensorWithReshapeOp<tensor::CollapseShapeOp>>(
178  patterns.getContext(), /*benefit=*/1, foldSingleUseOnly);
179  patterns.add<FoldConcatsOfEmpty, FoldEmptyTensorWithPackOp,
180  FoldEmptyTensorWithUnPackOp>(patterns.getContext(),
181  /*benefit=*/1);
182 }
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, bool foldSingleUseOnly=false)
Populates patterns with patterns that fold tensor.empty with its consumers.
Include the generated interface declarations.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358