MLIR  14.0.0git
TosaMakeBroadcastable.cpp
Go to the documentation of this file.
1 //===- TosaMakeBroadcastable.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Insert reshape to binary op's input if needed to match rank
10 //
11 //===----------------------------------------------------------------------===//
12 
19 #include "mlir/Pass/Pass.h"
21 
22 using namespace mlir;
23 using namespace mlir::tosa;
24 
25 /// There are two potential ways implementing broadcast:
26 /// a. https://www.tensorflow.org/xla/broadcasting#formal_definition
27 /// b. https://numpy.org/doc/stable/user/basics.broadcasting.html
28 /// This pass implements b (numpy style) now.
29 
30 /// In this pass, we insert RESHAPE operators to increase the rank of the
31 /// lower rank operand as a first step in the broadcasting process. The TOSA
32 /// operators that support broadcast require that the rank of the operands
33 /// are equal.
34 
35 // Examples:
36 // If lower=[c], higher=[a, b, c], [c] reshaped into [1, 1, c].
37 // If lower=[b, c], higher=[a, b, c], [b, c] reshaped into [1, b, c].
38 // If lower=[a], higher=[a, a], [a] reshaped into [1, a].
39 // If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a].
40 // If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1].
41 
42 static LogicalResult
44  ArrayRef<int64_t> lowerRankShape,
45  SmallVectorImpl<int64_t> &reshapeOutputShape) {
46  // Initialize new shapes with [1] * higherRank.
47  int64_t higherRank = higherRankShape.size();
48  int64_t lowerRank = lowerRankShape.size();
49 
50  reshapeOutputShape.assign(higherRank, 1);
51 
52  int64_t higherRankDim;
53  int64_t lowerRankDim;
54 
55  for (int64_t i = higherRank - 1, j = lowerRank - 1; i >= 0 && j >= 0;
56  i--, j--) {
57  higherRankDim = higherRankShape[i];
58  lowerRankDim = lowerRankShape[j];
59 
60  if (lowerRankDim == 1 && higherRankDim > 1)
61  reshapeOutputShape[i] = 1;
62  else if ((lowerRankDim > 1 && higherRankDim == 1) ||
63  (lowerRankDim == higherRankDim))
64  reshapeOutputShape[i] = lowerRankDim;
65  else if (higherRankDim != lowerRankDim)
66  return failure();
67  }
68  return success();
69 }
70 
71 /// Common code to create the reshape op where necessary to make the rank of the
72 /// operations equal. Returns the updated input1 and input2 for the original
73 /// input. The caller is expected to use these to rewrite the original operator
74 /// with the RESHAPE now in the graph.
76  Location loc,
77  RankedTensorType outputType,
78  Value input1, Value input2,
79  Value &outInput1, Value &outInput2) {
80  auto input1Ty = input1.getType().dyn_cast<RankedTensorType>();
81  auto input2Ty = input2.getType().dyn_cast<RankedTensorType>();
82 
83  if (!input1Ty || !input2Ty)
84  return failure();
85 
86  int64_t input1Rank = input1Ty.getRank();
87  int64_t input2Rank = input2Ty.getRank();
88 
89  Value higherTensorValue, lowerTensorValue;
90  // Cannot rewrite as its already correct.
91  if (input1Rank == input2Rank)
92  return failure();
93 
94  if (input1Rank > input2Rank) {
95  higherTensorValue = input1;
96  lowerTensorValue = input2;
97  } else {
98  higherTensorValue = input2;
99  lowerTensorValue = input1;
100  }
101 
102  ArrayRef<int64_t> higherRankShape =
103  higherTensorValue.getType().cast<RankedTensorType>().getShape();
104  (void)higherRankShape;
105  ArrayRef<int64_t> lowerRankShape =
106  lowerTensorValue.getType().cast<RankedTensorType>().getShape();
107 
108  SmallVector<int64_t, 4> reshapeOutputShape;
109 
110  if (computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape)
111  .failed())
112  return failure();
113 
114  auto reshapeInputType = lowerTensorValue.getType().cast<RankedTensorType>();
115  auto reshapeOutputType = RankedTensorType::get(
116  ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
117 
118  // Verify the rank agrees with the output type if the output type is ranked.
119  if (outputType) {
120  if (outputType.getShape().size() != reshapeOutputShape.size() ||
121  outputType.getShape().size() != higherRankShape.size())
122  return failure();
123  }
124 
125  auto reshapeLower = rewriter.create<tosa::ReshapeOp>(
126  loc, reshapeOutputType, lowerTensorValue,
127  rewriter.getI64ArrayAttr(reshapeOutputShape));
128 
129  if (input1Rank > input2Rank) {
130  outInput1 = higherTensorValue;
131  outInput2 = reshapeLower.getResult();
132  } else {
133  outInput1 = reshapeLower.getResult();
134  outInput2 = higherTensorValue;
135  }
136 
137  return success();
138 }
139 
140 namespace {
141 template <typename OpTy>
142 struct ConvertTosaOp : public OpRewritePattern<OpTy> {
144 
145  LogicalResult matchAndRewrite(OpTy tosaBinaryOp,
146  PatternRewriter &rewriter) const override {
147 
148  Value input1 = tosaBinaryOp.input1();
149  Value input2 = tosaBinaryOp.input2();
150  Value output = tosaBinaryOp.getResult();
151 
152  auto outputType = output.getType().dyn_cast<RankedTensorType>();
153  if (!outputType)
154  return failure();
155 
156  Value outInput1, outInput2;
157  if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
158  input1, input2, outInput1, outInput2)
159  .failed())
160  return failure();
161 
162  rewriter.replaceOpWithNewOp<OpTy>(tosaBinaryOp, outputType, outInput1,
163  outInput2);
164 
165  return success();
166  }
167 };
168 
169 // The MulOp has an extra parameter 'shift' not present in other elementwise
170 // binary ops, that necessitates special handling of its builder.
171 template <>
172 struct ConvertTosaOp<tosa::MulOp> : public OpRewritePattern<tosa::MulOp> {
174 
175  LogicalResult matchAndRewrite(tosa::MulOp tosaBinaryOp,
176  PatternRewriter &rewriter) const override {
177 
178  Value input1 = tosaBinaryOp.input1();
179  Value input2 = tosaBinaryOp.input2();
180  int32_t shift = tosaBinaryOp.shift();
181  Value output = tosaBinaryOp.getResult();
182  auto outputType = output.getType().dyn_cast<RankedTensorType>();
183  if (!outputType)
184  return failure();
185 
186  Value outInput1, outInput2;
187  if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
188  input1, input2, outInput1, outInput2)
189  .failed())
190  return failure();
191 
192  rewriter.replaceOpWithNewOp<tosa::MulOp>(tosaBinaryOp, outputType,
193  outInput1, outInput2, shift);
194 
195  return success();
196  }
197 };
198 
199 // The ArithmeticRightShiftOp has an extra parameter 'round' not present in
200 // other elementwise binary ops, that necessitates special handling of its
201 // builder.
202 template <>
203 struct ConvertTosaOp<tosa::ArithmeticRightShiftOp>
204  : public OpRewritePattern<tosa::ArithmeticRightShiftOp> {
206 
207  LogicalResult matchAndRewrite(tosa::ArithmeticRightShiftOp tosaBinaryOp,
208  PatternRewriter &rewriter) const override {
209 
210  Value input1 = tosaBinaryOp.input1();
211  Value input2 = tosaBinaryOp.input2();
212  int32_t round = tosaBinaryOp.round();
213  Value output = tosaBinaryOp.getResult();
214  auto outputType = output.getType().dyn_cast<RankedTensorType>();
215  if (!outputType)
216  return failure();
217 
218  Value outInput1, outInput2;
219  if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
220  input1, input2, outInput1, outInput2)
221  .failed())
222  return failure();
223 
224  rewriter.replaceOpWithNewOp<tosa::ArithmeticRightShiftOp>(
225  tosaBinaryOp, outputType, outInput1, outInput2, round);
226 
227  return success();
228  }
229 };
230 } // namespace
231 
232 namespace {
233 /// Pass that enables broadcast by making all input arrays have the same
234 /// number of dimensions. Insert RESHAPE operations to lower rank operand
235 struct TosaMakeBroadcastable
236  : public TosaMakeBroadcastableBase<TosaMakeBroadcastable> {
237 public:
238  void runOnOperation() override {
239  auto func = getOperation();
240  RewritePatternSet patterns(func.getContext());
241  MLIRContext *ctx = func.getContext();
242  // Add the generated patterns to the list.
243  patterns.add<ConvertTosaOp<tosa::BitwiseAndOp>>(ctx);
244  patterns.add<ConvertTosaOp<tosa::BitwiseOrOp>>(ctx);
245  patterns.add<ConvertTosaOp<tosa::BitwiseXorOp>>(ctx);
246  patterns.add<ConvertTosaOp<tosa::AddOp>>(ctx);
247  patterns.add<ConvertTosaOp<tosa::SubOp>>(ctx);
248  patterns.add<ConvertTosaOp<tosa::MulOp>>(ctx);
249  patterns.add<ConvertTosaOp<tosa::DivOp>>(ctx);
250  patterns.add<ConvertTosaOp<tosa::MaximumOp>>(ctx);
251  patterns.add<ConvertTosaOp<tosa::MinimumOp>>(ctx);
252  patterns.add<ConvertTosaOp<tosa::EqualOp>>(ctx);
253  patterns.add<ConvertTosaOp<tosa::GreaterOp>>(ctx);
254  patterns.add<ConvertTosaOp<tosa::GreaterEqualOp>>(ctx);
255  patterns.add<ConvertTosaOp<tosa::LogicalLeftShiftOp>>(ctx);
256  patterns.add<ConvertTosaOp<tosa::ArithmeticRightShiftOp>>(ctx);
257  patterns.add<ConvertTosaOp<tosa::LogicalRightShiftOp>>(ctx);
258  patterns.add<ConvertTosaOp<tosa::LogicalAndOp>>(ctx);
259  patterns.add<ConvertTosaOp<tosa::LogicalOrOp>>(ctx);
260  patterns.add<ConvertTosaOp<tosa::LogicalXorOp>>(ctx);
261  patterns.add<ConvertTosaOp<tosa::PowOp>>(ctx);
262  (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
263  }
264 };
265 } // namespace
266 
268  return std::make_unique<TosaMakeBroadcastable>();
269 }
Include the generated interface declarations.
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:881
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:220
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
std::unique_ptr< Pass > createTosaMakeBroadcastablePass()
static LogicalResult computeReshapeOutput(ArrayRef< int64_t > higherRankShape, ArrayRef< int64_t > lowerRankShape, SmallVectorImpl< int64_t > &reshapeOutputShape)
There are two potential ways implementing broadcast: a.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
U dyn_cast() const
Definition: Types.h:244
static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter, Location loc, RankedTensorType outputType, Value input1, Value input2, Value &outInput1, Value &outInput2)
Common code to create the reshape op where necessary to make the rank of the operations equal...
Eliminates identifier at the specified position using Fourier-Motzkin variable elimination.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
OpTy replaceOpWithNewOp(Operation *op, Args &&... args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:741
Type getType() const
Return the type of this value.
Definition: Value.h:117
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
LogicalResult applyPatternsAndFoldGreedily(MutableArrayRef< Region > regions, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig())
Rewrite the regions of the specified operation, which must be isolated from above, by repeatedly applying the highest benefit patterns in a greedy work-list driven manner.
U cast() const
Definition: Types.h:250