MLIR  19.0.0git
EmptyOpPatterns.cpp
Go to the documentation of this file.
1 //===- EmptyOpPatterns.cpp - Patterns related to tensor.empty folding ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
11 #include "mlir/IR/PatternMatch.h"
12 #include "llvm/Support/Debug.h"
13 
14 using namespace mlir;
15 using namespace mlir::tensor;
16 
17 namespace {
18 
19 template <typename ReshapeOp>
20 struct FoldEmptyTensorWithReshapeOp : public OpRewritePattern<ReshapeOp> {
21  FoldEmptyTensorWithReshapeOp(MLIRContext *ctx, PatternBenefit benefit = 1,
22  bool foldSingleUseOnly = false)
23  : OpRewritePattern<ReshapeOp>(ctx, benefit),
24  foldSingleUseOnly(foldSingleUseOnly) {}
25 
26  LogicalResult matchAndRewrite(ReshapeOp reshapeOp,
27  PatternRewriter &rewriter) const override {
28  // Check for tensor.empty source.
29  auto emptyOp = reshapeOp.getSrc().template getDefiningOp<EmptyOp>();
30  if (!emptyOp)
31  return failure();
32 
33  // Check for single use.
34  if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses()))
35  return failure();
36 
37  // Reify result shape.
38  Location loc = reshapeOp.getLoc();
39  ReifiedRankedShapedTypeDims resultShapes;
40  if (failed(reifyResultShapes(rewriter, reshapeOp, resultShapes)) ||
41  !llvm::hasSingleElement(resultShapes))
42  return failure();
43 
44  // Create new tensor.empty op.
45  // TODO: Do not drop tensor type encoding.
46  Value emptyTensor = rewriter.create<EmptyOp>(
47  loc, resultShapes[0], reshapeOp.getResultType().getElementType());
48  if (emptyTensor.getType() != reshapeOp.getResultType()) {
49  rewriter.replaceOpWithNewOp<tensor::CastOp>(
50  reshapeOp, reshapeOp.getResultType(), emptyTensor);
51  } else {
52  rewriter.replaceOp(reshapeOp, emptyTensor);
53  }
54  return success();
55  }
56 
57 private:
58  bool foldSingleUseOnly = false;
59 };
60 
61 /// tensor.empty does not define any tensor contents, so a slice of a
62 /// tensor.empty can be folded to a smaller tensor.empty.
63 struct FoldEmptyTensorWithExtractSliceOp
64  : public OpRewritePattern<ExtractSliceOp> {
65  FoldEmptyTensorWithExtractSliceOp(MLIRContext *ctx,
66  PatternBenefit benefit = 1,
67  bool foldSingleUseOnly = false)
68  : OpRewritePattern<ExtractSliceOp>(ctx, benefit),
69  foldSingleUseOnly(foldSingleUseOnly) {}
70 
71  LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
72  PatternRewriter &rewriter) const override {
73  // Check for tensor.empty source.
74  auto emptyOp = sliceOp.getSource().template getDefiningOp<EmptyOp>();
75  if (!emptyOp)
76  return failure();
77 
78  // Check for single use.
79  if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses()))
80  return failure();
81 
82  // Create new tensor.empty op. tensor.extract_slice may be rank-reducing;
83  // its dynamic sizes must be preserved as well as its result type.
84  auto tensorType = RankedTensorType::get(sliceOp.getType().getShape(),
85  sliceOp.getType().getElementType(),
86  sliceOp.getType().getEncoding());
87  rewriter.replaceOpWithNewOp<EmptyOp>(sliceOp, tensorType,
88  sliceOp.getSizes());
89  return success();
90  }
91 
92 private:
93  bool foldSingleUseOnly = false;
94 };
95 
96 /// tensor.empty does not define any tensor contents, so an unpadded pack
97 /// can be folded away.
98 struct FoldEmptyTensorWithPackOp : public OpRewritePattern<PackOp> {
100 
101  LogicalResult matchAndRewrite(PackOp packOp,
102  PatternRewriter &rewriter) const override {
103  // Check for tensor.empty source.
104  auto emptyOp = packOp.getSource().getDefiningOp<EmptyOp>();
105  if (!emptyOp)
106  return failure();
107 
108  // Check for padding.
109  // Packing with padding cannot be simply removed.
110  if (packOp.getPaddingValue())
111  return rewriter.notifyMatchFailure(packOp, "expects no padding value");
112 
113  // Replace the pack directly with its destination.
114  rewriter.replaceOp(packOp, packOp.getDest());
115 
116  return success();
117  }
118 };
119 
120 /// tensor.empty does not define any tensor contents, so an unpack
121 /// can be folded away.
122 struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> {
124 
125  LogicalResult matchAndRewrite(UnPackOp unPackOp,
126  PatternRewriter &rewriter) const override {
127  // Check for tensor.empty source.
128  auto emptyOp = unPackOp.getSource().getDefiningOp<EmptyOp>();
129  if (!emptyOp)
130  return failure();
131 
132  // Replace the unpack directly with its destination.
133  rewriter.replaceOp(unPackOp, unPackOp.getDest());
134 
135  return success();
136  }
137 };
138 
139 } // namespace
140 
142  bool foldSingleUseOnly) {
143  patterns.add<FoldEmptyTensorWithExtractSliceOp,
144  FoldEmptyTensorWithReshapeOp<tensor::ExpandShapeOp>,
145  FoldEmptyTensorWithReshapeOp<tensor::CollapseShapeOp>>(
146  patterns.getContext(), /*benefit=*/1, foldSingleUseOnly);
147  patterns.add<FoldEmptyTensorWithPackOp, FoldEmptyTensorWithUnPackOp>(
148  patterns.getContext(), /*benefit=*/1);
149 }
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, bool foldSingleUseOnly=false)
Populates patterns with patterns that fold tensor.empty with its consumers.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358