21 #define GEN_PASS_DEF_TOSAMAKEBROADCASTABLEPASS
22 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
38 RankedTensorType outputType,
Value &input1,
40 auto input1Ty = dyn_cast<RankedTensorType>(input1.
getType());
41 auto input2Ty = dyn_cast<RankedTensorType>(input2.
getType());
43 if (!input1Ty || !input2Ty) {
47 int64_t input1Rank = input1Ty.getRank();
48 int64_t input2Rank = input2Ty.getRank();
50 if (input1Rank == input2Rank)
52 "cannot rewrite as its already correct");
54 Value input1Copy = input1;
55 Value input2Copy = input2;
56 if (
EqualizeRanks(rewriter, loc, input1Copy, input2Copy).failed()) {
62 if (outputType.getRank() !=
63 llvm::cast<RankedTensorType>(input1Copy.
getType()).getRank() ||
64 outputType.getRank() !=
65 llvm::cast<RankedTensorType>(input2Copy.
getType()).getRank())
67 loc,
"the reshaped type doesn't agrees with the ranked output type");
76 template <
typename OpTy>
80 LogicalResult matchAndRewrite(OpTy tosaBinaryOp,
83 Value input1 = tosaBinaryOp.getInput1();
84 Value input2 = tosaBinaryOp.getInput2();
85 Value output = tosaBinaryOp.getResult();
87 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
91 if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
108 LogicalResult matchAndRewrite(tosa::MulOp tosaBinaryOp,
111 Value input1 = tosaBinaryOp.getInput1();
112 Value input2 = tosaBinaryOp.getInput2();
113 Value shift = tosaBinaryOp.getShift();
114 Value output = tosaBinaryOp.getResult();
115 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
119 if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
135 struct ConvertTosaOp<tosa::ArithmeticRightShiftOp>
139 LogicalResult matchAndRewrite(tosa::ArithmeticRightShiftOp tosaBinaryOp,
142 Value input1 = tosaBinaryOp.getInput1();
143 Value input2 = tosaBinaryOp.getInput2();
144 int32_t
round = tosaBinaryOp.getRound();
145 Value output = tosaBinaryOp.getResult();
146 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
150 if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
156 tosaBinaryOp, outputType, input1, input2,
round);
163 struct ConvertTosaOp<tosa::SelectOp> :
public OpRewritePattern<tosa::SelectOp> {
166 LogicalResult matchAndRewrite(tosa::SelectOp tosaOp,
169 Value input1 = tosaOp.getPred();
170 Value input2 = tosaOp.getOnTrue();
171 Value input3 = tosaOp.getOnFalse();
172 Value output = tosaOp.getResult();
174 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
180 bool reshaped1 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType,
184 bool reshaped2 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType,
188 bool reshaped3 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType,
192 if (!reshaped1 && !reshaped2 && !reshaped3)
195 "cannot rewrite as the rank of all operands is already aligned");
197 int32_t result1Rank = cast<RankedTensorType>(input1.
getType()).getRank();
198 int32_t result2Rank = cast<RankedTensorType>(input2.
getType()).getRank();
199 int32_t result3Rank = cast<RankedTensorType>(input3.
getType()).getRank();
200 int32_t outputRank = outputType.getRank();
202 if ((result1Rank != result2Rank) || (result2Rank != result3Rank) ||
203 (result1Rank != outputRank))
205 tosaOp,
"not all ranks are aligned with each other");
218 struct TosaMakeBroadcastable
219 :
public tosa::impl::TosaMakeBroadcastablePassBase<TosaMakeBroadcastable> {
221 void runOnOperation()
override {
222 auto func = getOperation();
226 patterns.add<ConvertTosaOp<tosa::BitwiseAndOp>>(ctx);
227 patterns.add<ConvertTosaOp<tosa::BitwiseOrOp>>(ctx);
228 patterns.add<ConvertTosaOp<tosa::BitwiseXorOp>>(ctx);
229 patterns.add<ConvertTosaOp<tosa::AddOp>>(ctx);
230 patterns.add<ConvertTosaOp<tosa::SubOp>>(ctx);
231 patterns.add<ConvertTosaOp<tosa::MulOp>>(ctx);
232 patterns.add<ConvertTosaOp<tosa::IntDivOp>>(ctx);
233 patterns.add<ConvertTosaOp<tosa::MaximumOp>>(ctx);
234 patterns.add<ConvertTosaOp<tosa::MinimumOp>>(ctx);
235 patterns.add<ConvertTosaOp<tosa::EqualOp>>(ctx);
236 patterns.add<ConvertTosaOp<tosa::GreaterOp>>(ctx);
237 patterns.add<ConvertTosaOp<tosa::GreaterEqualOp>>(ctx);
238 patterns.add<ConvertTosaOp<tosa::LogicalLeftShiftOp>>(ctx);
239 patterns.add<ConvertTosaOp<tosa::ArithmeticRightShiftOp>>(ctx);
240 patterns.add<ConvertTosaOp<tosa::LogicalRightShiftOp>>(ctx);
241 patterns.add<ConvertTosaOp<tosa::LogicalAndOp>>(ctx);
242 patterns.add<ConvertTosaOp<tosa::LogicalOrOp>>(ctx);
243 patterns.add<ConvertTosaOp<tosa::LogicalXorOp>>(ctx);
244 patterns.add<ConvertTosaOp<tosa::SelectOp>>(ctx);
245 patterns.add<ConvertTosaOp<tosa::PowOp>>(ctx);
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
DynamicAPInt round(const Fraction &f)
LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc, Value &input1, Value &input2)
Common code to create the reshape op where necessary to make the rank of two values equal.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...