MLIR  19.0.0git
EmptyOpPatterns.cpp
Go to the documentation of this file.
1 //===- EmptyOpPatterns.cpp - Patterns related to tensor.empty folding ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
11 #include "mlir/IR/PatternMatch.h"
12 #include "llvm/Support/Debug.h"
13 
14 using namespace mlir;
15 using namespace mlir::tensor;
16 
17 namespace {
18 
19 template <typename ReshapeOp>
20 struct FoldEmptyTensorWithReshapeOp : public OpRewritePattern<ReshapeOp> {
21  FoldEmptyTensorWithReshapeOp(MLIRContext *ctx, PatternBenefit benefit = 1,
22  bool foldSingleUseOnly = false)
23  : OpRewritePattern<ReshapeOp>(ctx, benefit),
24  foldSingleUseOnly(foldSingleUseOnly) {}
25 
26  LogicalResult matchAndRewrite(ReshapeOp reshapeOp,
27  PatternRewriter &rewriter) const override {
28  // Check for tensor.empty source.
29  auto emptyOp = reshapeOp.getSrc().template getDefiningOp<EmptyOp>();
30  if (!emptyOp)
31  return failure();
32 
33  // Check for single use.
34  if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses()))
35  return failure();
36 
37  // Reify result shape.
38  Location loc = reshapeOp.getLoc();
39  ReifiedRankedShapedTypeDims resultShapes;
40  if (failed(reifyResultShapes(rewriter, reshapeOp, resultShapes)) ||
41  !llvm::hasSingleElement(resultShapes))
42  return failure();
43 
44  // Create new tensor.empty op.
45  // TODO: Do not drop tensor type encoding.
46  Value emptyTensor = rewriter.create<EmptyOp>(
47  loc, resultShapes[0], reshapeOp.getResultType().getElementType());
48  if (emptyTensor.getType() != reshapeOp.getResultType()) {
49  rewriter.replaceOpWithNewOp<tensor::CastOp>(
50  reshapeOp, reshapeOp.getResultType(), emptyTensor);
51  } else {
52  rewriter.replaceOp(reshapeOp, emptyTensor);
53  }
54  return success();
55  }
56 
57 private:
58  bool foldSingleUseOnly = false;
59 };
60 
61 /// tensor.empty does not define any tensor contents, so a slice of a
62 /// tensor.empty can be folded to a smaller tensor.empty.
63 struct FoldEmptyTensorWithExtractSliceOp
64  : public OpRewritePattern<ExtractSliceOp> {
65  FoldEmptyTensorWithExtractSliceOp(MLIRContext *ctx,
66  PatternBenefit benefit = 1,
67  bool foldSingleUseOnly = false)
68  : OpRewritePattern<ExtractSliceOp>(ctx, benefit),
69  foldSingleUseOnly(foldSingleUseOnly) {}
70 
71  LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
72  PatternRewriter &rewriter) const override {
73  // Check for tensor.empty source.
74  auto emptyOp = sliceOp.getSource().template getDefiningOp<EmptyOp>();
75  if (!emptyOp)
76  return failure();
77 
78  // Check for single use.
79  if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses()))
80  return failure();
81 
82  // Create new tensor.empty op. tensor.extract_slice may be rank-reducing;
83  // its dynamic sizes must be preserved as well as its result type.
84  auto tensorType = RankedTensorType::get(sliceOp.getType().getShape(),
85  sliceOp.getType().getElementType(),
86  sliceOp.getType().getEncoding());
87  rewriter.replaceOpWithNewOp<EmptyOp>(sliceOp, tensorType,
88  sliceOp.getSizes());
89  return success();
90  }
91 
92 private:
93  bool foldSingleUseOnly = false;
94 };
95 
96 } // namespace
97 
99  bool foldSingleUseOnly) {
100  patterns.add<FoldEmptyTensorWithExtractSliceOp,
101  FoldEmptyTensorWithReshapeOp<tensor::ExpandShapeOp>,
102  FoldEmptyTensorWithReshapeOp<tensor::CollapseShapeOp>>(
103  patterns.getContext(), /*benefit=*/1, foldSingleUseOnly);
104 }
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, bool foldSingleUseOnly=false)
Populates patterns with patterns that fold tensor.empty with tensor.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358