28 LogicalResult matchAndRewrite(tosa::ConstOp op,
30 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, op.getValues());
36 if (
auto shapedTy = dyn_cast<ShapedType>(container))
37 return shapedTy.clone(element);
43 if (
auto shapedTy = dyn_cast<ShapedType>(type)) {
44 Type eTy = shapedTy.getElementType();
54 return arith::ConstantOp::create(rewriter, loc,
55 getConstantAttr(type, value, rewriter));
60 class ApplyScaleGenericOpConverter
65 LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
67 RoundingMode roundingMode = op.getRoundingMode();
68 if (roundingMode != RoundingMode::DOUBLE_ROUND &&
69 roundingMode != RoundingMode::SINGLE_ROUND) {
74 Value value = op.getValue();
75 Value multiplier32 = op.getMultiplier();
77 Type resultTy = op.getType();
82 Value zero = getConstantValue(loc, valueTy, 0, rewriter);
83 Value one64 = getConstantValue(loc, i64Ty, 1, rewriter);
84 Value thirtyOne32 = getConstantValue(loc, i32Ty, 31, rewriter);
86 Value shift32 = arith::ExtUIOp::create(rewriter, loc, i32Ty, op.getShift());
89 Value value64 = value;
91 value64 = arith::ExtSIOp::create(rewriter, loc, i64Ty, value);
93 arith::ExtSIOp::create(rewriter, loc, i64Ty, multiplier32);
95 arith::MulIOp::create(rewriter, loc, value64, multiplier64);
98 Value shift64 = arith::ExtUIOp::create(rewriter, loc, i64Ty, shift32);
99 Value round = arith::ShLIOp::create(rewriter, loc, one64, shift64);
100 round = arith::ShRUIOp::create(rewriter, loc,
round, one64);
101 multiply64 = arith::AddIOp::create(rewriter, loc, multiply64,
round);
104 if (op.getRoundingMode() == RoundingMode::DOUBLE_ROUND) {
105 int64_t roundInt = 1 << 30;
106 Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter);
107 Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter);
108 Value positive = arith::CmpIOp::create(
109 rewriter, loc, arith::CmpIPredicate::sge, value, zero);
111 arith::SelectOp::create(rewriter, loc, positive, roundUp, roundDown);
112 Value val = arith::AddIOp::create(rewriter, loc, dir, multiply64);
113 Value valid = arith::CmpIOp::create(
114 rewriter, loc, arith::CmpIPredicate::sgt, shift32, thirtyOne32);
116 arith::SelectOp::create(rewriter, loc, valid, val, multiply64);
119 Value result64 = arith::ShRSIOp::create(rewriter, loc, multiply64, shift64);
120 Value result32 = arith::TruncIOp::create(rewriter, loc, i32Ty, result64);
127 class ApplyScale32BitOpConverter :
public OpRewritePattern<tosa::ApplyScaleOp> {
131 LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
133 RoundingMode roundingMode = op.getRoundingMode();
134 if (roundingMode != RoundingMode::DOUBLE_ROUND &&
135 roundingMode != RoundingMode::SINGLE_ROUND) {
141 Type resultTy = op.getType();
144 Value value = op.getValue();
149 Value value32 = op.getValue();
150 Value multiplier32 = op.getMultiplier();
151 Value shift32 = arith::ExtUIOp::create(rewriter, loc, i32Ty, op.getShift());
154 Value zero32 = getConstantValue(loc, i32Ty, 0, rewriter);
155 Value one32 = getConstantValue(loc, i32Ty, 1, rewriter);
156 Value two32 = getConstantValue(loc, i32Ty, 2, rewriter);
157 Value thirty32 = getConstantValue(loc, i32Ty, 30, rewriter);
158 Value thirtyTwo32 = getConstantValue(loc, i32Ty, 32, rewriter);
163 arith::MulSIExtendedOp::create(rewriter, loc, value32, multiplier32);
164 Value low32 = value64.getLow();
165 Value high32 = value64.getHigh();
168 Value shiftOver32 = arith::CmpIOp::create(
169 rewriter, loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32);
170 Value roundHighBits = arith::CmpIOp::create(
171 rewriter, loc, arith::CmpIPredicate::sgt, shift32, thirtyTwo32);
174 arith::SubIOp::create(rewriter, loc, thirtyTwo32, shift32);
176 arith::SubIOp::create(rewriter, loc, shift32, thirtyTwo32);
179 arith::SelectOp::create(rewriter, loc, shiftOver32, zero32, shiftHighL);
181 arith::SelectOp::create(rewriter, loc, shiftOver32, shiftHighR, zero32);
184 if (op.getRoundingMode() == RoundingMode::DOUBLE_ROUND) {
185 Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter);
186 Value valuePositive = arith::CmpIOp::create(
187 rewriter, loc, arith::CmpIPredicate::sge, value32, zero32);
189 Value roundDir = arith::SelectOp::create(rewriter, loc, valuePositive,
192 arith::SelectOp::create(rewriter, loc, shiftOver32, roundDir, zero32);
194 Value shiftLow = arith::ShRUIOp::create(rewriter, loc, low32, thirty32);
195 Value rounded = arith::AddIOp::create(rewriter, loc, shiftLow, roundDir);
196 Value carry = arith::ShRSIOp::create(rewriter, loc, rounded, two32);
199 arith::ShLIOp::create(rewriter, loc, roundDir, thirty32);
201 low32 = arith::AddIOp::create(rewriter, loc, low32, shiftRound);
202 high32 = arith::AddIOp::create(rewriter, loc, high32, carry);
207 Value shiftSubOne = arith::SubIOp::create(rewriter, loc, shift32, one32);
208 Value roundBit = arith::ShLIOp::create(rewriter, loc, one32, shiftSubOne);
209 roundBit = arith::SelectOp::create(rewriter, loc, roundHighBits, zero32,
212 Value newLow32 = arith::AddIOp::create(rewriter, loc, low32, roundBit);
213 Value wasRounded = arith::CmpIOp::create(
214 rewriter, loc, arith::CmpIPredicate::ugt, low32, newLow32);
218 arith::ExtUIOp::create(rewriter, loc, i32Ty, wasRounded);
219 high32 = arith::AddIOp::create(rewriter, loc, high32, rounded32);
225 arith::SubIOp::create(rewriter, loc, shiftHighR, one32);
226 Value roundBit = arith::ShLIOp::create(rewriter, loc, one32, shiftSubOne);
227 roundBit = arith::SelectOp::create(rewriter, loc, roundHighBits, roundBit,
229 high32 = arith::AddIOp::create(rewriter, loc, high32, roundBit);
233 high32 = arith::ShLIOp::create(rewriter, loc, high32, shiftHighL);
234 high32 = arith::ShRSIOp::create(rewriter, loc, high32, shiftHighR);
235 low32 = arith::ShRUIOp::create(rewriter, loc, low32, shift32);
236 low32 = arith::SelectOp::create(rewriter, loc, shiftOver32, zero32, low32);
239 Value result = arith::AddIOp::create(rewriter, loc, low32, high32);
243 result = arith::TruncIOp::create(rewriter, loc, resultTy, result);
IntegerAttr getIntegerAttr(Type type, int64_t value)
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
DynamicAPInt round(const Fraction &f)
void populateTosaRescaleToArithConversionPatterns(RewritePatternSet *patterns, bool include32Bit=false)
void populateTosaToArithConversionPatterns(RewritePatternSet *patterns)
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...