MLIR 23.0.0git
EmptyOpPatterns.cpp
Go to the documentation of this file.
1//===- EmptyOpPatterns.cpp - Patterns related to tensor.empty folding ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
12
13using namespace mlir;
14using namespace mlir::tensor;
15
16namespace {
17
18template <typename ReshapeOp>
19struct FoldEmptyTensorWithReshapeOp : public OpRewritePattern<ReshapeOp> {
20 FoldEmptyTensorWithReshapeOp(MLIRContext *ctx, PatternBenefit benefit = 1,
21 bool foldSingleUseOnly = false)
22 : OpRewritePattern<ReshapeOp>(ctx, benefit),
23 foldSingleUseOnly(foldSingleUseOnly) {}
24
25 LogicalResult matchAndRewrite(ReshapeOp reshapeOp,
26 PatternRewriter &rewriter) const override {
27 // Check for tensor.empty source.
28 auto emptyOp = reshapeOp.getSrc().template getDefiningOp<EmptyOp>();
29 if (!emptyOp)
30 return failure();
31
32 // Check for single use.
33 if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses()))
34 return failure();
35
36 // Reify result shape.
37 Location loc = reshapeOp.getLoc();
38 ReifiedRankedShapedTypeDims resultShapes;
39 if (failed(reifyResultShapes(rewriter, reshapeOp, resultShapes)) ||
40 !llvm::hasSingleElement(resultShapes))
41 return failure();
42
43 Attribute encoding;
44 if (auto tensorTy = dyn_cast<RankedTensorType>(reshapeOp.getResultType()))
45 encoding = tensorTy.getEncoding();
46
47 // Create new tensor.empty op.
48 Value emptyTensor =
49 EmptyOp::create(rewriter, loc, resultShapes[0],
50 reshapeOp.getResultType().getElementType(), encoding);
51 if (emptyTensor.getType() != reshapeOp.getResultType()) {
52 rewriter.replaceOpWithNewOp<tensor::CastOp>(
53 reshapeOp, reshapeOp.getResultType(), emptyTensor);
54 } else {
55 rewriter.replaceOp(reshapeOp, emptyTensor);
56 }
57 return success();
58 }
59
60private:
61 bool foldSingleUseOnly = false;
62};
63
64/// tensor.empty does not define any tensor contents, so a slice of a
65/// tensor.empty can be folded to a smaller tensor.empty.
66struct FoldEmptyTensorWithExtractSliceOp
67 : public OpRewritePattern<ExtractSliceOp> {
68 FoldEmptyTensorWithExtractSliceOp(MLIRContext *ctx,
69 PatternBenefit benefit = 1,
70 bool foldSingleUseOnly = false)
71 : OpRewritePattern<ExtractSliceOp>(ctx, benefit),
72 foldSingleUseOnly(foldSingleUseOnly) {}
73
74 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
75 PatternRewriter &rewriter) const override {
76 // Check for tensor.empty source.
77 auto emptyOp = sliceOp.getSource().template getDefiningOp<EmptyOp>();
78 if (!emptyOp)
79 return failure();
80
81 // Check for single use.
82 if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses()))
83 return failure();
84
85 // Create new tensor.empty op. tensor.extract_slice may be rank-reducing;
86 // its dynamic sizes must be preserved as well as its result type.
87 auto tensorType = RankedTensorType::get(sliceOp.getType().getShape(),
88 sliceOp.getType().getElementType(),
89 sliceOp.getType().getEncoding());
90 rewriter.replaceOpWithNewOp<EmptyOp>(sliceOp, tensorType,
91 sliceOp.getSizes());
92 return success();
93 }
94
95private:
96 bool foldSingleUseOnly = false;
97};
98
99// Fold concat operation where all the operands are empty.
100struct FoldConcatsOfEmpty : public OpRewritePattern<ConcatOp> {
101 using OpRewritePattern<ConcatOp>::OpRewritePattern;
102
103 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
104 PatternRewriter &rewriter) const override {
105 auto concatOperands = concatOp.getInputs();
106 if (concatOperands.empty()) {
107 return failure();
108 }
109 auto firstEmptyOp = concatOperands.front().getDefiningOp<tensor::EmptyOp>();
110 if (!firstEmptyOp) {
111 return failure();
112 }
113 auto isDefinedByEmptyOp = [](Value v) -> bool {
114 return v.getDefiningOp<tensor::EmptyOp>();
115 };
116 if (!llvm::all_of(concatOperands.drop_front(), isDefinedByEmptyOp)) {
117 return rewriter.notifyMatchFailure(
118 concatOp, "not all operands are defined by an empty op");
119 }
120 SmallVector<SmallVector<OpFoldResult>> resultShape;
121 if (failed(concatOp.reifyResultShapes(rewriter, resultShape))) {
122 return rewriter.notifyMatchFailure(concatOp,
123 "failed to get result shape");
124 }
125 rewriter.replaceOpWithNewOp<tensor::EmptyOp>(
126 concatOp, resultShape[0], concatOp.getResultType().getElementType());
127 return success();
128 }
129};
130
131} // namespace
132
134 bool foldSingleUseOnly) {
135 patterns.add<FoldEmptyTensorWithExtractSliceOp,
136 FoldEmptyTensorWithReshapeOp<tensor::ExpandShapeOp>,
137 FoldEmptyTensorWithReshapeOp<tensor::CollapseShapeOp>>(
138 patterns.getContext(), /*benefit=*/1, foldSingleUseOnly);
139 patterns.add<FoldConcatsOfEmpty>(patterns.getContext(),
140 /*benefit=*/1);
141}
return success()
static Type getElementType(Type type)
Determine the element type of type.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type getType() const
Return the type of this value.
Definition Value.h:105
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, bool foldSingleUseOnly=false)
Populates patterns with patterns that fold tensor.empty with its consumers.
Include the generated interface declarations.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...