MLIR  19.0.0git
TosaMakeBroadcastable.cpp
Go to the documentation of this file.
1 //===- TosaMakeBroadcastable.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Insert reshape to binary op's input if needed to match rank
10 //
11 //===----------------------------------------------------------------------===//
12 
19 #include "mlir/Pass/Pass.h"
21 
22 namespace mlir {
23 namespace tosa {
24 #define GEN_PASS_DEF_TOSAMAKEBROADCASTABLE
25 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
26 } // namespace tosa
27 } // namespace mlir
28 
29 using namespace mlir;
30 using namespace mlir::tosa;
31 
32 namespace {
33 
34 /// Common code to create the reshape op where necessary to make the rank of the
35 /// operations equal. input1 and input2 will be updated when the rank has
36 /// changed. The caller is expected to use these to rewrite the original
37 /// operator with the RESHAPE now in the graph.
38 /// return failure when (1) no reshape needed, or (2) output_type is specified
39 /// and it has different rank
40 LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter, Location loc,
41  RankedTensorType outputType, Value &input1,
42  Value &input2) {
43  auto input1Ty = dyn_cast<RankedTensorType>(input1.getType());
44  auto input2Ty = dyn_cast<RankedTensorType>(input2.getType());
45 
46  if (!input1Ty || !input2Ty) {
47  return rewriter.notifyMatchFailure(loc, "input not a ranked tensor");
48  }
49 
50  int64_t input1Rank = input1Ty.getRank();
51  int64_t input2Rank = input2Ty.getRank();
52 
53  if (input1Rank == input2Rank)
54  return rewriter.notifyMatchFailure(loc,
55  "cannot rewrite as its already correct");
56 
57  Value input1Copy = input1;
58  Value input2Copy = input2;
59  if (EqualizeRanks(rewriter, loc, input1Copy, input2Copy).failed()) {
60  return rewriter.notifyMatchFailure(loc, "failed to reshape inputs");
61  }
62 
63  // Verify the rank agrees with the output type if the output type is ranked.
64  if (outputType) {
65  if (outputType.getRank() !=
66  llvm::cast<RankedTensorType>(input1Copy.getType()).getRank() ||
67  outputType.getRank() !=
68  llvm::cast<RankedTensorType>(input2Copy.getType()).getRank())
69  return rewriter.notifyMatchFailure(
70  loc, "the reshaped type doesn't agrees with the ranked output type");
71  }
72 
73  input1 = input1Copy;
74  input2 = input2Copy;
75 
76  return success();
77 }
78 
79 template <typename OpTy>
80 struct ConvertTosaOp : public OpRewritePattern<OpTy> {
82 
83  LogicalResult matchAndRewrite(OpTy tosaBinaryOp,
84  PatternRewriter &rewriter) const override {
85 
86  Value input1 = tosaBinaryOp.getInput1();
87  Value input2 = tosaBinaryOp.getInput2();
88  Value output = tosaBinaryOp.getResult();
89 
90  auto outputType = dyn_cast<RankedTensorType>(output.getType());
91  if (!outputType)
92  return failure();
93 
94  if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
95  input1, input2)
96  .failed())
97  return failure();
98 
99  rewriter.replaceOpWithNewOp<OpTy>(tosaBinaryOp, outputType, input1, input2);
100 
101  return success();
102  }
103 };
104 
105 // The MulOp has an extra parameter 'shift' not present in other elementwise
106 // binary ops, that necessitates special handling of its builder.
107 template <>
108 struct ConvertTosaOp<tosa::MulOp> : public OpRewritePattern<tosa::MulOp> {
110 
111  LogicalResult matchAndRewrite(tosa::MulOp tosaBinaryOp,
112  PatternRewriter &rewriter) const override {
113 
114  Value input1 = tosaBinaryOp.getInput1();
115  Value input2 = tosaBinaryOp.getInput2();
116  int32_t shift = tosaBinaryOp.getShift();
117  Value output = tosaBinaryOp.getResult();
118  auto outputType = dyn_cast<RankedTensorType>(output.getType());
119  if (!outputType)
120  return failure();
121 
122  if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
123  input1, input2)
124  .failed())
125  return failure();
126 
127  rewriter.replaceOpWithNewOp<tosa::MulOp>(tosaBinaryOp, outputType, input1,
128  input2, shift);
129 
130  return success();
131  }
132 };
133 
134 // The ArithmeticRightShiftOp has an extra parameter 'round' not present in
135 // other elementwise binary ops, that necessitates special handling of its
136 // builder.
137 template <>
138 struct ConvertTosaOp<tosa::ArithmeticRightShiftOp>
139  : public OpRewritePattern<tosa::ArithmeticRightShiftOp> {
141 
142  LogicalResult matchAndRewrite(tosa::ArithmeticRightShiftOp tosaBinaryOp,
143  PatternRewriter &rewriter) const override {
144 
145  Value input1 = tosaBinaryOp.getInput1();
146  Value input2 = tosaBinaryOp.getInput2();
147  int32_t round = tosaBinaryOp.getRound();
148  Value output = tosaBinaryOp.getResult();
149  auto outputType = dyn_cast<RankedTensorType>(output.getType());
150  if (!outputType)
151  return failure();
152 
153  if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
154  input1, input2)
155  .failed())
156  return failure();
157 
158  rewriter.replaceOpWithNewOp<tosa::ArithmeticRightShiftOp>(
159  tosaBinaryOp, outputType, input1, input2, round);
160 
161  return success();
162  }
163 };
164 
165 template <>
166 struct ConvertTosaOp<tosa::SelectOp> : public OpRewritePattern<tosa::SelectOp> {
168 
169  LogicalResult matchAndRewrite(tosa::SelectOp tosaOp,
170  PatternRewriter &rewriter) const override {
171 
172  Value input1 = tosaOp.getPred();
173  Value input2 = tosaOp.getOnTrue();
174  Value input3 = tosaOp.getOnFalse();
175  Value output = tosaOp.getResult();
176 
177  auto outputType = dyn_cast<RankedTensorType>(output.getType());
178  if (!outputType)
179  return rewriter.notifyMatchFailure(tosaOp, "output not a ranked tensor");
180 
181  // Apply broadcasting to each pair of inputs separately, and chain them as
182  // compound as below so that the broadcasting happens all at once.
183  bool reshaped1 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType,
184  input1, input2)
185  .succeeded();
186 
187  bool reshaped2 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType,
188  input1, input3)
189  .succeeded();
190 
191  bool reshaped3 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType,
192  input2, input3)
193  .succeeded();
194 
195  if (!reshaped1 && !reshaped2 && !reshaped3)
196  return rewriter.notifyMatchFailure(
197  tosaOp,
198  "cannot rewrite as the rank of all operands is already aligned");
199 
200  int32_t result1Rank = cast<RankedTensorType>(input1.getType()).getRank();
201  int32_t result2Rank = cast<RankedTensorType>(input2.getType()).getRank();
202  int32_t result3Rank = cast<RankedTensorType>(input3.getType()).getRank();
203  int32_t outputRank = outputType.getRank();
204 
205  if ((result1Rank != result2Rank) || (result2Rank != result3Rank) ||
206  (result1Rank != outputRank))
207  return rewriter.notifyMatchFailure(
208  tosaOp, "not all ranks are aligned with each other");
209 
210  rewriter.replaceOpWithNewOp<tosa::SelectOp>(tosaOp, outputType, input1,
211  input2, input3);
212 
213  return success();
214  }
215 };
216 } // namespace
217 
218 namespace {
219 /// Pass that enables broadcast by making all input arrays have the same
220 /// number of dimensions. Insert RESHAPE operations to lower rank operand
221 struct TosaMakeBroadcastable
222  : public tosa::impl::TosaMakeBroadcastableBase<TosaMakeBroadcastable> {
223 public:
224  void runOnOperation() override {
225  auto func = getOperation();
226  RewritePatternSet patterns(func.getContext());
227  MLIRContext *ctx = func.getContext();
228  // Add the generated patterns to the list.
229  patterns.add<ConvertTosaOp<tosa::BitwiseAndOp>>(ctx);
230  patterns.add<ConvertTosaOp<tosa::BitwiseOrOp>>(ctx);
231  patterns.add<ConvertTosaOp<tosa::BitwiseXorOp>>(ctx);
232  patterns.add<ConvertTosaOp<tosa::AddOp>>(ctx);
233  patterns.add<ConvertTosaOp<tosa::SubOp>>(ctx);
234  patterns.add<ConvertTosaOp<tosa::MulOp>>(ctx);
235  patterns.add<ConvertTosaOp<tosa::DivOp>>(ctx);
236  patterns.add<ConvertTosaOp<tosa::MaximumOp>>(ctx);
237  patterns.add<ConvertTosaOp<tosa::MinimumOp>>(ctx);
238  patterns.add<ConvertTosaOp<tosa::EqualOp>>(ctx);
239  patterns.add<ConvertTosaOp<tosa::GreaterOp>>(ctx);
240  patterns.add<ConvertTosaOp<tosa::GreaterEqualOp>>(ctx);
241  patterns.add<ConvertTosaOp<tosa::LogicalLeftShiftOp>>(ctx);
242  patterns.add<ConvertTosaOp<tosa::ArithmeticRightShiftOp>>(ctx);
243  patterns.add<ConvertTosaOp<tosa::LogicalRightShiftOp>>(ctx);
244  patterns.add<ConvertTosaOp<tosa::LogicalAndOp>>(ctx);
245  patterns.add<ConvertTosaOp<tosa::LogicalOrOp>>(ctx);
246  patterns.add<ConvertTosaOp<tosa::LogicalXorOp>>(ctx);
247  patterns.add<ConvertTosaOp<tosa::SelectOp>>(ctx);
248  patterns.add<ConvertTosaOp<tosa::PowOp>>(ctx);
249  (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
250  }
251 };
252 } // namespace
253 
255  return std::make_unique<TosaMakeBroadcastable>();
256 }
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
MPInt round(const Fraction &f)
Definition: Fraction.h:133
std::unique_ptr< Pass > createTosaMakeBroadcastablePass()
LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc, Value &input1, Value &input2)
Common code to create the reshape op where necessary to make the rank of two values equal.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358