MLIR 23.0.0git
TosaDowngrade1p1To1p0.cpp
Go to the documentation of this file.
1//===- TosaDowngrade1_1To1_0.cpp -----------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Rewrites constructs which are only compatible in TOSA specification 1.1 and
10// above to their TOSA 1.0 counterparts where possible. Downgrading is
11// best-effort and validation should be performed afterwards to ensure
12// compatibility with the TOSA 1.0 specification.
13//
14//===----------------------------------------------------------------------===//
15
17
21
22namespace mlir {
23namespace tosa {
24#define GEN_PASS_DEF_TOSADOWNGRADE1P1TO1P0PASS
25#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
26} // namespace tosa
27} // namespace mlir
28
29using namespace mlir;
30using namespace mlir::tosa;
31
32namespace {
33
34class BoolFp32CastRewrite : public OpRewritePattern<tosa::CastOp> {
35public:
37
38 LogicalResult matchAndRewrite(tosa::CastOp op,
39 PatternRewriter &rewriter) const override {
40 const Value input = op.getInput();
41
42 const Type i1Type = rewriter.getI1Type();
43 const Type f32Type = rewriter.getF32Type();
44
45 const Type inputElemType = getElementTypeOrSelf(input.getType());
46 const Type outputElemType = getElementTypeOrSelf(op.getType());
47 const bool isFp32ToBool =
48 inputElemType == f32Type && outputElemType == i1Type;
49 const bool isBoolToFp32 =
50 inputElemType == i1Type && outputElemType == f32Type;
51
52 if (!isFp32ToBool && !isBoolToFp32)
53 return rewriter.notifyMatchFailure(op,
54 "expected cast between bool and f32");
55
56 const Type outputType = op.getType();
57 const Type i8Type = rewriter.getI8Type();
58 const Type intermediateType = cast<TensorType>(outputType).clone(i8Type);
59
60 auto inner =
61 tosa::CastOp::create(rewriter, op.getLoc(), intermediateType, input);
62 auto outer = tosa::CastOp::create(rewriter, op.getLoc(), outputType,
63 inner.getOutput());
64 rewriter.replaceOp(op, outer.getOutput());
65 return success();
66 }
67};
68
69class BoolGatherRewrite : public OpRewritePattern<tosa::GatherOp> {
70public:
72
73 LogicalResult matchAndRewrite(tosa::GatherOp op,
74 PatternRewriter &rewriter) const override {
75 const Value values = op.getValues();
76 const Value indices = op.getIndices();
77
78 const Type valuesType = values.getType();
79 const Type resultType = op.getType();
80
81 const Type i1Type = rewriter.getI1Type();
82 const Type i32Type = rewriter.getI32Type();
83 if (getElementTypeOrSelf(valuesType) != i1Type ||
84 getElementTypeOrSelf(indices.getType()) != i32Type)
85 return rewriter.notifyMatchFailure(
86 op, "expected values of bool type and indices of i32 type");
87
88 const Type i8Type = rewriter.getI8Type();
89 const Type valuesI8Type = cast<TensorType>(valuesType).clone(i8Type);
90 const Type resultI8Type = cast<TensorType>(resultType).clone(i8Type);
91
92 auto valuesToI8 =
93 tosa::CastOp::create(rewriter, op.getLoc(), valuesI8Type, values);
94 auto gatherI8 = tosa::GatherOp::create(rewriter, op.getLoc(), resultI8Type,
95 valuesToI8.getOutput(), indices);
96 auto i8ToBool = tosa::CastOp::create(rewriter, op.getLoc(), resultType,
97 gatherI8.getOutput());
98 rewriter.replaceOp(op, i8ToBool.getOutput());
99 return success();
100 }
101};
102
103class BoolScatterRewrite : public OpRewritePattern<tosa::ScatterOp> {
104public:
106
107 LogicalResult matchAndRewrite(tosa::ScatterOp op,
108 PatternRewriter &rewriter) const override {
109 const Value valuesIn = op.getValuesIn();
110 const Value indices = op.getIndices();
111
112 const Type valuesInType = valuesIn.getType();
113 const Type i1Type = rewriter.getI1Type();
114 const Type i32Type = rewriter.getI32Type();
115 if (getElementTypeOrSelf(valuesInType) != i1Type ||
116 getElementTypeOrSelf(indices.getType()) != i32Type)
117 return rewriter.notifyMatchFailure(
118 op, "expected values of bool type and indices of i32 type");
119
120 const Value input = op.getInput();
121 const Type inputType = input.getType();
122 const Type resultType = op.getType();
123
124 const Type i8Type = rewriter.getI8Type();
125 const Type valuesInI8Type = cast<TensorType>(valuesInType).clone(i8Type);
126 const Type inputI8Type = cast<TensorType>(inputType).clone(i8Type);
127 const Type resultI8Type = cast<TensorType>(resultType).clone(i8Type);
128
129 auto valuesInToI8 =
130 tosa::CastOp::create(rewriter, op.getLoc(), valuesInI8Type, valuesIn);
131 auto inputToI8 =
132 tosa::CastOp::create(rewriter, op.getLoc(), inputI8Type, input);
133 auto scatterI8 = tosa::ScatterOp::create(
134 rewriter, op.getLoc(), resultI8Type, valuesInToI8.getOutput(), indices,
135 inputToI8.getOutput());
136 auto i8ToBool = tosa::CastOp::create(rewriter, op.getLoc(), resultType,
137 scatterI8.getValuesOut());
138 rewriter.replaceOp(op, i8ToBool.getOutput());
139 return success();
140 }
141};
142
143struct TosaDowngrade1p1To1p0Pass
145 TosaDowngrade1p1To1p0Pass> {
146 using Base::Base;
147
148 void runOnOperation() override {
149 MLIRContext &context = getContext();
150 func::FuncOp func = getOperation();
151
152 RewritePatternSet patterns(&context);
153 patterns.add<BoolFp32CastRewrite, BoolGatherRewrite, BoolScatterRewrite>(
154 &context);
155 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
156
157 if (failed(applyPatternsGreedily(func, frozenPatterns)))
158 return signalPassFailure();
159 }
160};
161
162} // namespace
return success()
b getContext())
FloatType getF32Type()
Definition Builders.cpp:47
IntegerType getI32Type()
Definition Builders.cpp:67
IntegerType getI1Type()
Definition Builders.cpp:57
IntegerType getI8Type()
Definition Builders.cpp:63
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Type getType() const
Return the type of this value.
Definition Value.h:105
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...