MLIR  22.0.0git
TosaToArith.cpp
Go to the documentation of this file.
1 //===- TosaToArith.cpp - Lowering Tosa to Arith Dialect -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Arith dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/TypeUtilities.h"
18 
19 using namespace mlir;
20 using namespace tosa;
21 
22 namespace {
23 
24 class ConstOpConverter : public OpRewritePattern<tosa::ConstOp> {
25 public:
27 
28  LogicalResult matchAndRewrite(tosa::ConstOp op,
29  PatternRewriter &rewriter) const final {
30  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, op.getValues());
31  return success();
32  }
33 };
34 
35 Type matchContainerType(Type element, Type container) {
36  if (auto shapedTy = dyn_cast<ShapedType>(container))
37  return shapedTy.clone(element);
38 
39  return element;
40 }
41 
42 TypedAttr getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) {
43  if (auto shapedTy = dyn_cast<ShapedType>(type)) {
44  Type eTy = shapedTy.getElementType();
45  APInt valueInt(eTy.getIntOrFloatBitWidth(), value, /*isSigned=*/true);
46  return DenseIntElementsAttr::get(shapedTy, valueInt);
47  }
48 
49  return rewriter.getIntegerAttr(type, value);
50 }
51 
52 Value getConstantValue(Location loc, Type type, int64_t value,
53  PatternRewriter &rewriter) {
54  return arith::ConstantOp::create(rewriter, loc,
55  getConstantAttr(type, value, rewriter));
56 }
57 
58 // This converts the TOSA ApplyScale operator to a set of arithmetic ops,
59 // using 64-bit operations to perform the necessary multiply, bias, and shift.
60 class ApplyScaleGenericOpConverter
61  : public OpRewritePattern<tosa::ApplyScaleOp> {
62 public:
64 
65  LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
66  PatternRewriter &rewriter) const final {
67  RoundingMode roundingMode = op.getRoundingMode();
68  if (roundingMode != RoundingMode::DOUBLE_ROUND &&
69  roundingMode != RoundingMode::SINGLE_ROUND) {
70  return failure();
71  }
72 
73  Location loc = op.getLoc();
74  Value value = op.getValue();
75  Value multiplier32 = op.getMultiplier();
76 
77  Type resultTy = op.getType();
78  Type valueTy = value.getType();
79  Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy);
80  Type i64Ty = matchContainerType(rewriter.getI64Type(), resultTy);
81 
82  Value zero = getConstantValue(loc, valueTy, 0, rewriter);
83  Value one64 = getConstantValue(loc, i64Ty, 1, rewriter);
84  Value thirtyOne32 = getConstantValue(loc, i32Ty, 31, rewriter);
85 
86  Value shift32 = arith::ExtUIOp::create(rewriter, loc, i32Ty, op.getShift());
87 
88  // Compute the multiplication in 64-bits then select the high / low parts.
89  Value value64 = value;
90  if (getElementTypeOrSelf(valueTy) != rewriter.getI64Type())
91  value64 = arith::ExtSIOp::create(rewriter, loc, i64Ty, value);
92  Value multiplier64 =
93  arith::ExtSIOp::create(rewriter, loc, i64Ty, multiplier32);
94  Value multiply64 =
95  arith::MulIOp::create(rewriter, loc, value64, multiplier64);
96 
97  // Apply normal rounding.
98  Value shift64 = arith::ExtUIOp::create(rewriter, loc, i64Ty, shift32);
99  Value round = arith::ShLIOp::create(rewriter, loc, one64, shift64);
100  round = arith::ShRUIOp::create(rewriter, loc, round, one64);
101  multiply64 = arith::AddIOp::create(rewriter, loc, multiply64, round);
102 
103  // Apply double rounding if necessary.
104  if (op.getRoundingMode() == RoundingMode::DOUBLE_ROUND) {
105  int64_t roundInt = 1 << 30;
106  Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter);
107  Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter);
108  Value positive = arith::CmpIOp::create(
109  rewriter, loc, arith::CmpIPredicate::sge, value, zero);
110  Value dir =
111  arith::SelectOp::create(rewriter, loc, positive, roundUp, roundDown);
112  Value val = arith::AddIOp::create(rewriter, loc, dir, multiply64);
113  Value valid = arith::CmpIOp::create(
114  rewriter, loc, arith::CmpIPredicate::sgt, shift32, thirtyOne32);
115  multiply64 =
116  arith::SelectOp::create(rewriter, loc, valid, val, multiply64);
117  }
118 
119  Value result64 = arith::ShRSIOp::create(rewriter, loc, multiply64, shift64);
120  Value result32 = arith::TruncIOp::create(rewriter, loc, i32Ty, result64);
121 
122  rewriter.replaceOp(op, result32);
123  return success();
124  }
125 };
126 
127 class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
128 public:
130 
131  LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
132  PatternRewriter &rewriter) const final {
133  RoundingMode roundingMode = op.getRoundingMode();
134  if (roundingMode != RoundingMode::DOUBLE_ROUND &&
135  roundingMode != RoundingMode::SINGLE_ROUND) {
136  return failure();
137  }
138 
139  Location loc = op.getLoc();
140 
141  Type resultTy = op.getType();
142  Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy);
143 
144  Value value = op.getValue();
145  if (getElementTypeOrSelf(value.getType()).getIntOrFloatBitWidth() > 32) {
146  return failure();
147  }
148 
149  Value value32 = op.getValue();
150  Value multiplier32 = op.getMultiplier();
151  Value shift32 = arith::ExtUIOp::create(rewriter, loc, i32Ty, op.getShift());
152 
153  // Constants used during the scaling operation.
154  Value zero32 = getConstantValue(loc, i32Ty, 0, rewriter);
155  Value one32 = getConstantValue(loc, i32Ty, 1, rewriter);
156  Value two32 = getConstantValue(loc, i32Ty, 2, rewriter);
157  Value thirty32 = getConstantValue(loc, i32Ty, 30, rewriter);
158  Value thirtyTwo32 = getConstantValue(loc, i32Ty, 32, rewriter);
159 
160  // Compute the multiplication in 64-bits then select the high / low parts.
161  // Grab out the high/low of the computation
162  auto value64 =
163  arith::MulSIExtendedOp::create(rewriter, loc, value32, multiplier32);
164  Value low32 = value64.getLow();
165  Value high32 = value64.getHigh();
166 
167  // Determine the direction and amount to shift the high bits.
168  Value shiftOver32 = arith::CmpIOp::create(
169  rewriter, loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32);
170  Value roundHighBits = arith::CmpIOp::create(
171  rewriter, loc, arith::CmpIPredicate::sgt, shift32, thirtyTwo32);
172 
173  Value shiftHighL =
174  arith::SubIOp::create(rewriter, loc, thirtyTwo32, shift32);
175  Value shiftHighR =
176  arith::SubIOp::create(rewriter, loc, shift32, thirtyTwo32);
177 
178  shiftHighL =
179  arith::SelectOp::create(rewriter, loc, shiftOver32, zero32, shiftHighL);
180  shiftHighR =
181  arith::SelectOp::create(rewriter, loc, shiftOver32, shiftHighR, zero32);
182 
183  // Conditionally perform our double round.
184  if (op.getRoundingMode() == RoundingMode::DOUBLE_ROUND) {
185  Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter);
186  Value valuePositive = arith::CmpIOp::create(
187  rewriter, loc, arith::CmpIPredicate::sge, value32, zero32);
188 
189  Value roundDir = arith::SelectOp::create(rewriter, loc, valuePositive,
190  one32, negOne32);
191  roundDir =
192  arith::SelectOp::create(rewriter, loc, shiftOver32, roundDir, zero32);
193 
194  Value shiftLow = arith::ShRUIOp::create(rewriter, loc, low32, thirty32);
195  Value rounded = arith::AddIOp::create(rewriter, loc, shiftLow, roundDir);
196  Value carry = arith::ShRSIOp::create(rewriter, loc, rounded, two32);
197 
198  Value shiftRound =
199  arith::ShLIOp::create(rewriter, loc, roundDir, thirty32);
200 
201  low32 = arith::AddIOp::create(rewriter, loc, low32, shiftRound);
202  high32 = arith::AddIOp::create(rewriter, loc, high32, carry);
203  }
204 
205  // Conditionally apply rounding in the low bits.
206  {
207  Value shiftSubOne = arith::SubIOp::create(rewriter, loc, shift32, one32);
208  Value roundBit = arith::ShLIOp::create(rewriter, loc, one32, shiftSubOne);
209  roundBit = arith::SelectOp::create(rewriter, loc, roundHighBits, zero32,
210  roundBit);
211 
212  Value newLow32 = arith::AddIOp::create(rewriter, loc, low32, roundBit);
213  Value wasRounded = arith::CmpIOp::create(
214  rewriter, loc, arith::CmpIPredicate::ugt, low32, newLow32);
215  low32 = newLow32;
216 
217  Value rounded32 =
218  arith::ExtUIOp::create(rewriter, loc, i32Ty, wasRounded);
219  high32 = arith::AddIOp::create(rewriter, loc, high32, rounded32);
220  }
221 
222  // Conditionally apply rounding in the high bits.
223  {
224  Value shiftSubOne =
225  arith::SubIOp::create(rewriter, loc, shiftHighR, one32);
226  Value roundBit = arith::ShLIOp::create(rewriter, loc, one32, shiftSubOne);
227  roundBit = arith::SelectOp::create(rewriter, loc, roundHighBits, roundBit,
228  zero32);
229  high32 = arith::AddIOp::create(rewriter, loc, high32, roundBit);
230  }
231 
232  // Combine the correct high/low bits into the final rescale result.
233  high32 = arith::ShLIOp::create(rewriter, loc, high32, shiftHighL);
234  high32 = arith::ShRSIOp::create(rewriter, loc, high32, shiftHighR);
235  low32 = arith::ShRUIOp::create(rewriter, loc, low32, shift32);
236  low32 = arith::SelectOp::create(rewriter, loc, shiftOver32, zero32, low32);
237 
238  // Apply the rounding behavior and shift to the final alignment.
239  Value result = arith::AddIOp::create(rewriter, loc, low32, high32);
240 
241  // Truncate if necessary.
242  if (!getElementTypeOrSelf(resultTy).isInteger(32)) {
243  result = arith::TruncIOp::create(rewriter, loc, resultTy, result);
244  }
245 
246  rewriter.replaceOp(op, result);
247  return success();
248  }
249 };
250 
251 } // namespace
252 
255  patterns->add<ConstOpConverter>(patterns->getContext());
256 }
257 
259  RewritePatternSet *patterns, bool include32Bit) {
260  patterns->add<ApplyScaleGenericOpConverter>(patterns->getContext(), 100);
261  if (include32Bit) {
262  patterns->add<ApplyScale32BitOpConverter>(patterns->getContext(), 200);
263  }
264 }
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:227
IntegerType getI64Type()
Definition: Builders.cpp:64
IntegerType getI32Type()
Definition: Builders.cpp:62
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
DynamicAPInt round(const Fraction &f)
Definition: Fraction.h:136
void populateTosaRescaleToArithConversionPatterns(RewritePatternSet *patterns, bool include32Bit=false)
void populateTosaToArithConversionPatterns(RewritePatternSet *patterns)
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314