24 #define GEN_PASS_DEF_TOSAMAKEBROADCASTABLE
25 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
41 RankedTensorType outputType,
Value &input1,
43 auto input1Ty = dyn_cast<RankedTensorType>(input1.
getType());
44 auto input2Ty = dyn_cast<RankedTensorType>(input2.
getType());
46 if (!input1Ty || !input2Ty) {
50 int64_t input1Rank = input1Ty.getRank();
51 int64_t input2Rank = input2Ty.getRank();
53 if (input1Rank == input2Rank)
55 "cannot rewrite as its already correct");
57 Value input1Copy = input1;
58 Value input2Copy = input2;
65 if (outputType.getRank() !=
66 llvm::cast<RankedTensorType>(input1Copy.
getType()).getRank() ||
67 outputType.getRank() !=
68 llvm::cast<RankedTensorType>(input2Copy.
getType()).getRank())
70 loc,
"the reshaped type doesn't agrees with the ranked output type");
79 template <
typename OpTy>
86 Value input1 = tosaBinaryOp.getInput1();
87 Value input2 = tosaBinaryOp.getInput2();
88 Value output = tosaBinaryOp.getResult();
90 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
94 if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
114 Value input1 = tosaBinaryOp.getInput1();
115 Value input2 = tosaBinaryOp.getInput2();
116 int32_t shift = tosaBinaryOp.getShift();
117 Value output = tosaBinaryOp.getResult();
118 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
122 if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
138 struct ConvertTosaOp<tosa::ArithmeticRightShiftOp>
142 LogicalResult matchAndRewrite(tosa::ArithmeticRightShiftOp tosaBinaryOp,
145 Value input1 = tosaBinaryOp.getInput1();
146 Value input2 = tosaBinaryOp.getInput2();
147 int32_t
round = tosaBinaryOp.getRound();
148 Value output = tosaBinaryOp.getResult();
149 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
153 if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
159 tosaBinaryOp, outputType, input1, input2,
round);
166 struct ConvertTosaOp<tosa::SelectOp> :
public OpRewritePattern<tosa::SelectOp> {
172 Value input1 = tosaOp.getPred();
173 Value input2 = tosaOp.getOnTrue();
174 Value input3 = tosaOp.getOnFalse();
175 Value output = tosaOp.getResult();
177 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
183 bool reshaped1 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType,
187 bool reshaped2 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType,
191 bool reshaped3 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType,
195 if (!reshaped1 && !reshaped2 && !reshaped3)
198 "cannot rewrite as the rank of all operands is already aligned");
200 int32_t result1Rank = cast<RankedTensorType>(input1.
getType()).getRank();
201 int32_t result2Rank = cast<RankedTensorType>(input2.
getType()).getRank();
202 int32_t result3Rank = cast<RankedTensorType>(input3.
getType()).getRank();
203 int32_t outputRank = outputType.getRank();
205 if ((result1Rank != result2Rank) || (result2Rank != result3Rank) ||
206 (result1Rank != outputRank))
208 tosaOp,
"not all ranks are aligned with each other");
221 struct TosaMakeBroadcastable
222 :
public tosa::impl::TosaMakeBroadcastableBase<TosaMakeBroadcastable> {
224 void runOnOperation()
override {
225 auto func = getOperation();
229 patterns.add<ConvertTosaOp<tosa::BitwiseAndOp>>(ctx);
230 patterns.add<ConvertTosaOp<tosa::BitwiseOrOp>>(ctx);
231 patterns.add<ConvertTosaOp<tosa::BitwiseXorOp>>(ctx);
232 patterns.add<ConvertTosaOp<tosa::AddOp>>(ctx);
233 patterns.add<ConvertTosaOp<tosa::SubOp>>(ctx);
234 patterns.add<ConvertTosaOp<tosa::MulOp>>(ctx);
235 patterns.add<ConvertTosaOp<tosa::DivOp>>(ctx);
236 patterns.add<ConvertTosaOp<tosa::MaximumOp>>(ctx);
237 patterns.add<ConvertTosaOp<tosa::MinimumOp>>(ctx);
238 patterns.add<ConvertTosaOp<tosa::EqualOp>>(ctx);
239 patterns.add<ConvertTosaOp<tosa::GreaterOp>>(ctx);
240 patterns.add<ConvertTosaOp<tosa::GreaterEqualOp>>(ctx);
241 patterns.add<ConvertTosaOp<tosa::LogicalLeftShiftOp>>(ctx);
242 patterns.add<ConvertTosaOp<tosa::ArithmeticRightShiftOp>>(ctx);
243 patterns.add<ConvertTosaOp<tosa::LogicalRightShiftOp>>(ctx);
244 patterns.add<ConvertTosaOp<tosa::LogicalAndOp>>(ctx);
245 patterns.add<ConvertTosaOp<tosa::LogicalOrOp>>(ctx);
246 patterns.add<ConvertTosaOp<tosa::LogicalXorOp>>(ctx);
247 patterns.add<ConvertTosaOp<tosa::SelectOp>>(ctx);
248 patterns.add<ConvertTosaOp<tosa::PowOp>>(ctx);
255 return std::make_unique<TosaMakeBroadcastable>();
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
MPInt round(const Fraction &f)
std::unique_ptr< Pass > createTosaMakeBroadcastablePass()
LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc, Value &input1, Value &input2)
Common code to create the reshape op where necessary to make the rank of two values equal.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...