24#define GEN_PASS_DEF_TOSADOWNGRADE1P1TO1P0PASS
25#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
38 LogicalResult matchAndRewrite(tosa::CastOp op,
39 PatternRewriter &rewriter)
const override {
40 const Value input = op.getInput();
47 const bool isFp32ToBool =
48 inputElemType == f32Type && outputElemType == i1Type;
49 const bool isBoolToFp32 =
50 inputElemType == i1Type && outputElemType == f32Type;
52 if (!isFp32ToBool && !isBoolToFp32)
54 "expected cast between bool and f32");
56 const Type outputType = op.getType();
58 const Type intermediateType = cast<TensorType>(outputType).clone(i8Type);
61 tosa::CastOp::create(rewriter, op.getLoc(), intermediateType, input);
62 auto outer = tosa::CastOp::create(rewriter, op.getLoc(), outputType,
64 rewriter.
replaceOp(op, outer.getOutput());
73 LogicalResult matchAndRewrite(tosa::GatherOp op,
74 PatternRewriter &rewriter)
const override {
75 const Value values = op.getValues();
76 const Value
indices = op.getIndices();
78 const Type valuesType = values.
getType();
79 const Type resultType = op.getType();
86 op,
"expected values of bool type and indices of i32 type");
89 const Type valuesI8Type = cast<TensorType>(valuesType).clone(i8Type);
90 const Type resultI8Type = cast<TensorType>(resultType).clone(i8Type);
93 tosa::CastOp::create(rewriter, op.getLoc(), valuesI8Type, values);
94 auto gatherI8 = tosa::GatherOp::create(rewriter, op.getLoc(), resultI8Type,
95 valuesToI8.getOutput(),
indices);
96 auto i8ToBool = tosa::CastOp::create(rewriter, op.getLoc(), resultType,
97 gatherI8.getOutput());
98 rewriter.
replaceOp(op, i8ToBool.getOutput());
107 LogicalResult matchAndRewrite(tosa::ScatterOp op,
108 PatternRewriter &rewriter)
const override {
109 const Value valuesIn = op.getValuesIn();
110 const Value
indices = op.getIndices();
112 const Type valuesInType = valuesIn.
getType();
113 const Type i1Type = rewriter.
getI1Type();
118 op,
"expected values of bool type and indices of i32 type");
120 const Value input = op.getInput();
121 const Type inputType = input.
getType();
122 const Type resultType = op.getType();
124 const Type i8Type = rewriter.
getI8Type();
125 const Type valuesInI8Type = cast<TensorType>(valuesInType).clone(i8Type);
126 const Type inputI8Type = cast<TensorType>(inputType).clone(i8Type);
127 const Type resultI8Type = cast<TensorType>(resultType).clone(i8Type);
130 tosa::CastOp::create(rewriter, op.getLoc(), valuesInI8Type, valuesIn);
132 tosa::CastOp::create(rewriter, op.getLoc(), inputI8Type, input);
133 auto scatterI8 = tosa::ScatterOp::create(
134 rewriter, op.getLoc(), resultI8Type, valuesInToI8.getOutput(),
indices,
135 inputToI8.getOutput());
136 auto i8ToBool = tosa::CastOp::create(rewriter, op.getLoc(), resultType,
137 scatterI8.getValuesOut());
138 rewriter.
replaceOp(op, i8ToBool.getOutput());
143struct TosaDowngrade1p1To1p0Pass
145 TosaDowngrade1p1To1p0Pass> {
148 void runOnOperation()
override {
150 func::FuncOp func = getOperation();
152 RewritePatternSet patterns(&context);
153 patterns.add<BoolFp32CastRewrite, BoolGatherRewrite, BoolScatterRewrite>(
155 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
158 return signalPassFailure();
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Type getType() const
Return the type of this value.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...