24 #define GEN_PASS_DEF_TOSAMAKEBROADCASTABLE
25 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
41 RankedTensorType outputType,
Value &input1,
43 auto input1Ty = dyn_cast<RankedTensorType>(input1.
getType());
44 auto input2Ty = dyn_cast<RankedTensorType>(input2.
getType());
46 if (!input1Ty || !input2Ty) {
50 int64_t input1Rank = input1Ty.getRank();
51 int64_t input2Rank = input2Ty.getRank();
53 if (input1Rank == input2Rank)
55 "cannot rewrite as its already correct");
57 Value input1Copy = input1;
58 Value input2Copy = input2;
59 if (
EqualizeRanks(rewriter, loc, input1Copy, input2Copy).failed()) {
65 if (outputType.getRank() !=
66 llvm::cast<RankedTensorType>(input1Copy.
getType()).getRank() ||
67 outputType.getRank() !=
68 llvm::cast<RankedTensorType>(input2Copy.
getType()).getRank())
70 loc,
"the reshaped type doesn't agrees with the ranked output type");
79 template <
typename OpTy>
83 LogicalResult matchAndRewrite(OpTy tosaBinaryOp,
86 Value input1 = tosaBinaryOp.getInput1();
87 Value input2 = tosaBinaryOp.getInput2();
88 Value output = tosaBinaryOp.getResult();
90 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
94 if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
111 LogicalResult matchAndRewrite(tosa::MulOp tosaBinaryOp,
114 Value input1 = tosaBinaryOp.getInput1();
115 Value input2 = tosaBinaryOp.getInput2();
116 int32_t shift = tosaBinaryOp.getShift();
117 Value output = tosaBinaryOp.getResult();
118 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
122 if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
138 struct ConvertTosaOp<tosa::ArithmeticRightShiftOp>
142 LogicalResult matchAndRewrite(tosa::ArithmeticRightShiftOp tosaBinaryOp,
145 Value input1 = tosaBinaryOp.getInput1();
146 Value input2 = tosaBinaryOp.getInput2();
147 int32_t
round = tosaBinaryOp.getRound();
148 Value output = tosaBinaryOp.getResult();
149 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
153 if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
159 tosaBinaryOp, outputType, input1, input2,
round);
166 struct ConvertTosaOp<tosa::SelectOp> :
public OpRewritePattern<tosa::SelectOp> {
169 LogicalResult matchAndRewrite(tosa::SelectOp tosaOp,
172 Value input1 = tosaOp.getPred();
173 Value input2 = tosaOp.getOnTrue();
174 Value input3 = tosaOp.getOnFalse();
175 Value output = tosaOp.getResult();
177 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
183 bool reshaped1 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType,
187 bool reshaped2 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType,
191 bool reshaped3 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType,
195 if (!reshaped1 && !reshaped2 && !reshaped3)
198 "cannot rewrite as the rank of all operands is already aligned");
200 int32_t result1Rank = cast<RankedTensorType>(input1.
getType()).getRank();
201 int32_t result2Rank = cast<RankedTensorType>(input2.
getType()).getRank();
202 int32_t result3Rank = cast<RankedTensorType>(input3.
getType()).getRank();
203 int32_t outputRank = outputType.getRank();
205 if ((result1Rank != result2Rank) || (result2Rank != result3Rank) ||
206 (result1Rank != outputRank))
208 tosaOp,
"not all ranks are aligned with each other");
221 struct TosaMakeBroadcastable
222 :
public tosa::impl::TosaMakeBroadcastableBase<TosaMakeBroadcastable> {
224 void runOnOperation()
override {
225 auto func = getOperation();
229 patterns.add<ConvertTosaOp<tosa::BitwiseAndOp>>(ctx);
230 patterns.add<ConvertTosaOp<tosa::BitwiseOrOp>>(ctx);
231 patterns.add<ConvertTosaOp<tosa::BitwiseXorOp>>(ctx);
232 patterns.add<ConvertTosaOp<tosa::AddOp>>(ctx);
233 patterns.add<ConvertTosaOp<tosa::SubOp>>(ctx);
234 patterns.add<ConvertTosaOp<tosa::MulOp>>(ctx);
235 patterns.add<ConvertTosaOp<tosa::IntDivOp>>(ctx);
236 patterns.add<ConvertTosaOp<tosa::MaximumOp>>(ctx);
237 patterns.add<ConvertTosaOp<tosa::MinimumOp>>(ctx);
238 patterns.add<ConvertTosaOp<tosa::EqualOp>>(ctx);
239 patterns.add<ConvertTosaOp<tosa::GreaterOp>>(ctx);
240 patterns.add<ConvertTosaOp<tosa::GreaterEqualOp>>(ctx);
241 patterns.add<ConvertTosaOp<tosa::LogicalLeftShiftOp>>(ctx);
242 patterns.add<ConvertTosaOp<tosa::ArithmeticRightShiftOp>>(ctx);
243 patterns.add<ConvertTosaOp<tosa::LogicalRightShiftOp>>(ctx);
244 patterns.add<ConvertTosaOp<tosa::LogicalAndOp>>(ctx);
245 patterns.add<ConvertTosaOp<tosa::LogicalOrOp>>(ctx);
246 patterns.add<ConvertTosaOp<tosa::LogicalXorOp>>(ctx);
247 patterns.add<ConvertTosaOp<tosa::SelectOp>>(ctx);
248 patterns.add<ConvertTosaOp<tosa::PowOp>>(ctx);
255 return std::make_unique<TosaMakeBroadcastable>();
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
DynamicAPInt round(const Fraction &f)
std::unique_ptr< Pass > createTosaMakeBroadcastablePass()
LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc, Value &input1, Value &input2)
Common code to create the reshape op where necessary to make the rank of two values equal.
Include the generated interface declarations.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...