MLIR  22.0.0git
TosaToArith.cpp
Go to the documentation of this file.
1 //===- TosaToArith.cpp - Lowering Tosa to Arith Dialect -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Arith dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/TypeUtilities.h"
18 
19 using namespace mlir;
20 using namespace tosa;
21 
22 namespace {
23 
24 class ConstOpConverter : public OpRewritePattern<tosa::ConstOp> {
25 public:
27 
28  LogicalResult matchAndRewrite(tosa::ConstOp op,
29  PatternRewriter &rewriter) const final {
30  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, op.getValues());
31  return success();
32  }
33 };
34 
35 Type matchContainerType(Type element, Type container) {
36  if (auto shapedTy = dyn_cast<ShapedType>(container))
37  return shapedTy.clone(element);
38 
39  return element;
40 }
41 
42 TypedAttr getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) {
43  if (auto shapedTy = dyn_cast<ShapedType>(type)) {
44  Type eTy = shapedTy.getElementType();
45  APInt valueInt(eTy.getIntOrFloatBitWidth(), value, /*isSigned=*/true);
46  return DenseIntElementsAttr::get(shapedTy, valueInt);
47  }
48 
49  return rewriter.getIntegerAttr(type, value);
50 }
51 
52 Value getConstantValue(Location loc, Type type, int64_t value,
53  PatternRewriter &rewriter) {
54  return rewriter.create<arith::ConstantOp>(
55  loc, getConstantAttr(type, value, rewriter));
56 }
57 
58 // This converts the TOSA ApplyScale operator to a set of arithmetic ops,
59 // using 64-bit operations to perform the necessary multiply, bias, and shift.
60 class ApplyScaleGenericOpConverter
61  : public OpRewritePattern<tosa::ApplyScaleOp> {
62 public:
64 
65  LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
66  PatternRewriter &rewriter) const final {
67  StringRef roundingMode = op.getRoundingMode();
68  if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
69  return failure();
70  }
71 
72  Location loc = op.getLoc();
73  Value value = op.getValue();
74  Value multiplier32 = op.getMultiplier();
75 
76  Type resultTy = op.getType();
77  Type valueTy = value.getType();
78  Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy);
79  Type i64Ty = matchContainerType(rewriter.getI64Type(), resultTy);
80 
81  Value zero = getConstantValue(loc, valueTy, 0, rewriter);
82  Value one64 = getConstantValue(loc, i64Ty, 1, rewriter);
83  Value thirtyOne32 = getConstantValue(loc, i32Ty, 31, rewriter);
84 
85  Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.getShift());
86 
87  // Compute the multiplication in 64-bits then select the high / low parts.
88  Value value64 = value;
89  if (getElementTypeOrSelf(valueTy) != rewriter.getI64Type())
90  value64 = rewriter.create<arith::ExtSIOp>(loc, i64Ty, value);
91  Value multiplier64 =
92  rewriter.create<arith::ExtSIOp>(loc, i64Ty, multiplier32);
93  Value multiply64 =
94  rewriter.create<arith::MulIOp>(loc, value64, multiplier64);
95 
96  // Apply normal rounding.
97  Value shift64 = rewriter.create<arith::ExtUIOp>(loc, i64Ty, shift32);
98  Value round = rewriter.create<arith::ShLIOp>(loc, one64, shift64);
99  round = rewriter.create<arith::ShRUIOp>(loc, round, one64);
100  multiply64 = rewriter.create<arith::AddIOp>(loc, multiply64, round);
101 
102  // Apply double rounding if necessary.
103  if (op.getRoundingMode() == "DOUBLE_ROUND") {
104  int64_t roundInt = 1 << 30;
105  Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter);
106  Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter);
107  Value positive = rewriter.create<arith::CmpIOp>(
108  loc, arith::CmpIPredicate::sge, value, zero);
109  Value dir =
110  rewriter.create<arith::SelectOp>(loc, positive, roundUp, roundDown);
111  Value val = rewriter.create<arith::AddIOp>(loc, dir, multiply64);
112  Value valid = rewriter.create<arith::CmpIOp>(
113  loc, arith::CmpIPredicate::sgt, shift32, thirtyOne32);
114  multiply64 =
115  rewriter.create<arith::SelectOp>(loc, valid, val, multiply64);
116  }
117 
118  Value result64 = rewriter.create<arith::ShRSIOp>(loc, multiply64, shift64);
119  Value result32 = rewriter.create<arith::TruncIOp>(loc, i32Ty, result64);
120 
121  rewriter.replaceOp(op, result32);
122  return success();
123  }
124 };
125 
126 class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
127 public:
129 
130  LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
131  PatternRewriter &rewriter) const final {
132  StringRef roundingMode = op.getRoundingMode();
133  if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
134  return failure();
135  }
136 
137  Location loc = op.getLoc();
138 
139  Type resultTy = op.getType();
140  Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy);
141 
142  Value value = op.getValue();
143  if (getElementTypeOrSelf(value.getType()).getIntOrFloatBitWidth() > 32) {
144  return failure();
145  }
146 
147  Value value32 = op.getValue();
148  Value multiplier32 = op.getMultiplier();
149  Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.getShift());
150 
151  // Constants used during the scaling operation.
152  Value zero32 = getConstantValue(loc, i32Ty, 0, rewriter);
153  Value one32 = getConstantValue(loc, i32Ty, 1, rewriter);
154  Value two32 = getConstantValue(loc, i32Ty, 2, rewriter);
155  Value thirty32 = getConstantValue(loc, i32Ty, 30, rewriter);
156  Value thirtyTwo32 = getConstantValue(loc, i32Ty, 32, rewriter);
157 
158  // Compute the multiplication in 64-bits then select the high / low parts.
159  // Grab out the high/low of the computation
160  auto value64 =
161  rewriter.create<arith::MulSIExtendedOp>(loc, value32, multiplier32);
162  Value low32 = value64.getLow();
163  Value high32 = value64.getHigh();
164 
165  // Determine the direction and amount to shift the high bits.
166  Value shiftOver32 = rewriter.create<arith::CmpIOp>(
167  loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32);
168  Value roundHighBits = rewriter.create<arith::CmpIOp>(
169  loc, arith::CmpIPredicate::sgt, shift32, thirtyTwo32);
170 
171  Value shiftHighL =
172  rewriter.create<arith::SubIOp>(loc, thirtyTwo32, shift32);
173  Value shiftHighR =
174  rewriter.create<arith::SubIOp>(loc, shift32, thirtyTwo32);
175 
176  shiftHighL =
177  rewriter.create<arith::SelectOp>(loc, shiftOver32, zero32, shiftHighL);
178  shiftHighR =
179  rewriter.create<arith::SelectOp>(loc, shiftOver32, shiftHighR, zero32);
180 
181  // Conditionally perform our double round.
182  if (op.getRoundingMode() == "DOUBLE_ROUND") {
183  Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter);
184  Value valuePositive = rewriter.create<arith::CmpIOp>(
185  loc, arith::CmpIPredicate::sge, value32, zero32);
186 
187  Value roundDir =
188  rewriter.create<arith::SelectOp>(loc, valuePositive, one32, negOne32);
189  roundDir =
190  rewriter.create<arith::SelectOp>(loc, shiftOver32, roundDir, zero32);
191 
192  Value shiftLow = rewriter.create<arith::ShRUIOp>(loc, low32, thirty32);
193  Value rounded = rewriter.create<arith::AddIOp>(loc, shiftLow, roundDir);
194  Value carry = rewriter.create<arith::ShRSIOp>(loc, rounded, two32);
195 
196  Value shiftRound =
197  rewriter.create<arith::ShLIOp>(loc, roundDir, thirty32);
198 
199  low32 = rewriter.create<arith::AddIOp>(loc, low32, shiftRound);
200  high32 = rewriter.create<arith::AddIOp>(loc, high32, carry);
201  }
202 
203  // Conditionally apply rounding in the low bits.
204  {
205  Value shiftSubOne = rewriter.create<arith::SubIOp>(loc, shift32, one32);
206  Value roundBit = rewriter.create<arith::ShLIOp>(loc, one32, shiftSubOne);
207  roundBit = rewriter.create<arith::SelectOp>(loc, roundHighBits, zero32,
208  roundBit);
209 
210  Value newLow32 = rewriter.create<arith::AddIOp>(loc, low32, roundBit);
211  Value wasRounded = rewriter.create<arith::CmpIOp>(
212  loc, arith::CmpIPredicate::ugt, low32, newLow32);
213  low32 = newLow32;
214 
215  Value rounded32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, wasRounded);
216  high32 = rewriter.create<arith::AddIOp>(loc, high32, rounded32);
217  }
218 
219  // Conditionally apply rounding in the high bits.
220  {
221  Value shiftSubOne =
222  rewriter.create<arith::SubIOp>(loc, shiftHighR, one32);
223  Value roundBit = rewriter.create<arith::ShLIOp>(loc, one32, shiftSubOne);
224  roundBit = rewriter.create<arith::SelectOp>(loc, roundHighBits, roundBit,
225  zero32);
226  high32 = rewriter.create<arith::AddIOp>(loc, high32, roundBit);
227  }
228 
229  // Combine the correct high/low bits into the final rescale result.
230  high32 = rewriter.create<arith::ShLIOp>(loc, high32, shiftHighL);
231  high32 = rewriter.create<arith::ShRSIOp>(loc, high32, shiftHighR);
232  low32 = rewriter.create<arith::ShRUIOp>(loc, low32, shift32);
233  low32 = rewriter.create<arith::SelectOp>(loc, shiftOver32, zero32, low32);
234 
235  // Apply the rounding behavior and shift to the final alignment.
236  Value result = rewriter.create<arith::AddIOp>(loc, low32, high32);
237 
238  // Truncate if necessary.
239  if (!getElementTypeOrSelf(resultTy).isInteger(32)) {
240  result = rewriter.create<arith::TruncIOp>(loc, resultTy, result);
241  }
242 
243  rewriter.replaceOp(op, result);
244  return success();
245  }
246 };
247 
248 } // namespace
249 
252  patterns->add<ConstOpConverter>(patterns->getContext());
253 }
254 
256  RewritePatternSet *patterns, bool include32Bit) {
257  patterns->add<ApplyScaleGenericOpConverter>(patterns->getContext(), 100);
258  if (include32Bit) {
259  patterns->add<ApplyScale32BitOpConverter>(patterns->getContext(), 200);
260  }
261 }
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
IntegerType getI64Type()
Definition: Builders.cpp:64
IntegerType getI32Type()
Definition: Builders.cpp:62
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:767
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
DynamicAPInt round(const Fraction &f)
Definition: Fraction.h:136
void populateTosaRescaleToArithConversionPatterns(RewritePatternSet *patterns, bool include32Bit=false)
void populateTosaToArithConversionPatterns(RewritePatternSet *patterns)
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314