MLIR 22.0.0git
TosaToArith.cpp
Go to the documentation of this file.
1//===- TosaToArith.cpp - Lowering Tosa to Arith Dialect -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// These rewriters lower from the Tosa to the Arith dialect.
10//
11//===----------------------------------------------------------------------===//
12
18
19using namespace mlir;
20using namespace tosa;
21
22namespace {
23
24class ConstOpConverter : public OpRewritePattern<tosa::ConstOp> {
25public:
26 using Base::Base;
27
28 LogicalResult matchAndRewrite(tosa::ConstOp op,
29 PatternRewriter &rewriter) const final {
30 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, op.getValues());
31 return success();
32 }
33};
34
35Type matchContainerType(Type element, Type container) {
36 if (auto shapedTy = dyn_cast<ShapedType>(container))
37 return shapedTy.clone(element);
38
39 return element;
40}
41
42TypedAttr getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) {
43 if (auto shapedTy = dyn_cast<ShapedType>(type)) {
44 Type eTy = shapedTy.getElementType();
45 APInt valueInt(eTy.getIntOrFloatBitWidth(), value, /*isSigned=*/true);
46 return DenseIntElementsAttr::get(shapedTy, valueInt);
47 }
48
49 return rewriter.getIntegerAttr(type, value);
50}
51
52Value getConstantValue(Location loc, Type type, int64_t value,
53 PatternRewriter &rewriter) {
54 return arith::ConstantOp::create(rewriter, loc,
55 getConstantAttr(type, value, rewriter));
56}
57
58// This converts the TOSA ApplyScale operator to a set of arithmetic ops,
59// using 64-bit operations to perform the necessary multiply, bias, and shift.
60class ApplyScaleGenericOpConverter
61 : public OpRewritePattern<tosa::ApplyScaleOp> {
62public:
63 using Base::Base;
64
65 LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
66 PatternRewriter &rewriter) const final {
67 RoundingMode roundingMode = op.getRoundingMode();
68 if (roundingMode != RoundingMode::DOUBLE_ROUND &&
69 roundingMode != RoundingMode::SINGLE_ROUND) {
70 return failure();
71 }
72
73 Location loc = op.getLoc();
74 Value value = op.getValue();
75 Value multiplier32 = op.getMultiplier();
76
77 Type resultTy = op.getType();
78 Type valueTy = value.getType();
79 Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy);
80 Type i64Ty = matchContainerType(rewriter.getI64Type(), resultTy);
81
82 Value zero = getConstantValue(loc, valueTy, 0, rewriter);
83 Value one64 = getConstantValue(loc, i64Ty, 1, rewriter);
84 Value thirtyOne32 = getConstantValue(loc, i32Ty, 31, rewriter);
85
86 Value shift32 = arith::ExtUIOp::create(rewriter, loc, i32Ty, op.getShift());
87
88 // Compute the multiplication in 64-bits then select the high / low parts.
89 Value value64 = value;
90 if (getElementTypeOrSelf(valueTy) != rewriter.getI64Type())
91 value64 = arith::ExtSIOp::create(rewriter, loc, i64Ty, value);
92 Value multiplier64 =
93 arith::ExtSIOp::create(rewriter, loc, i64Ty, multiplier32);
94 Value multiply64 =
95 arith::MulIOp::create(rewriter, loc, value64, multiplier64);
96
97 // Apply normal rounding.
98 Value shift64 = arith::ExtUIOp::create(rewriter, loc, i64Ty, shift32);
99 Value round = arith::ShLIOp::create(rewriter, loc, one64, shift64);
100 round = arith::ShRUIOp::create(rewriter, loc, round, one64);
101 multiply64 = arith::AddIOp::create(rewriter, loc, multiply64, round);
102
103 // Apply double rounding if necessary.
104 if (op.getRoundingMode() == RoundingMode::DOUBLE_ROUND) {
105 int64_t roundInt = 1 << 30;
106 Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter);
107 Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter);
108 Value positive = arith::CmpIOp::create(
109 rewriter, loc, arith::CmpIPredicate::sge, value, zero);
110 Value dir =
111 arith::SelectOp::create(rewriter, loc, positive, roundUp, roundDown);
112 Value val = arith::AddIOp::create(rewriter, loc, dir, multiply64);
113 Value valid = arith::CmpIOp::create(
114 rewriter, loc, arith::CmpIPredicate::sgt, shift32, thirtyOne32);
115 multiply64 =
116 arith::SelectOp::create(rewriter, loc, valid, val, multiply64);
117 }
118
119 Value result64 = arith::ShRSIOp::create(rewriter, loc, multiply64, shift64);
120 Value result32 = arith::TruncIOp::create(rewriter, loc, i32Ty, result64);
121
122 rewriter.replaceOp(op, result32);
123 return success();
124 }
125};
126
127class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
128public:
129 using Base::Base;
130
131 LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
132 PatternRewriter &rewriter) const final {
133 RoundingMode roundingMode = op.getRoundingMode();
134 if (roundingMode != RoundingMode::DOUBLE_ROUND &&
135 roundingMode != RoundingMode::SINGLE_ROUND) {
136 return failure();
137 }
138
139 Location loc = op.getLoc();
140
141 Type resultTy = op.getType();
142 Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy);
143
144 Value value = op.getValue();
146 return failure();
147 }
148
149 Value value32 = op.getValue();
150 Value multiplier32 = op.getMultiplier();
151 Value shift32 = arith::ExtUIOp::create(rewriter, loc, i32Ty, op.getShift());
152
153 // Constants used during the scaling operation.
154 Value zero32 = getConstantValue(loc, i32Ty, 0, rewriter);
155 Value one32 = getConstantValue(loc, i32Ty, 1, rewriter);
156 Value two32 = getConstantValue(loc, i32Ty, 2, rewriter);
157 Value thirty32 = getConstantValue(loc, i32Ty, 30, rewriter);
158 Value thirtyTwo32 = getConstantValue(loc, i32Ty, 32, rewriter);
159
160 // Compute the multiplication in 64-bits then select the high / low parts.
161 // Grab out the high/low of the computation
162 auto value64 =
163 arith::MulSIExtendedOp::create(rewriter, loc, value32, multiplier32);
164 Value low32 = value64.getLow();
165 Value high32 = value64.getHigh();
166
167 // Determine the direction and amount to shift the high bits.
168 Value shiftOver32 = arith::CmpIOp::create(
169 rewriter, loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32);
170 Value roundHighBits = arith::CmpIOp::create(
171 rewriter, loc, arith::CmpIPredicate::sgt, shift32, thirtyTwo32);
172
173 Value shiftHighL =
174 arith::SubIOp::create(rewriter, loc, thirtyTwo32, shift32);
175 Value shiftHighR =
176 arith::SubIOp::create(rewriter, loc, shift32, thirtyTwo32);
177
178 shiftHighL =
179 arith::SelectOp::create(rewriter, loc, shiftOver32, zero32, shiftHighL);
180 shiftHighR =
181 arith::SelectOp::create(rewriter, loc, shiftOver32, shiftHighR, zero32);
182
183 // Conditionally perform our double round.
184 if (op.getRoundingMode() == RoundingMode::DOUBLE_ROUND) {
185 Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter);
186 Value valuePositive = arith::CmpIOp::create(
187 rewriter, loc, arith::CmpIPredicate::sge, value32, zero32);
188
189 Value roundDir = arith::SelectOp::create(rewriter, loc, valuePositive,
190 one32, negOne32);
191 roundDir =
192 arith::SelectOp::create(rewriter, loc, shiftOver32, roundDir, zero32);
193
194 Value shiftLow = arith::ShRUIOp::create(rewriter, loc, low32, thirty32);
195 Value rounded = arith::AddIOp::create(rewriter, loc, shiftLow, roundDir);
196 Value carry = arith::ShRSIOp::create(rewriter, loc, rounded, two32);
197
198 Value shiftRound =
199 arith::ShLIOp::create(rewriter, loc, roundDir, thirty32);
200
201 low32 = arith::AddIOp::create(rewriter, loc, low32, shiftRound);
202 high32 = arith::AddIOp::create(rewriter, loc, high32, carry);
203 }
204
205 // Conditionally apply rounding in the low bits.
206 {
207 Value shiftSubOne = arith::SubIOp::create(rewriter, loc, shift32, one32);
208 Value roundBit = arith::ShLIOp::create(rewriter, loc, one32, shiftSubOne);
209 roundBit = arith::SelectOp::create(rewriter, loc, roundHighBits, zero32,
210 roundBit);
211
212 Value newLow32 = arith::AddIOp::create(rewriter, loc, low32, roundBit);
213 Value wasRounded = arith::CmpIOp::create(
214 rewriter, loc, arith::CmpIPredicate::ugt, low32, newLow32);
215 low32 = newLow32;
216
217 Value rounded32 =
218 arith::ExtUIOp::create(rewriter, loc, i32Ty, wasRounded);
219 high32 = arith::AddIOp::create(rewriter, loc, high32, rounded32);
220 }
221
222 // Conditionally apply rounding in the high bits.
223 {
224 Value shiftSubOne =
225 arith::SubIOp::create(rewriter, loc, shiftHighR, one32);
226 Value roundBit = arith::ShLIOp::create(rewriter, loc, one32, shiftSubOne);
227 roundBit = arith::SelectOp::create(rewriter, loc, roundHighBits, roundBit,
228 zero32);
229 high32 = arith::AddIOp::create(rewriter, loc, high32, roundBit);
230 }
231
232 // Combine the correct high/low bits into the final rescale result.
233 high32 = arith::ShLIOp::create(rewriter, loc, high32, shiftHighL);
234 high32 = arith::ShRSIOp::create(rewriter, loc, high32, shiftHighR);
235 low32 = arith::ShRUIOp::create(rewriter, loc, low32, shift32);
236 low32 = arith::SelectOp::create(rewriter, loc, shiftOver32, zero32, low32);
237
238 // Apply the rounding behavior and shift to the final alignment.
239 Value result = arith::AddIOp::create(rewriter, loc, low32, high32);
240
241 // Truncate if necessary.
242 if (!getElementTypeOrSelf(resultTy).isInteger(32)) {
243 result = arith::TruncIOp::create(rewriter, loc, resultTy, result);
244 }
245
246 rewriter.replaceOp(op, result);
247 return success();
248 }
249};
250
251} // namespace
252
255 patterns->add<ConstOpConverter>(patterns->getContext());
256}
257
259 RewritePatternSet *patterns, bool include32Bit) {
260 patterns->add<ApplyScaleGenericOpConverter>(patterns->getContext(), 100);
261 if (include32Bit) {
262 patterns->add<ApplyScale32BitOpConverter>(patterns->getContext(), 200);
263 }
264}
return success()
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
IntegerType getI64Type()
Definition Builders.cpp:65
IntegerType getI32Type()
Definition Builders.cpp:63
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
DynamicAPInt round(const Fraction &f)
Definition Fraction.h:136
void populateTosaRescaleToArithConversionPatterns(RewritePatternSet *patterns, bool include32Bit=false)
void populateTosaToArithConversionPatterns(RewritePatternSet *patterns)
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...