MLIR  20.0.0git
TosaToArith.cpp
Go to the documentation of this file.
1 //===- TosaToArith.cpp - Lowering Tosa to Arith Dialect -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Arith dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/TypeUtilities.h"
19 
20 using namespace mlir;
21 using namespace tosa;
22 
23 namespace {
24 
25 class ConstOpConverter : public OpRewritePattern<tosa::ConstOp> {
26 public:
28 
29  LogicalResult matchAndRewrite(tosa::ConstOp op,
30  PatternRewriter &rewriter) const final {
31  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, op.getValue());
32  return success();
33  }
34 };
35 
36 Type matchContainerType(Type element, Type container) {
37  if (auto shapedTy = dyn_cast<ShapedType>(container))
38  return shapedTy.clone(element);
39 
40  return element;
41 }
42 
43 TypedAttr getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) {
44  if (auto shapedTy = dyn_cast<ShapedType>(type)) {
45  Type eTy = shapedTy.getElementType();
46  APInt valueInt(eTy.getIntOrFloatBitWidth(), value, /*isSigned=*/true);
47  return DenseIntElementsAttr::get(shapedTy, valueInt);
48  }
49 
50  return rewriter.getIntegerAttr(type, value);
51 }
52 
53 Value getConstantValue(Location loc, Type type, int64_t value,
54  PatternRewriter &rewriter) {
55  return rewriter.create<arith::ConstantOp>(
56  loc, getConstantAttr(type, value, rewriter));
57 }
58 
59 // This converts the TOSA ApplyScale operator to a set of arithmetic ops,
60 // using 64-bit operations to perform the necessary multiply, bias, and shift.
61 class ApplyScaleGenericOpConverter
62  : public OpRewritePattern<tosa::ApplyScaleOp> {
63 public:
65 
66  LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
67  PatternRewriter &rewriter) const final {
68  Location loc = op.getLoc();
69  Value value = op.getValue();
70  Value multiplier32 = op.getMultiplier();
71 
72  Type resultTy = op.getType();
73  Type valueTy = value.getType();
74  Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy);
75  Type i64Ty = matchContainerType(rewriter.getI64Type(), resultTy);
76 
77  Value zero = getConstantValue(loc, valueTy, 0, rewriter);
78  Value one64 = getConstantValue(loc, i64Ty, 1, rewriter);
79  Value thirtyOne32 = getConstantValue(loc, i32Ty, 31, rewriter);
80 
81  Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.getShift());
82 
83  // Compute the multiplication in 64-bits then select the high / low parts.
84  Value value64 = value;
85  if (getElementTypeOrSelf(valueTy) != rewriter.getI64Type())
86  value64 = rewriter.create<arith::ExtSIOp>(loc, i64Ty, value);
87  Value multiplier64 =
88  rewriter.create<arith::ExtSIOp>(loc, i64Ty, multiplier32);
89  Value multiply64 =
90  rewriter.create<arith::MulIOp>(loc, value64, multiplier64);
91 
92  // Apply normal rounding.
93  Value shift64 = rewriter.create<arith::ExtUIOp>(loc, i64Ty, shift32);
94  Value round = rewriter.create<arith::ShLIOp>(loc, one64, shift64);
95  round = rewriter.create<arith::ShRUIOp>(loc, round, one64);
96  multiply64 = rewriter.create<arith::AddIOp>(loc, multiply64, round);
97 
98  // Apply double rounding if necessary.
99  if (op.getDoubleRound()) {
100  int64_t roundInt = 1 << 30;
101  Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter);
102  Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter);
103  Value positive = rewriter.create<arith::CmpIOp>(
104  loc, arith::CmpIPredicate::sge, value, zero);
105  Value dir =
106  rewriter.create<arith::SelectOp>(loc, positive, roundUp, roundDown);
107  Value val = rewriter.create<arith::AddIOp>(loc, dir, multiply64);
108  Value valid = rewriter.create<arith::CmpIOp>(
109  loc, arith::CmpIPredicate::sgt, shift32, thirtyOne32);
110  multiply64 =
111  rewriter.create<arith::SelectOp>(loc, valid, val, multiply64);
112  }
113 
114  Value result64 = rewriter.create<arith::ShRSIOp>(loc, multiply64, shift64);
115  Value result32 = rewriter.create<arith::TruncIOp>(loc, i32Ty, result64);
116 
117  rewriter.replaceOp(op, result32);
118  return success();
119  }
120 };
121 
122 class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
123 public:
125 
126  LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
127  PatternRewriter &rewriter) const final {
128  Location loc = op.getLoc();
129 
130  Type resultTy = op.getType();
131  Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy);
132 
133  Value value = op.getValue();
134  if (getElementTypeOrSelf(value.getType()).getIntOrFloatBitWidth() > 32) {
135  return failure();
136  }
137 
138  Value value32 = op.getValue();
139  Value multiplier32 = op.getMultiplier();
140  Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.getShift());
141 
142  // Constants used during the scaling operation.
143  Value zero32 = getConstantValue(loc, i32Ty, 0, rewriter);
144  Value one32 = getConstantValue(loc, i32Ty, 1, rewriter);
145  Value two32 = getConstantValue(loc, i32Ty, 2, rewriter);
146  Value thirty32 = getConstantValue(loc, i32Ty, 30, rewriter);
147  Value thirtyTwo32 = getConstantValue(loc, i32Ty, 32, rewriter);
148 
149  // Compute the multiplication in 64-bits then select the high / low parts.
150  // Grab out the high/low of the computation
151  auto value64 =
152  rewriter.create<arith::MulSIExtendedOp>(loc, value32, multiplier32);
153  Value low32 = value64.getLow();
154  Value high32 = value64.getHigh();
155 
156  // Determine the direction and amount to shift the high bits.
157  Value shiftOver32 = rewriter.create<arith::CmpIOp>(
158  loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32);
159  Value roundHighBits = rewriter.create<arith::CmpIOp>(
160  loc, arith::CmpIPredicate::sgt, shift32, thirtyTwo32);
161 
162  Value shiftHighL =
163  rewriter.create<arith::SubIOp>(loc, thirtyTwo32, shift32);
164  Value shiftHighR =
165  rewriter.create<arith::SubIOp>(loc, shift32, thirtyTwo32);
166 
167  shiftHighL =
168  rewriter.create<arith::SelectOp>(loc, shiftOver32, zero32, shiftHighL);
169  shiftHighR =
170  rewriter.create<arith::SelectOp>(loc, shiftOver32, shiftHighR, zero32);
171 
172  // Conditionally perform our double round.
173  if (op.getDoubleRound()) {
174  Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter);
175  Value valuePositive = rewriter.create<arith::CmpIOp>(
176  loc, arith::CmpIPredicate::sge, value32, zero32);
177 
178  Value roundDir =
179  rewriter.create<arith::SelectOp>(loc, valuePositive, one32, negOne32);
180  roundDir =
181  rewriter.create<arith::SelectOp>(loc, shiftOver32, roundDir, zero32);
182 
183  Value shiftLow = rewriter.create<arith::ShRUIOp>(loc, low32, thirty32);
184  Value rounded = rewriter.create<arith::AddIOp>(loc, shiftLow, roundDir);
185  Value carry = rewriter.create<arith::ShRSIOp>(loc, rounded, two32);
186 
187  Value shiftRound =
188  rewriter.create<arith::ShLIOp>(loc, roundDir, thirty32);
189 
190  low32 = rewriter.create<arith::AddIOp>(loc, low32, shiftRound);
191  high32 = rewriter.create<arith::AddIOp>(loc, high32, carry);
192  }
193 
194  // Conditionally apply rounding in the low bits.
195  {
196  Value shiftSubOne = rewriter.create<arith::SubIOp>(loc, shift32, one32);
197  Value roundBit = rewriter.create<arith::ShLIOp>(loc, one32, shiftSubOne);
198  roundBit = rewriter.create<arith::SelectOp>(loc, roundHighBits, zero32,
199  roundBit);
200 
201  Value newLow32 = rewriter.create<arith::AddIOp>(loc, low32, roundBit);
202  Value wasRounded = rewriter.create<arith::CmpIOp>(
203  loc, arith::CmpIPredicate::ugt, low32, newLow32);
204  low32 = newLow32;
205 
206  Value rounded32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, wasRounded);
207  high32 = rewriter.create<arith::AddIOp>(loc, high32, rounded32);
208  }
209 
210  // Conditionally apply rounding in the high bits.
211  {
212  Value shiftSubOne =
213  rewriter.create<arith::SubIOp>(loc, shiftHighR, one32);
214  Value roundBit = rewriter.create<arith::ShLIOp>(loc, one32, shiftSubOne);
215  roundBit = rewriter.create<arith::SelectOp>(loc, roundHighBits, roundBit,
216  zero32);
217  high32 = rewriter.create<arith::AddIOp>(loc, high32, roundBit);
218  }
219 
220  // Combine the correct high/low bits into the final rescale result.
221  high32 = rewriter.create<arith::ShLIOp>(loc, high32, shiftHighL);
222  high32 = rewriter.create<arith::ShRSIOp>(loc, high32, shiftHighR);
223  low32 = rewriter.create<arith::ShRUIOp>(loc, low32, shift32);
224  low32 = rewriter.create<arith::SelectOp>(loc, shiftOver32, zero32, low32);
225 
226  // Apply the rounding behavior and shift to the final alignment.
227  Value result = rewriter.create<arith::AddIOp>(loc, low32, high32);
228 
229  // Truncate if necessary.
230  if (!getElementTypeOrSelf(resultTy).isInteger(32)) {
231  result = rewriter.create<arith::TruncIOp>(loc, resultTy, result);
232  }
233 
234  rewriter.replaceOp(op, result);
235  return success();
236  }
237 };
238 
239 } // namespace
240 
242  RewritePatternSet *patterns) {
243  patterns->add<ConstOpConverter>(patterns->getContext());
244 }
245 
247  RewritePatternSet *patterns, bool include32Bit) {
248  patterns->add<ApplyScaleGenericOpConverter>(patterns->getContext(), 100);
249  if (include32Bit) {
250  patterns->add<ApplyScale32BitOpConverter>(patterns->getContext(), 200);
251  }
252 }
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:268
IntegerType getI64Type()
Definition: Builders.cpp:109
IntegerType getI32Type()
Definition: Builders.cpp:107
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:66
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
DynamicAPInt round(const Fraction &f)
Definition: Fraction.h:136
void populateTosaRescaleToArithConversionPatterns(RewritePatternSet *patterns, bool include32Bit=false)
void populateTosaToArithConversionPatterns(RewritePatternSet *patterns)
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358