MLIR 23.0.0git
TosaDowngrade1p1To1p0.cpp
Go to the documentation of this file.
1//===- TosaDowngrade1_1To1_0.cpp -----------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Rewrites constructs which are only compatible in TOSA specification 1.1 and
10// above to their TOSA 1.0 counterparts where possible. Downgrading is
11// best-effort and validation should be performed afterwards to ensure
12// compatibility with the TOSA 1.0 specification.
13//
14//===----------------------------------------------------------------------===//
15
17
23
24namespace mlir {
25namespace tosa {
26#define GEN_PASS_DEF_TOSADOWNGRADE1P1TO1P0PASS
27#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
28} // namespace tosa
29} // namespace mlir
30
31using namespace mlir;
32using namespace mlir::tosa;
33
34namespace {
35
36class BoolFp32CastRewrite : public OpRewritePattern<tosa::CastOp> {
37public:
39
40 LogicalResult matchAndRewrite(tosa::CastOp op,
41 PatternRewriter &rewriter) const override {
42 const Value input = op.getInput();
43
44 const Type i1Type = rewriter.getI1Type();
45 const Type f32Type = rewriter.getF32Type();
46
47 const Type inputElemType = getElementTypeOrSelf(input.getType());
48 const Type outputElemType = getElementTypeOrSelf(op.getType());
49 const bool isFp32ToBool =
50 inputElemType == f32Type && outputElemType == i1Type;
51 const bool isBoolToFp32 =
52 inputElemType == i1Type && outputElemType == f32Type;
53
54 if (!isFp32ToBool && !isBoolToFp32)
55 return rewriter.notifyMatchFailure(op,
56 "expected cast between bool and f32");
57
58 const Type outputType = op.getType();
59 const Type i8Type = rewriter.getI8Type();
60 const Type intermediateType = cast<TensorType>(outputType).clone(i8Type);
61
62 auto inner =
63 tosa::CastOp::create(rewriter, op.getLoc(), intermediateType, input);
64 auto outer = tosa::CastOp::create(rewriter, op.getLoc(), outputType,
65 inner.getOutput());
66 rewriter.replaceOp(op, outer.getOutput());
67 return success();
68 }
69};
70
71class BoolGatherRewrite : public OpRewritePattern<tosa::GatherOp> {
72public:
74
75 LogicalResult matchAndRewrite(tosa::GatherOp op,
76 PatternRewriter &rewriter) const override {
77 const Value values = op.getValues();
78 const Value indices = op.getIndices();
79
80 const Type valuesType = values.getType();
81 const Type resultType = op.getType();
82
83 const Type i1Type = rewriter.getI1Type();
84 const Type i32Type = rewriter.getI32Type();
85 if (getElementTypeOrSelf(valuesType) != i1Type ||
86 getElementTypeOrSelf(indices.getType()) != i32Type)
87 return rewriter.notifyMatchFailure(
88 op, "expected values of bool type and indices of i32 type");
89
90 const Type i8Type = rewriter.getI8Type();
91 const Type valuesI8Type = cast<TensorType>(valuesType).clone(i8Type);
92 const Type resultI8Type = cast<TensorType>(resultType).clone(i8Type);
93
94 auto valuesToI8 =
95 tosa::CastOp::create(rewriter, op.getLoc(), valuesI8Type, values);
96 auto gatherI8 = tosa::GatherOp::create(rewriter, op.getLoc(), resultI8Type,
97 valuesToI8.getOutput(), indices);
98 auto i8ToBool = tosa::CastOp::create(rewriter, op.getLoc(), resultType,
99 gatherI8.getOutput());
100 rewriter.replaceOp(op, i8ToBool.getOutput());
101 return success();
102 }
103};
104
105class BoolScatterRewrite : public OpRewritePattern<tosa::ScatterOp> {
106public:
108
109 LogicalResult matchAndRewrite(tosa::ScatterOp op,
110 PatternRewriter &rewriter) const override {
111 const Value valuesIn = op.getValuesIn();
112 const Value indices = op.getIndices();
113
114 const Type valuesInType = valuesIn.getType();
115 const Type i1Type = rewriter.getI1Type();
116 const Type i32Type = rewriter.getI32Type();
117 if (getElementTypeOrSelf(valuesInType) != i1Type ||
118 getElementTypeOrSelf(indices.getType()) != i32Type)
119 return rewriter.notifyMatchFailure(
120 op, "expected values of bool type and indices of i32 type");
121
122 const Value input = op.getInput();
123 const Type inputType = input.getType();
124 const Type resultType = op.getType();
125
126 const Type i8Type = rewriter.getI8Type();
127 const Type valuesInI8Type = cast<TensorType>(valuesInType).clone(i8Type);
128 const Type inputI8Type = cast<TensorType>(inputType).clone(i8Type);
129 const Type resultI8Type = cast<TensorType>(resultType).clone(i8Type);
130
131 auto valuesInToI8 =
132 tosa::CastOp::create(rewriter, op.getLoc(), valuesInI8Type, valuesIn);
133 auto inputToI8 =
134 tosa::CastOp::create(rewriter, op.getLoc(), inputI8Type, input);
135 auto scatterI8 = tosa::ScatterOp::create(
136 rewriter, op.getLoc(), resultI8Type, valuesInToI8.getOutput(), indices,
137 inputToI8.getOutput());
138 auto i8ToBool = tosa::CastOp::create(rewriter, op.getLoc(), resultType,
139 scatterI8.getValuesOut());
140 rewriter.replaceOp(op, i8ToBool.getOutput());
141 return success();
142 }
143};
144
145static LogicalResult isMatMulTTypeCompatibleForDowngrade(tosa::MatMulTOp op) {
146 const Type aElementType = getStorageElementTypeOrSelf(op.getA().getType());
147 const Type bElementType = getStorageElementTypeOrSelf(op.getB().getType());
148 const Type outputElementType =
149 getStorageElementTypeOrSelf(op.getOutput().getType());
150
151 if (aElementType != bElementType)
152 return failure();
153
154 if ((aElementType.isF16() && outputElementType.isF16()) ||
155 (aElementType.isF16() && outputElementType.isF32()) ||
156 (aElementType.isF32() && outputElementType.isF32()) ||
157 (aElementType.isBF16() && outputElementType.isF32()) ||
158 (aElementType.isInteger(8) && outputElementType.isInteger(32)) ||
159 (aElementType.isInteger(16) && outputElementType.isInteger(48)) ||
160 (isa<Float8E5M2Type>(aElementType) && outputElementType.isF16()) ||
161 (isa<Float8E4M3FNType>(aElementType) && outputElementType.isF16()))
162 return success();
163
164 return failure();
165}
166
167class MatMulTRewrite : public OpRewritePattern<tosa::MatMulTOp> {
168public:
170
171 LogicalResult matchAndRewrite(tosa::MatMulTOp op,
172 PatternRewriter &rewriter) const override {
173 if (failed(isMatMulTTypeCompatibleForDowngrade(op)))
174 return rewriter.notifyMatchFailure(
175 op, "expected 1.0-compatible matmul_t element types");
176
177 const Type aType = op.getA().getType();
178 const Type bType = op.getB().getType();
179 const ShapeAdaptor aShape(aType);
180 const ShapeAdaptor bShape(bType);
181 if (!aShape.hasRank() || !bShape.hasRank())
182 return rewriter.notifyMatchFailure(op, "expected ranked A and B tensors");
183
184 const int64_t dSize = bShape.getDimSize(0);
185 const int64_t nSize = aShape.getDimSize(0);
186
187 // To convert broadcasting behaviour to TOSA 1.0, we're required to tile the
188 // input. TOSA 1.0 does not support shape expressions, so the batch size
189 // must be known at compile time.
190 if (ShapedType::isDynamic(dSize) ||
191 (dSize == 1 && ShapedType::isDynamic(nSize)))
192 return rewriter.notifyMatchFailure(
193 op, "expected known batch size for broadcast");
194
195 const int64_t wSize = bShape.getDimSize(1);
196 const int64_t cSize = bShape.getDimSize(2);
197 const Location loc = op.getLoc();
198 const RankedTensorType transposedBType =
199 cast<RankedTensorType>(bType).clone({dSize, cSize, wSize});
200 auto transpose =
201 tosa::TransposeOp::create(rewriter, loc, transposedBType, op.getB(),
202 rewriter.getDenseI32ArrayAttr({0, 2, 1}));
203 Value matMulB = transpose.getOutput();
204
205 // Matmul does not support broadcasting, so tile b if required
206 if (dSize == 1 && nSize != 1) {
207 const RankedTensorType tiledBType =
208 cast<RankedTensorType>(bType).clone({nSize, cSize, wSize});
209 const Value multiples = getTosaConstShape(rewriter, loc, {nSize, 1, 1});
210 auto tile =
211 tosa::TileOp::create(rewriter, loc, tiledBType, matMulB, multiples);
212 matMulB = tile.getOutput();
213 }
214
215 auto matmul = tosa::MatMulOp::create(rewriter, loc, op.getType(), op.getA(),
216 matMulB, op.getAZp(), op.getBZp());
217 rewriter.replaceOp(op, matmul.getOutput());
218 return success();
219 }
220};
221
222struct TosaDowngrade1p1To1p0Pass
224 TosaDowngrade1p1To1p0Pass> {
225 using Base::Base;
226
227 void runOnOperation() override {
228 MLIRContext &context = getContext();
229 func::FuncOp func = getOperation();
230
231 RewritePatternSet patterns(&context);
232 patterns.add<BoolFp32CastRewrite, BoolGatherRewrite, BoolScatterRewrite,
233 MatMulTRewrite>(&context);
234 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
235
236 if (failed(applyPatternsGreedily(func, frozenPatterns)))
237 return signalPassFailure();
238 }
239};
240
241} // namespace
return success()
b getContext())
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
FloatType getF32Type()
Definition Builders.cpp:47
IntegerType getI32Type()
Definition Builders.cpp:67
IntegerType getI1Type()
Definition Builders.cpp:57
IntegerType getI8Type()
Definition Builders.cpp:63
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF32() const
Definition Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isF16() const
Definition Types.cpp:38
bool isBF16() const
Definition Types.cpp:37
Type getType() const
Return the type of this value.
Definition Value.h:105
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Type getStorageElementTypeOrSelf(Type type)
Definition TosaOps.cpp:584
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1330
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...