MLIR  22.0.0git
TosaMakeBroadcastable.cpp
Go to the documentation of this file.
1 //===- TosaMakeBroadcastable.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Insert reshape to binary op's input if needed to match rank
10 //
11 //===----------------------------------------------------------------------===//
12 
18 
19 namespace mlir {
20 namespace tosa {
21 #define GEN_PASS_DEF_TOSAMAKEBROADCASTABLEPASS
22 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
23 } // namespace tosa
24 } // namespace mlir
25 
26 using namespace mlir;
27 using namespace mlir::tosa;
28 
29 namespace {
30 
31 /// Common code to create the reshape op where necessary to make the rank of the
32 /// operations equal. input1 and input2 will be updated when the rank has
33 /// changed. The caller is expected to use these to rewrite the original
34 /// operator with the RESHAPE now in the graph.
35 /// return failure when (1) no reshape needed, or (2) output_type is specified
36 /// and it has different rank
37 LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter, Location loc,
38  RankedTensorType outputType, Value &input1,
39  Value &input2) {
40  auto input1Ty = dyn_cast<RankedTensorType>(input1.getType());
41  auto input2Ty = dyn_cast<RankedTensorType>(input2.getType());
42 
43  if (!input1Ty || !input2Ty) {
44  return rewriter.notifyMatchFailure(loc, "input not a ranked tensor");
45  }
46 
47  int64_t input1Rank = input1Ty.getRank();
48  int64_t input2Rank = input2Ty.getRank();
49 
50  if (input1Rank == input2Rank)
51  return rewriter.notifyMatchFailure(loc,
52  "cannot rewrite as its already correct");
53 
54  Value input1Copy = input1;
55  Value input2Copy = input2;
56  if (EqualizeRanks(rewriter, loc, input1Copy, input2Copy).failed()) {
57  return rewriter.notifyMatchFailure(loc, "failed to reshape inputs");
58  }
59 
60  // Verify the rank agrees with the output type if the output type is ranked.
61  if (outputType) {
62  if (outputType.getRank() !=
63  llvm::cast<RankedTensorType>(input1Copy.getType()).getRank() ||
64  outputType.getRank() !=
65  llvm::cast<RankedTensorType>(input2Copy.getType()).getRank())
66  return rewriter.notifyMatchFailure(
67  loc, "the reshaped type doesn't agrees with the ranked output type");
68  }
69 
70  input1 = input1Copy;
71  input2 = input2Copy;
72 
73  return success();
74 }
75 
76 template <typename OpTy>
77 struct ConvertTosaOp : public OpRewritePattern<OpTy> {
79 
80  LogicalResult matchAndRewrite(OpTy tosaBinaryOp,
81  PatternRewriter &rewriter) const override {
82 
83  Value input1 = tosaBinaryOp.getInput1();
84  Value input2 = tosaBinaryOp.getInput2();
85  Value output = tosaBinaryOp.getResult();
86 
87  auto outputType = dyn_cast<RankedTensorType>(output.getType());
88  if (!outputType)
89  return failure();
90 
91  if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
92  input1, input2)
93  .failed())
94  return failure();
95 
96  rewriter.replaceOpWithNewOp<OpTy>(tosaBinaryOp, outputType, input1, input2);
97 
98  return success();
99  }
100 };
101 
102 // The MulOp has an extra parameter 'shift' not present in other elementwise
103 // binary ops, that necessitates special handling of its builder.
104 template <>
105 struct ConvertTosaOp<tosa::MulOp> : public OpRewritePattern<tosa::MulOp> {
107 
108  LogicalResult matchAndRewrite(tosa::MulOp tosaBinaryOp,
109  PatternRewriter &rewriter) const override {
110 
111  Value input1 = tosaBinaryOp.getInput1();
112  Value input2 = tosaBinaryOp.getInput2();
113  Value shift = tosaBinaryOp.getShift();
114  Value output = tosaBinaryOp.getResult();
115  auto outputType = dyn_cast<RankedTensorType>(output.getType());
116  if (!outputType)
117  return failure();
118 
119  if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
120  input1, input2)
121  .failed())
122  return failure();
123 
124  rewriter.replaceOpWithNewOp<tosa::MulOp>(tosaBinaryOp, outputType, input1,
125  input2, shift);
126 
127  return success();
128  }
129 };
130 
131 // The ArithmeticRightShiftOp has an extra parameter 'round' not present in
132 // other elementwise binary ops, that necessitates special handling of its
133 // builder.
134 template <>
135 struct ConvertTosaOp<tosa::ArithmeticRightShiftOp>
136  : public OpRewritePattern<tosa::ArithmeticRightShiftOp> {
138 
139  LogicalResult matchAndRewrite(tosa::ArithmeticRightShiftOp tosaBinaryOp,
140  PatternRewriter &rewriter) const override {
141 
142  Value input1 = tosaBinaryOp.getInput1();
143  Value input2 = tosaBinaryOp.getInput2();
144  int32_t round = tosaBinaryOp.getRound();
145  Value output = tosaBinaryOp.getResult();
146  auto outputType = dyn_cast<RankedTensorType>(output.getType());
147  if (!outputType)
148  return failure();
149 
150  if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
151  input1, input2)
152  .failed())
153  return failure();
154 
155  rewriter.replaceOpWithNewOp<tosa::ArithmeticRightShiftOp>(
156  tosaBinaryOp, outputType, input1, input2, round);
157 
158  return success();
159  }
160 };
161 
162 template <>
163 struct ConvertTosaOp<tosa::SelectOp> : public OpRewritePattern<tosa::SelectOp> {
165 
166  LogicalResult matchAndRewrite(tosa::SelectOp tosaOp,
167  PatternRewriter &rewriter) const override {
168 
169  Value input1 = tosaOp.getPred();
170  Value input2 = tosaOp.getOnTrue();
171  Value input3 = tosaOp.getOnFalse();
172  Value output = tosaOp.getResult();
173 
174  auto outputType = dyn_cast<RankedTensorType>(output.getType());
175  if (!outputType)
176  return rewriter.notifyMatchFailure(tosaOp, "output not a ranked tensor");
177 
178  // Apply broadcasting to each pair of inputs separately, and chain them as
179  // compound as below so that the broadcasting happens all at once.
180  bool reshaped1 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType,
181  input1, input2)
182  .succeeded();
183 
184  bool reshaped2 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType,
185  input1, input3)
186  .succeeded();
187 
188  bool reshaped3 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType,
189  input2, input3)
190  .succeeded();
191 
192  if (!reshaped1 && !reshaped2 && !reshaped3)
193  return rewriter.notifyMatchFailure(
194  tosaOp,
195  "cannot rewrite as the rank of all operands is already aligned");
196 
197  int32_t result1Rank = cast<RankedTensorType>(input1.getType()).getRank();
198  int32_t result2Rank = cast<RankedTensorType>(input2.getType()).getRank();
199  int32_t result3Rank = cast<RankedTensorType>(input3.getType()).getRank();
200  int32_t outputRank = outputType.getRank();
201 
202  if ((result1Rank != result2Rank) || (result2Rank != result3Rank) ||
203  (result1Rank != outputRank))
204  return rewriter.notifyMatchFailure(
205  tosaOp, "not all ranks are aligned with each other");
206 
207  rewriter.replaceOpWithNewOp<tosa::SelectOp>(tosaOp, outputType, input1,
208  input2, input3);
209 
210  return success();
211  }
212 };
213 } // namespace
214 
215 namespace {
216 /// Pass that enables broadcast by making all input arrays have the same
217 /// number of dimensions. Insert RESHAPE operations to lower rank operand
218 struct TosaMakeBroadcastable
219  : public tosa::impl::TosaMakeBroadcastablePassBase<TosaMakeBroadcastable> {
220 public:
221  void runOnOperation() override {
222  auto func = getOperation();
223  RewritePatternSet patterns(func.getContext());
224  MLIRContext *ctx = func.getContext();
225  // Add the generated patterns to the list.
226  patterns.add<ConvertTosaOp<tosa::BitwiseAndOp>>(ctx);
227  patterns.add<ConvertTosaOp<tosa::BitwiseOrOp>>(ctx);
228  patterns.add<ConvertTosaOp<tosa::BitwiseXorOp>>(ctx);
229  patterns.add<ConvertTosaOp<tosa::AddOp>>(ctx);
230  patterns.add<ConvertTosaOp<tosa::SubOp>>(ctx);
231  patterns.add<ConvertTosaOp<tosa::MulOp>>(ctx);
232  patterns.add<ConvertTosaOp<tosa::IntDivOp>>(ctx);
233  patterns.add<ConvertTosaOp<tosa::MaximumOp>>(ctx);
234  patterns.add<ConvertTosaOp<tosa::MinimumOp>>(ctx);
235  patterns.add<ConvertTosaOp<tosa::EqualOp>>(ctx);
236  patterns.add<ConvertTosaOp<tosa::GreaterOp>>(ctx);
237  patterns.add<ConvertTosaOp<tosa::GreaterEqualOp>>(ctx);
238  patterns.add<ConvertTosaOp<tosa::LogicalLeftShiftOp>>(ctx);
239  patterns.add<ConvertTosaOp<tosa::ArithmeticRightShiftOp>>(ctx);
240  patterns.add<ConvertTosaOp<tosa::LogicalRightShiftOp>>(ctx);
241  patterns.add<ConvertTosaOp<tosa::LogicalAndOp>>(ctx);
242  patterns.add<ConvertTosaOp<tosa::LogicalOrOp>>(ctx);
243  patterns.add<ConvertTosaOp<tosa::LogicalXorOp>>(ctx);
244  patterns.add<ConvertTosaOp<tosa::SelectOp>>(ctx);
245  patterns.add<ConvertTosaOp<tosa::PowOp>>(ctx);
246  (void)applyPatternsGreedily(func, std::move(patterns));
247  }
248 };
249 } // namespace
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:769
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:702
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
DynamicAPInt round(const Fraction &f)
Definition: Fraction.h:136
LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc, Value &input1, Value &input2)
Common code to create the reshape op where necessary to make the rank of two values equal.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314