MLIR  21.0.0git
TosaToArith.cpp
Go to the documentation of this file.
1 //===- TosaToArith.cpp - Lowering Tosa to Arith Dialect -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Arith dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/TypeUtilities.h"
19 
20 using namespace mlir;
21 using namespace tosa;
22 
23 namespace {
24 
25 class ConstOpConverter : public OpRewritePattern<tosa::ConstOp> {
26 public:
28 
29  LogicalResult matchAndRewrite(tosa::ConstOp op,
30  PatternRewriter &rewriter) const final {
31  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, op.getValues());
32  return success();
33  }
34 };
35 
36 Type matchContainerType(Type element, Type container) {
37  if (auto shapedTy = dyn_cast<ShapedType>(container))
38  return shapedTy.clone(element);
39 
40  return element;
41 }
42 
43 TypedAttr getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) {
44  if (auto shapedTy = dyn_cast<ShapedType>(type)) {
45  Type eTy = shapedTy.getElementType();
46  APInt valueInt(eTy.getIntOrFloatBitWidth(), value, /*isSigned=*/true);
47  return DenseIntElementsAttr::get(shapedTy, valueInt);
48  }
49 
50  return rewriter.getIntegerAttr(type, value);
51 }
52 
53 Value getConstantValue(Location loc, Type type, int64_t value,
54  PatternRewriter &rewriter) {
55  return rewriter.create<arith::ConstantOp>(
56  loc, getConstantAttr(type, value, rewriter));
57 }
58 
59 // This converts the TOSA ApplyScale operator to a set of arithmetic ops,
60 // using 64-bit operations to perform the necessary multiply, bias, and shift.
61 class ApplyScaleGenericOpConverter
62  : public OpRewritePattern<tosa::ApplyScaleOp> {
63 public:
65 
66  LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
67  PatternRewriter &rewriter) const final {
68  StringRef roundingMode = op.getRoundingMode();
69  if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
70  return failure();
71  }
72 
73  Location loc = op.getLoc();
74  Value value = op.getValue();
75  Value multiplier32 = op.getMultiplier();
76 
77  Type resultTy = op.getType();
78  Type valueTy = value.getType();
79  Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy);
80  Type i64Ty = matchContainerType(rewriter.getI64Type(), resultTy);
81 
82  Value zero = getConstantValue(loc, valueTy, 0, rewriter);
83  Value one64 = getConstantValue(loc, i64Ty, 1, rewriter);
84  Value thirtyOne32 = getConstantValue(loc, i32Ty, 31, rewriter);
85 
86  Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.getShift());
87 
88  // Compute the multiplication in 64-bits then select the high / low parts.
89  Value value64 = value;
90  if (getElementTypeOrSelf(valueTy) != rewriter.getI64Type())
91  value64 = rewriter.create<arith::ExtSIOp>(loc, i64Ty, value);
92  Value multiplier64 =
93  rewriter.create<arith::ExtSIOp>(loc, i64Ty, multiplier32);
94  Value multiply64 =
95  rewriter.create<arith::MulIOp>(loc, value64, multiplier64);
96 
97  // Apply normal rounding.
98  Value shift64 = rewriter.create<arith::ExtUIOp>(loc, i64Ty, shift32);
99  Value round = rewriter.create<arith::ShLIOp>(loc, one64, shift64);
100  round = rewriter.create<arith::ShRUIOp>(loc, round, one64);
101  multiply64 = rewriter.create<arith::AddIOp>(loc, multiply64, round);
102 
103  // Apply double rounding if necessary.
104  if (op.getRoundingMode() == "DOUBLE_ROUND") {
105  int64_t roundInt = 1 << 30;
106  Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter);
107  Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter);
108  Value positive = rewriter.create<arith::CmpIOp>(
109  loc, arith::CmpIPredicate::sge, value, zero);
110  Value dir =
111  rewriter.create<arith::SelectOp>(loc, positive, roundUp, roundDown);
112  Value val = rewriter.create<arith::AddIOp>(loc, dir, multiply64);
113  Value valid = rewriter.create<arith::CmpIOp>(
114  loc, arith::CmpIPredicate::sgt, shift32, thirtyOne32);
115  multiply64 =
116  rewriter.create<arith::SelectOp>(loc, valid, val, multiply64);
117  }
118 
119  Value result64 = rewriter.create<arith::ShRSIOp>(loc, multiply64, shift64);
120  Value result32 = rewriter.create<arith::TruncIOp>(loc, i32Ty, result64);
121 
122  rewriter.replaceOp(op, result32);
123  return success();
124  }
125 };
126 
127 class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
128 public:
130 
131  LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
132  PatternRewriter &rewriter) const final {
133  StringRef roundingMode = op.getRoundingMode();
134  if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
135  return failure();
136  }
137 
138  Location loc = op.getLoc();
139 
140  Type resultTy = op.getType();
141  Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy);
142 
143  Value value = op.getValue();
144  if (getElementTypeOrSelf(value.getType()).getIntOrFloatBitWidth() > 32) {
145  return failure();
146  }
147 
148  Value value32 = op.getValue();
149  Value multiplier32 = op.getMultiplier();
150  Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.getShift());
151 
152  // Constants used during the scaling operation.
153  Value zero32 = getConstantValue(loc, i32Ty, 0, rewriter);
154  Value one32 = getConstantValue(loc, i32Ty, 1, rewriter);
155  Value two32 = getConstantValue(loc, i32Ty, 2, rewriter);
156  Value thirty32 = getConstantValue(loc, i32Ty, 30, rewriter);
157  Value thirtyTwo32 = getConstantValue(loc, i32Ty, 32, rewriter);
158 
159  // Compute the multiplication in 64-bits then select the high / low parts.
160  // Grab out the high/low of the computation
161  auto value64 =
162  rewriter.create<arith::MulSIExtendedOp>(loc, value32, multiplier32);
163  Value low32 = value64.getLow();
164  Value high32 = value64.getHigh();
165 
166  // Determine the direction and amount to shift the high bits.
167  Value shiftOver32 = rewriter.create<arith::CmpIOp>(
168  loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32);
169  Value roundHighBits = rewriter.create<arith::CmpIOp>(
170  loc, arith::CmpIPredicate::sgt, shift32, thirtyTwo32);
171 
172  Value shiftHighL =
173  rewriter.create<arith::SubIOp>(loc, thirtyTwo32, shift32);
174  Value shiftHighR =
175  rewriter.create<arith::SubIOp>(loc, shift32, thirtyTwo32);
176 
177  shiftHighL =
178  rewriter.create<arith::SelectOp>(loc, shiftOver32, zero32, shiftHighL);
179  shiftHighR =
180  rewriter.create<arith::SelectOp>(loc, shiftOver32, shiftHighR, zero32);
181 
182  // Conditionally perform our double round.
183  if (op.getRoundingMode() == "DOUBLE_ROUND") {
184  Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter);
185  Value valuePositive = rewriter.create<arith::CmpIOp>(
186  loc, arith::CmpIPredicate::sge, value32, zero32);
187 
188  Value roundDir =
189  rewriter.create<arith::SelectOp>(loc, valuePositive, one32, negOne32);
190  roundDir =
191  rewriter.create<arith::SelectOp>(loc, shiftOver32, roundDir, zero32);
192 
193  Value shiftLow = rewriter.create<arith::ShRUIOp>(loc, low32, thirty32);
194  Value rounded = rewriter.create<arith::AddIOp>(loc, shiftLow, roundDir);
195  Value carry = rewriter.create<arith::ShRSIOp>(loc, rounded, two32);
196 
197  Value shiftRound =
198  rewriter.create<arith::ShLIOp>(loc, roundDir, thirty32);
199 
200  low32 = rewriter.create<arith::AddIOp>(loc, low32, shiftRound);
201  high32 = rewriter.create<arith::AddIOp>(loc, high32, carry);
202  }
203 
204  // Conditionally apply rounding in the low bits.
205  {
206  Value shiftSubOne = rewriter.create<arith::SubIOp>(loc, shift32, one32);
207  Value roundBit = rewriter.create<arith::ShLIOp>(loc, one32, shiftSubOne);
208  roundBit = rewriter.create<arith::SelectOp>(loc, roundHighBits, zero32,
209  roundBit);
210 
211  Value newLow32 = rewriter.create<arith::AddIOp>(loc, low32, roundBit);
212  Value wasRounded = rewriter.create<arith::CmpIOp>(
213  loc, arith::CmpIPredicate::ugt, low32, newLow32);
214  low32 = newLow32;
215 
216  Value rounded32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, wasRounded);
217  high32 = rewriter.create<arith::AddIOp>(loc, high32, rounded32);
218  }
219 
220  // Conditionally apply rounding in the high bits.
221  {
222  Value shiftSubOne =
223  rewriter.create<arith::SubIOp>(loc, shiftHighR, one32);
224  Value roundBit = rewriter.create<arith::ShLIOp>(loc, one32, shiftSubOne);
225  roundBit = rewriter.create<arith::SelectOp>(loc, roundHighBits, roundBit,
226  zero32);
227  high32 = rewriter.create<arith::AddIOp>(loc, high32, roundBit);
228  }
229 
230  // Combine the correct high/low bits into the final rescale result.
231  high32 = rewriter.create<arith::ShLIOp>(loc, high32, shiftHighL);
232  high32 = rewriter.create<arith::ShRSIOp>(loc, high32, shiftHighR);
233  low32 = rewriter.create<arith::ShRUIOp>(loc, low32, shift32);
234  low32 = rewriter.create<arith::SelectOp>(loc, shiftOver32, zero32, low32);
235 
236  // Apply the rounding behavior and shift to the final alignment.
237  Value result = rewriter.create<arith::AddIOp>(loc, low32, high32);
238 
239  // Truncate if necessary.
240  if (!getElementTypeOrSelf(resultTy).isInteger(32)) {
241  result = rewriter.create<arith::TruncIOp>(loc, resultTy, result);
242  }
243 
244  rewriter.replaceOp(op, result);
245  return success();
246  }
247 };
248 
249 } // namespace
250 
253  patterns->add<ConstOpConverter>(patterns->getContext());
254 }
255 
257  RewritePatternSet *patterns, bool include32Bit) {
258  patterns->add<ApplyScaleGenericOpConverter>(patterns->getContext(), 100);
259  if (include32Bit) {
260  patterns->add<ApplyScale32BitOpConverter>(patterns->getContext(), 200);
261  }
262 }
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
IntegerType getI64Type()
Definition: Builders.cpp:65
IntegerType getI32Type()
Definition: Builders.cpp:63
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:803
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
DynamicAPInt round(const Fraction &f)
Definition: Fraction.h:136
void populateTosaRescaleToArithConversionPatterns(RewritePatternSet *patterns, bool include32Bit=false)
void populateTosaToArithConversionPatterns(RewritePatternSet *patterns)
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358