MLIR 22.0.0git
EmptyOpPatterns.cpp
Go to the documentation of this file.
1//===- EmptyOpPatterns.cpp - Patterns related to tensor.empty folding ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
12
13using namespace mlir;
14using namespace mlir::tensor;
15
16namespace {
17
18template <typename ReshapeOp>
19struct FoldEmptyTensorWithReshapeOp : public OpRewritePattern<ReshapeOp> {
20 FoldEmptyTensorWithReshapeOp(MLIRContext *ctx, PatternBenefit benefit = 1,
21 bool foldSingleUseOnly = false)
22 : OpRewritePattern<ReshapeOp>(ctx, benefit),
23 foldSingleUseOnly(foldSingleUseOnly) {}
24
25 LogicalResult matchAndRewrite(ReshapeOp reshapeOp,
26 PatternRewriter &rewriter) const override {
27 // Check for tensor.empty source.
28 auto emptyOp = reshapeOp.getSrc().template getDefiningOp<EmptyOp>();
29 if (!emptyOp)
30 return failure();
31
32 // Check for single use.
33 if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses()))
34 return failure();
35
36 // Reify result shape.
37 Location loc = reshapeOp.getLoc();
38 ReifiedRankedShapedTypeDims resultShapes;
39 if (failed(reifyResultShapes(rewriter, reshapeOp, resultShapes)) ||
40 !llvm::hasSingleElement(resultShapes))
41 return failure();
42
43 // Create new tensor.empty op.
44 // TODO: Do not drop tensor type encoding.
45 Value emptyTensor =
46 EmptyOp::create(rewriter, loc, resultShapes[0],
47 reshapeOp.getResultType().getElementType());
48 if (emptyTensor.getType() != reshapeOp.getResultType()) {
49 rewriter.replaceOpWithNewOp<tensor::CastOp>(
50 reshapeOp, reshapeOp.getResultType(), emptyTensor);
51 } else {
52 rewriter.replaceOp(reshapeOp, emptyTensor);
53 }
54 return success();
55 }
56
57private:
58 bool foldSingleUseOnly = false;
59};
60
61/// tensor.empty does not define any tensor contents, so a slice of a
62/// tensor.empty can be folded to a smaller tensor.empty.
63struct FoldEmptyTensorWithExtractSliceOp
64 : public OpRewritePattern<ExtractSliceOp> {
65 FoldEmptyTensorWithExtractSliceOp(MLIRContext *ctx,
66 PatternBenefit benefit = 1,
67 bool foldSingleUseOnly = false)
68 : OpRewritePattern<ExtractSliceOp>(ctx, benefit),
69 foldSingleUseOnly(foldSingleUseOnly) {}
70
71 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
72 PatternRewriter &rewriter) const override {
73 // Check for tensor.empty source.
74 auto emptyOp = sliceOp.getSource().template getDefiningOp<EmptyOp>();
75 if (!emptyOp)
76 return failure();
77
78 // Check for single use.
79 if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses()))
80 return failure();
81
82 // Create new tensor.empty op. tensor.extract_slice may be rank-reducing;
83 // its dynamic sizes must be preserved as well as its result type.
84 auto tensorType = RankedTensorType::get(sliceOp.getType().getShape(),
85 sliceOp.getType().getElementType(),
86 sliceOp.getType().getEncoding());
87 rewriter.replaceOpWithNewOp<EmptyOp>(sliceOp, tensorType,
88 sliceOp.getSizes());
89 return success();
90 }
91
92private:
93 bool foldSingleUseOnly = false;
94};
95
96// Fold concat operation where all the operands are empty.
97struct FoldConcatsOfEmpty : public OpRewritePattern<ConcatOp> {
98 using OpRewritePattern<ConcatOp>::OpRewritePattern;
99
100 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
101 PatternRewriter &rewriter) const override {
102 auto concatOperands = concatOp.getInputs();
103 if (concatOperands.empty()) {
104 return failure();
105 }
106 auto firstEmptyOp = concatOperands.front().getDefiningOp<tensor::EmptyOp>();
107 if (!firstEmptyOp) {
108 return failure();
109 }
110 auto isDefinedByEmptyOp = [](Value v) -> bool {
111 return v.getDefiningOp<tensor::EmptyOp>();
112 };
113 if (!llvm::all_of(concatOperands.drop_front(), isDefinedByEmptyOp)) {
114 return rewriter.notifyMatchFailure(
115 concatOp, "not all operands are defined by an empty op");
116 }
117 SmallVector<SmallVector<OpFoldResult>> resultShape;
118 if (failed(concatOp.reifyResultShapes(rewriter, resultShape))) {
119 return rewriter.notifyMatchFailure(concatOp,
120 "failed to get result shape");
121 }
122 rewriter.replaceOpWithNewOp<tensor::EmptyOp>(
123 concatOp, resultShape[0], concatOp.getResultType().getElementType());
124 return success();
125 }
126};
127
128} // namespace
129
131 bool foldSingleUseOnly) {
132 patterns.add<FoldEmptyTensorWithExtractSliceOp,
133 FoldEmptyTensorWithReshapeOp<tensor::ExpandShapeOp>,
134 FoldEmptyTensorWithReshapeOp<tensor::CollapseShapeOp>>(
135 patterns.getContext(), /*benefit=*/1, foldSingleUseOnly);
136 patterns.add<FoldConcatsOfEmpty>(patterns.getContext(),
137 /*benefit=*/1);
138}
return success()
static Type getElementType(Type type)
Determine the element type of type.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type getType() const
Return the type of this value.
Definition Value.h:105
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, bool foldSingleUseOnly=false)
Populates patterns with patterns that fold tensor.empty with its consumers.
Include the generated interface declarations.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...