MLIR 22.0.0git
TosaMakeBroadcastable.cpp
Go to the documentation of this file.
1//===- TosaMakeBroadcastable.cpp ------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Insert reshape to binary op's input if needed to match rank
10//
11//===----------------------------------------------------------------------===//
12
18
19namespace mlir {
20namespace tosa {
21#define GEN_PASS_DEF_TOSAMAKEBROADCASTABLEPASS
22#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
23} // namespace tosa
24} // namespace mlir
25
26using namespace mlir;
27using namespace mlir::tosa;
28
29namespace {
30
31/// Common code to create the reshape op where necessary to make the rank of the
32/// operations equal. input1 and input2 will be updated when the rank has
33/// changed. The caller is expected to use these to rewrite the original
34/// operator with the RESHAPE now in the graph.
35/// return failure when (1) no reshape needed, or (2) output_type is specified
36/// and it has different rank
37LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter, Location loc,
38 RankedTensorType outputType, Value &input1,
39 Value &input2) {
40 auto input1Ty = dyn_cast<RankedTensorType>(input1.getType());
41 auto input2Ty = dyn_cast<RankedTensorType>(input2.getType());
42
43 if (!input1Ty || !input2Ty) {
44 return rewriter.notifyMatchFailure(loc, "input not a ranked tensor");
45 }
46
47 int64_t input1Rank = input1Ty.getRank();
48 int64_t input2Rank = input2Ty.getRank();
49
50 if (input1Rank == input2Rank)
51 return rewriter.notifyMatchFailure(loc,
52 "cannot rewrite as its already correct");
53
54 Value input1Copy = input1;
55 Value input2Copy = input2;
56 if (EqualizeRanks(rewriter, loc, input1Copy, input2Copy).failed()) {
57 return rewriter.notifyMatchFailure(loc, "failed to reshape inputs");
58 }
59
60 // Verify the rank agrees with the output type if the output type is ranked.
61 if (outputType) {
62 if (outputType.getRank() !=
63 llvm::cast<RankedTensorType>(input1Copy.getType()).getRank() ||
64 outputType.getRank() !=
65 llvm::cast<RankedTensorType>(input2Copy.getType()).getRank())
66 return rewriter.notifyMatchFailure(
67 loc, "the reshaped type doesn't agrees with the ranked output type");
68 }
69
70 input1 = input1Copy;
71 input2 = input2Copy;
72
73 return success();
74}
75
76template <typename OpTy>
77struct ConvertTosaOp : public OpRewritePattern<OpTy> {
78 using OpRewritePattern<OpTy>::OpRewritePattern;
79
80 LogicalResult matchAndRewrite(OpTy tosaBinaryOp,
81 PatternRewriter &rewriter) const override {
82
83 Value input1 = tosaBinaryOp.getInput1();
84 Value input2 = tosaBinaryOp.getInput2();
85 Value output = tosaBinaryOp.getResult();
86
87 auto outputType = dyn_cast<RankedTensorType>(output.getType());
88 if (!outputType)
89 return failure();
90
91 if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
92 input1, input2)
93 .failed())
94 return failure();
95
96 rewriter.replaceOpWithNewOp<OpTy>(tosaBinaryOp, outputType, input1, input2);
97
98 return success();
99 }
100};
101
102// The MulOp has an extra parameter 'shift' not present in other elementwise
103// binary ops, that necessitates special handling of its builder.
104template <>
105struct ConvertTosaOp<tosa::MulOp> : public OpRewritePattern<tosa::MulOp> {
106 using OpRewritePattern<tosa::MulOp>::OpRewritePattern;
107
108 LogicalResult matchAndRewrite(tosa::MulOp tosaBinaryOp,
109 PatternRewriter &rewriter) const override {
110
111 Value input1 = tosaBinaryOp.getInput1();
112 Value input2 = tosaBinaryOp.getInput2();
113 Value shift = tosaBinaryOp.getShift();
114 Value output = tosaBinaryOp.getResult();
115 auto outputType = dyn_cast<RankedTensorType>(output.getType());
116 if (!outputType)
117 return failure();
118
119 if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
120 input1, input2)
121 .failed())
122 return failure();
123
124 rewriter.replaceOpWithNewOp<tosa::MulOp>(tosaBinaryOp, outputType, input1,
125 input2, shift);
126
127 return success();
128 }
129};
130
131// The ArithmeticRightShiftOp has an extra parameter 'round' not present in
132// other elementwise binary ops, that necessitates special handling of its
133// builder.
134template <>
135struct ConvertTosaOp<tosa::ArithmeticRightShiftOp>
136 : public OpRewritePattern<tosa::ArithmeticRightShiftOp> {
137 using OpRewritePattern<tosa::ArithmeticRightShiftOp>::OpRewritePattern;
138
139 LogicalResult matchAndRewrite(tosa::ArithmeticRightShiftOp tosaBinaryOp,
140 PatternRewriter &rewriter) const override {
141
142 Value input1 = tosaBinaryOp.getInput1();
143 Value input2 = tosaBinaryOp.getInput2();
144 int32_t round = tosaBinaryOp.getRound();
145 Value output = tosaBinaryOp.getResult();
146 auto outputType = dyn_cast<RankedTensorType>(output.getType());
147 if (!outputType)
148 return failure();
149
150 if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
151 input1, input2)
152 .failed())
153 return failure();
154
155 rewriter.replaceOpWithNewOp<tosa::ArithmeticRightShiftOp>(
156 tosaBinaryOp, outputType, input1, input2, round);
157
158 return success();
159 }
160};
161
162template <>
163struct ConvertTosaOp<tosa::SelectOp> : public OpRewritePattern<tosa::SelectOp> {
164 using OpRewritePattern<tosa::SelectOp>::OpRewritePattern;
165
166 LogicalResult matchAndRewrite(tosa::SelectOp tosaOp,
167 PatternRewriter &rewriter) const override {
168
169 Value input1 = tosaOp.getPred();
170 Value input2 = tosaOp.getOnTrue();
171 Value input3 = tosaOp.getOnFalse();
172 Value output = tosaOp.getResult();
173
174 auto outputType = dyn_cast<RankedTensorType>(output.getType());
175 if (!outputType)
176 return rewriter.notifyMatchFailure(tosaOp, "output not a ranked tensor");
177
178 // Apply broadcasting to each pair of inputs separately, and chain them as
179 // compound as below so that the broadcasting happens all at once.
180 bool reshaped1 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType,
181 input1, input2)
182 .succeeded();
183
184 bool reshaped2 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType,
185 input1, input3)
186 .succeeded();
187
188 bool reshaped3 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType,
189 input2, input3)
190 .succeeded();
191
192 if (!reshaped1 && !reshaped2 && !reshaped3)
193 return rewriter.notifyMatchFailure(
194 tosaOp,
195 "cannot rewrite as the rank of all operands is already aligned");
196
197 int32_t result1Rank = cast<RankedTensorType>(input1.getType()).getRank();
198 int32_t result2Rank = cast<RankedTensorType>(input2.getType()).getRank();
199 int32_t result3Rank = cast<RankedTensorType>(input3.getType()).getRank();
200 int32_t outputRank = outputType.getRank();
201
202 if ((result1Rank != result2Rank) || (result2Rank != result3Rank) ||
203 (result1Rank != outputRank))
204 return rewriter.notifyMatchFailure(
205 tosaOp, "not all ranks are aligned with each other");
206
207 rewriter.replaceOpWithNewOp<tosa::SelectOp>(tosaOp, outputType, input1,
208 input2, input3);
209
210 return success();
211 }
212};
213} // namespace
214
215namespace {
216/// Pass that enables broadcast by making all input arrays have the same
217/// number of dimensions. Insert RESHAPE operations to lower rank operand
218struct TosaMakeBroadcastable
219 : public tosa::impl::TosaMakeBroadcastablePassBase<TosaMakeBroadcastable> {
220public:
221 void runOnOperation() override {
222 auto func = getOperation();
223 RewritePatternSet patterns(func.getContext());
224 MLIRContext *ctx = func.getContext();
225 // Add the generated patterns to the list.
226 patterns.add<ConvertTosaOp<tosa::BitwiseAndOp>>(ctx);
227 patterns.add<ConvertTosaOp<tosa::BitwiseOrOp>>(ctx);
228 patterns.add<ConvertTosaOp<tosa::BitwiseXorOp>>(ctx);
229 patterns.add<ConvertTosaOp<tosa::AddOp>>(ctx);
230 patterns.add<ConvertTosaOp<tosa::SubOp>>(ctx);
231 patterns.add<ConvertTosaOp<tosa::MulOp>>(ctx);
232 patterns.add<ConvertTosaOp<tosa::IntDivOp>>(ctx);
233 patterns.add<ConvertTosaOp<tosa::MaximumOp>>(ctx);
234 patterns.add<ConvertTosaOp<tosa::MinimumOp>>(ctx);
235 patterns.add<ConvertTosaOp<tosa::EqualOp>>(ctx);
236 patterns.add<ConvertTosaOp<tosa::GreaterOp>>(ctx);
237 patterns.add<ConvertTosaOp<tosa::GreaterEqualOp>>(ctx);
238 patterns.add<ConvertTosaOp<tosa::LogicalLeftShiftOp>>(ctx);
239 patterns.add<ConvertTosaOp<tosa::ArithmeticRightShiftOp>>(ctx);
240 patterns.add<ConvertTosaOp<tosa::LogicalRightShiftOp>>(ctx);
241 patterns.add<ConvertTosaOp<tosa::LogicalAndOp>>(ctx);
242 patterns.add<ConvertTosaOp<tosa::LogicalOrOp>>(ctx);
243 patterns.add<ConvertTosaOp<tosa::LogicalXorOp>>(ctx);
244 patterns.add<ConvertTosaOp<tosa::SelectOp>>(ctx);
245 patterns.add<ConvertTosaOp<tosa::PowOp>>(ctx);
246 (void)applyPatternsGreedily(func, std::move(patterns));
247 }
248};
249} // namespace
return success()
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
DynamicAPInt round(const Fraction &f)
Definition Fraction.h:136
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc, Value &input1, Value &input2)
Common code to create the reshape op where necessary to make the rank of two values equal.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...