26#define GEN_PASS_DEF_TOSADOWNGRADE1P1TO1P0PASS
27#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
40 LogicalResult matchAndRewrite(tosa::CastOp op,
41 PatternRewriter &rewriter)
const override {
42 const Value input = op.getInput();
49 const bool isFp32ToBool =
50 inputElemType == f32Type && outputElemType == i1Type;
51 const bool isBoolToFp32 =
52 inputElemType == i1Type && outputElemType == f32Type;
54 if (!isFp32ToBool && !isBoolToFp32)
56 "expected cast between bool and f32");
58 const Type outputType = op.getType();
60 const Type intermediateType = cast<TensorType>(outputType).clone(i8Type);
63 tosa::CastOp::create(rewriter, op.getLoc(), intermediateType, input);
64 auto outer = tosa::CastOp::create(rewriter, op.getLoc(), outputType,
66 rewriter.
replaceOp(op, outer.getOutput());
75 LogicalResult matchAndRewrite(tosa::GatherOp op,
76 PatternRewriter &rewriter)
const override {
77 const Value values = op.getValues();
78 const Value
indices = op.getIndices();
80 const Type valuesType = values.
getType();
81 const Type resultType = op.getType();
88 op,
"expected values of bool type and indices of i32 type");
91 const Type valuesI8Type = cast<TensorType>(valuesType).clone(i8Type);
92 const Type resultI8Type = cast<TensorType>(resultType).clone(i8Type);
95 tosa::CastOp::create(rewriter, op.getLoc(), valuesI8Type, values);
96 auto gatherI8 = tosa::GatherOp::create(rewriter, op.getLoc(), resultI8Type,
97 valuesToI8.getOutput(),
indices);
98 auto i8ToBool = tosa::CastOp::create(rewriter, op.getLoc(), resultType,
99 gatherI8.getOutput());
100 rewriter.
replaceOp(op, i8ToBool.getOutput());
109 LogicalResult matchAndRewrite(tosa::ScatterOp op,
110 PatternRewriter &rewriter)
const override {
111 const Value valuesIn = op.getValuesIn();
112 const Value
indices = op.getIndices();
114 const Type valuesInType = valuesIn.
getType();
115 const Type i1Type = rewriter.
getI1Type();
120 op,
"expected values of bool type and indices of i32 type");
122 const Value input = op.getInput();
123 const Type inputType = input.
getType();
124 const Type resultType = op.getType();
126 const Type i8Type = rewriter.
getI8Type();
127 const Type valuesInI8Type = cast<TensorType>(valuesInType).clone(i8Type);
128 const Type inputI8Type = cast<TensorType>(inputType).clone(i8Type);
129 const Type resultI8Type = cast<TensorType>(resultType).clone(i8Type);
132 tosa::CastOp::create(rewriter, op.getLoc(), valuesInI8Type, valuesIn);
134 tosa::CastOp::create(rewriter, op.getLoc(), inputI8Type, input);
135 auto scatterI8 = tosa::ScatterOp::create(
136 rewriter, op.getLoc(), resultI8Type, valuesInToI8.getOutput(),
indices,
137 inputToI8.getOutput());
138 auto i8ToBool = tosa::CastOp::create(rewriter, op.getLoc(), resultType,
139 scatterI8.getValuesOut());
140 rewriter.
replaceOp(op, i8ToBool.getOutput());
145static LogicalResult isMatMulTTypeCompatibleForDowngrade(tosa::MatMulTOp op) {
148 const Type outputElementType =
151 if (aElementType != bElementType)
154 if ((aElementType.
isF16() && outputElementType.
isF16()) ||
155 (aElementType.
isF16() && outputElementType.
isF32()) ||
156 (aElementType.
isF32() && outputElementType.
isF32()) ||
157 (aElementType.
isBF16() && outputElementType.
isF32()) ||
160 (isa<Float8E5M2Type>(aElementType) && outputElementType.
isF16()) ||
161 (isa<Float8E4M3FNType>(aElementType) && outputElementType.
isF16()))
171 LogicalResult matchAndRewrite(tosa::MatMulTOp op,
172 PatternRewriter &rewriter)
const override {
173 if (
failed(isMatMulTTypeCompatibleForDowngrade(op)))
175 op,
"expected 1.0-compatible matmul_t element types");
177 const Type aType = op.getA().getType();
178 const Type bType = op.getB().getType();
179 const ShapeAdaptor aShape(aType);
180 const ShapeAdaptor bShape(bType);
181 if (!aShape.hasRank() || !bShape.hasRank())
184 const int64_t dSize = bShape.getDimSize(0);
185 const int64_t nSize = aShape.getDimSize(0);
190 if (ShapedType::isDynamic(dSize) ||
191 (dSize == 1 && ShapedType::isDynamic(nSize)))
193 op,
"expected known batch size for broadcast");
195 const int64_t wSize = bShape.getDimSize(1);
196 const int64_t cSize = bShape.getDimSize(2);
197 const Location loc = op.getLoc();
198 const RankedTensorType transposedBType =
199 cast<RankedTensorType>(bType).clone({dSize, cSize, wSize});
201 tosa::TransposeOp::create(rewriter, loc, transposedBType, op.getB(),
203 Value matMulB = transpose.getOutput();
206 if (dSize == 1 && nSize != 1) {
207 const RankedTensorType tiledBType =
208 cast<RankedTensorType>(bType).clone({nSize, cSize, wSize});
211 tosa::TileOp::create(rewriter, loc, tiledBType, matMulB, multiples);
212 matMulB =
tile.getOutput();
215 auto matmul = tosa::MatMulOp::create(rewriter, loc, op.getType(), op.getA(),
216 matMulB, op.getAZp(), op.getBZp());
217 rewriter.
replaceOp(op, matmul.getOutput());
222struct TosaDowngrade1p1To1p0Pass
224 TosaDowngrade1p1To1p0Pass> {
227 void runOnOperation()
override {
229 func::FuncOp func = getOperation();
231 RewritePatternSet patterns(&context);
232 patterns.add<BoolFp32CastRewrite, BoolGatherRewrite, BoolScatterRewrite,
233 MatMulTRewrite>(&context);
234 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
237 return signalPassFailure();
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
Type getType() const
Return the type of this value.
Type getStorageElementTypeOrSelf(Type type)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...