MLIR  17.0.0git
TosaToArith.cpp
Go to the documentation of this file.
1 //===- TosaToArith.cpp - Lowering Tosa to Arith Dialect -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Arith dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/TypeUtilities.h"
19 
20 using namespace mlir;
21 using namespace tosa;
22 
23 namespace {
24 
25 class ConstOpConverter : public OpRewritePattern<tosa::ConstOp> {
26 public:
28 
29  LogicalResult matchAndRewrite(tosa::ConstOp op,
30  PatternRewriter &rewriter) const final {
31  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, op.getValue());
32  return success();
33  }
34 };
35 
36 Type matchContainerType(Type element, Type container) {
37  if (auto shapedTy = container.dyn_cast<ShapedType>())
38  return shapedTy.clone(element);
39 
40  return element;
41 }
42 
43 Attribute getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) {
44  if (auto shapedTy = type.dyn_cast<ShapedType>()) {
45  Type eTy = shapedTy.getElementType();
46  APInt valueInt(eTy.getIntOrFloatBitWidth(), value);
47  return DenseIntElementsAttr::get(shapedTy, valueInt);
48  }
49 
50  return rewriter.getIntegerAttr(type, value);
51 }
52 
53 Value getConstantValue(Location loc, Type type, int64_t value,
54  PatternRewriter &rewriter) {
55  return rewriter.create<arith::ConstantOp>(
56  loc, getConstantAttr(type, value, rewriter));
57 }
58 
59 // This converts the TOSA ApplyScale operator to a set of arithmetic ops,
60 // using 64-bit operations to perform the necessary multiply, bias, and shift.
61 class ApplyScaleGenericOpConverter
62  : public OpRewritePattern<tosa::ApplyScaleOp> {
63 public:
65 
66  LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
67  PatternRewriter &rewriter) const final {
68  Location loc = op.getLoc();
69  Value value = op.getValue();
70  Value multiplier32 = op.getMultiplier();
71 
72  Type resultTy = op.getType();
73  Type valueTy = value.getType();
74  Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy);
75  Type i64Ty = matchContainerType(rewriter.getI64Type(), resultTy);
76 
77  Value zero = getConstantValue(loc, valueTy, 0, rewriter);
78  Value one64 = getConstantValue(loc, i64Ty, 1, rewriter);
79  Value thirtyOne32 = getConstantValue(loc, i32Ty, 31, rewriter);
80 
81  Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.getShift());
82 
83  // Compute the multiplication in 64-bits then select the high / low parts.
84  Value value64 = rewriter.create<arith::ExtSIOp>(loc, i64Ty, value);
85  Value multiplier64 =
86  rewriter.create<arith::ExtSIOp>(loc, i64Ty, multiplier32);
87  Value multiply64 =
88  rewriter.create<arith::MulIOp>(loc, value64, multiplier64);
89 
90  // Apply normal rounding.
91  Value shift64 = rewriter.create<arith::ExtUIOp>(loc, i64Ty, shift32);
92  Value round = rewriter.create<arith::ShLIOp>(loc, one64, shift64);
93  round = rewriter.create<arith::ShRUIOp>(loc, round, one64);
94  multiply64 = rewriter.create<arith::AddIOp>(loc, multiply64, round);
95 
96  // Apply double rounding if necessary.
97  if (op.getDoubleRound()) {
98  int64_t roundInt = 1 << 30;
99  Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter);
100  Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter);
101  Value positive = rewriter.create<arith::CmpIOp>(
102  loc, arith::CmpIPredicate::sge, value, zero);
103  Value dir =
104  rewriter.create<arith::SelectOp>(loc, positive, roundUp, roundDown);
105  Value val = rewriter.create<arith::AddIOp>(loc, dir, multiply64);
106  Value valid = rewriter.create<arith::CmpIOp>(
107  loc, arith::CmpIPredicate::sgt, shift32, thirtyOne32);
108  multiply64 =
109  rewriter.create<arith::SelectOp>(loc, valid, val, multiply64);
110  }
111 
112  Value result64 = rewriter.create<arith::ShRSIOp>(loc, multiply64, shift64);
113  Value result32 = rewriter.create<arith::TruncIOp>(loc, i32Ty, result64);
114 
115  rewriter.replaceOp(op, result32);
116  return success();
117  }
118 };
119 
120 class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
121 public:
123 
124  LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
125  PatternRewriter &rewriter) const final {
126  Location loc = op.getLoc();
127 
128  Type resultTy = op.getType();
129  Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy);
130 
131  Value value = op.getValue();
132  if (getElementTypeOrSelf(value.getType()).getIntOrFloatBitWidth() > 32) {
133  return failure();
134  }
135 
136  Value value32 = op.getValue();
137  Value multiplier32 = op.getMultiplier();
138  Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.getShift());
139 
140  // Constants used during the scaling operation.
141  Value zero32 = getConstantValue(loc, i32Ty, 0, rewriter);
142  Value one32 = getConstantValue(loc, i32Ty, 1, rewriter);
143  Value two32 = getConstantValue(loc, i32Ty, 2, rewriter);
144  Value thirty32 = getConstantValue(loc, i32Ty, 30, rewriter);
145  Value thirtyTwo32 = getConstantValue(loc, i32Ty, 32, rewriter);
146 
147  // Compute the multiplication in 64-bits then select the high / low parts.
148  // Grab out the high/low of the computation
149  auto value64 =
150  rewriter.create<arith::MulSIExtendedOp>(loc, value32, multiplier32);
151  Value low32 = value64.getLow();
152  Value high32 = value64.getHigh();
153 
154  // Determine the direction and amount to shift the high bits.
155  Value shiftOver32 = rewriter.create<arith::CmpIOp>(
156  loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32);
157  Value roundHighBits = rewriter.create<arith::CmpIOp>(
158  loc, arith::CmpIPredicate::sgt, shift32, thirtyTwo32);
159 
160  Value shiftHighL =
161  rewriter.create<arith::SubIOp>(loc, thirtyTwo32, shift32);
162  Value shiftHighR =
163  rewriter.create<arith::SubIOp>(loc, shift32, thirtyTwo32);
164 
165  shiftHighL =
166  rewriter.create<arith::SelectOp>(loc, shiftOver32, zero32, shiftHighL);
167  shiftHighR =
168  rewriter.create<arith::SelectOp>(loc, shiftOver32, shiftHighR, zero32);
169 
170  // Conditionally perform our double round.
171  if (op.getDoubleRound()) {
172  Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter);
173  Value valuePositive = rewriter.create<arith::CmpIOp>(
174  loc, arith::CmpIPredicate::sge, value32, zero32);
175 
176  Value roundDir =
177  rewriter.create<arith::SelectOp>(loc, valuePositive, one32, negOne32);
178  roundDir =
179  rewriter.create<arith::SelectOp>(loc, shiftOver32, roundDir, zero32);
180 
181  Value shiftLow = rewriter.create<arith::ShRUIOp>(loc, low32, thirty32);
182  Value rounded = rewriter.create<arith::AddIOp>(loc, shiftLow, roundDir);
183  Value carry = rewriter.create<arith::ShRSIOp>(loc, rounded, two32);
184 
185  Value shiftRound =
186  rewriter.create<arith::ShLIOp>(loc, roundDir, thirty32);
187 
188  low32 = rewriter.create<arith::AddIOp>(loc, low32, shiftRound);
189  high32 = rewriter.create<arith::AddIOp>(loc, high32, carry);
190  }
191 
192  // Conditionally apply rounding in the low bits.
193  {
194  Value shiftSubOne = rewriter.create<arith::SubIOp>(loc, shift32, one32);
195  Value roundBit = rewriter.create<arith::ShLIOp>(loc, one32, shiftSubOne);
196  roundBit = rewriter.create<arith::SelectOp>(loc, roundHighBits, zero32,
197  roundBit);
198 
199  Value newLow32 = rewriter.create<arith::AddIOp>(loc, low32, roundBit);
200  Value wasRounded = rewriter.create<arith::CmpIOp>(
201  loc, arith::CmpIPredicate::ugt, low32, newLow32);
202  low32 = newLow32;
203 
204  Value rounded32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, wasRounded);
205  high32 = rewriter.create<arith::AddIOp>(loc, high32, rounded32);
206  }
207 
208  // Conditionally apply rounding in the high bits.
209  {
210  Value shiftSubOne =
211  rewriter.create<arith::SubIOp>(loc, shiftHighR, one32);
212  Value roundBit = rewriter.create<arith::ShLIOp>(loc, one32, shiftSubOne);
213  roundBit = rewriter.create<arith::SelectOp>(loc, roundHighBits, roundBit,
214  zero32);
215  high32 = rewriter.create<arith::AddIOp>(loc, high32, roundBit);
216  }
217 
218  // Combine the correct high/low bits into the final rescale result.
219  high32 = rewriter.create<arith::ShLIOp>(loc, high32, shiftHighL);
220  high32 = rewriter.create<arith::ShRSIOp>(loc, high32, shiftHighR);
221  low32 = rewriter.create<arith::ShRUIOp>(loc, low32, shift32);
222  low32 = rewriter.create<arith::SelectOp>(loc, shiftOver32, zero32, low32);
223 
224  // Apply the rounding behavior and shift to the final alignment.
225  Value result = rewriter.create<arith::AddIOp>(loc, low32, high32);
226 
227  // Truncate if necessary.
228  if (!getElementTypeOrSelf(resultTy).isInteger(32)) {
229  result = rewriter.create<arith::TruncIOp>(loc, resultTy, result);
230  }
231 
232  rewriter.replaceOp(op, result);
233  return success();
234  }
235 };
236 
237 } // namespace
238 
240  RewritePatternSet *patterns) {
241  patterns->add<ConstOpConverter>(patterns->getContext());
242 }
243 
245  RewritePatternSet *patterns, bool include32Bit) {
246  patterns->add<ApplyScaleGenericOpConverter>(patterns->getContext(), 100);
247  if (include32Bit) {
248  patterns->add<ApplyScale32BitOpConverter>(patterns->getContext(), 200);
249  }
250 }
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:212
IntegerType getI64Type()
Definition: Builders.cpp:70
IntegerType getI32Type()
Definition: Builders.cpp:68
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:621
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
Definition: Types.cpp:49
U dyn_cast() const
Definition: Types.h:311
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:109
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
void populateTosaRescaleToArithConversionPatterns(RewritePatternSet *patterns, bool include32Bit=false)
void populateTosaToArithConversionPatterns(RewritePatternSet *patterns)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357