29 LogicalResult matchAndRewrite(tosa::ConstOp op,
31 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, op.getValues());
37 if (
auto shapedTy = dyn_cast<ShapedType>(container))
38 return shapedTy.clone(element);
44 if (
auto shapedTy = dyn_cast<ShapedType>(type)) {
45 Type eTy = shapedTy.getElementType();
55 return rewriter.
create<arith::ConstantOp>(
56 loc, getConstantAttr(type, value, rewriter));
61 class ApplyScaleGenericOpConverter
66 LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
68 StringRef roundingMode = op.getRoundingMode();
69 if (roundingMode !=
"DOUBLE_ROUND" && roundingMode !=
"SINGLE_ROUND") {
74 Value value = op.getValue();
75 Value multiplier32 = op.getMultiplier();
77 Type resultTy = op.getType();
82 Value zero = getConstantValue(loc, valueTy, 0, rewriter);
83 Value one64 = getConstantValue(loc, i64Ty, 1, rewriter);
84 Value thirtyOne32 = getConstantValue(loc, i32Ty, 31, rewriter);
86 Value shift32 = rewriter.
create<arith::ExtUIOp>(loc, i32Ty, op.getShift());
89 Value value64 = value;
91 value64 = rewriter.
create<arith::ExtSIOp>(loc, i64Ty, value);
93 rewriter.
create<arith::ExtSIOp>(loc, i64Ty, multiplier32);
95 rewriter.
create<arith::MulIOp>(loc, value64, multiplier64);
98 Value shift64 = rewriter.
create<arith::ExtUIOp>(loc, i64Ty, shift32);
101 multiply64 = rewriter.
create<arith::AddIOp>(loc, multiply64,
round);
104 if (op.getRoundingMode() ==
"DOUBLE_ROUND") {
105 int64_t roundInt = 1 << 30;
106 Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter);
107 Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter);
109 loc, arith::CmpIPredicate::sge, value, zero);
111 rewriter.
create<arith::SelectOp>(loc, positive, roundUp, roundDown);
112 Value val = rewriter.
create<arith::AddIOp>(loc, dir, multiply64);
114 loc, arith::CmpIPredicate::sgt, shift32, thirtyOne32);
116 rewriter.
create<arith::SelectOp>(loc, valid, val, multiply64);
119 Value result64 = rewriter.
create<arith::ShRSIOp>(loc, multiply64, shift64);
120 Value result32 = rewriter.
create<arith::TruncIOp>(loc, i32Ty, result64);
127 class ApplyScale32BitOpConverter :
public OpRewritePattern<tosa::ApplyScaleOp> {
131 LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
133 StringRef roundingMode = op.getRoundingMode();
134 if (roundingMode !=
"DOUBLE_ROUND" && roundingMode !=
"SINGLE_ROUND") {
140 Type resultTy = op.getType();
143 Value value = op.getValue();
148 Value value32 = op.getValue();
149 Value multiplier32 = op.getMultiplier();
150 Value shift32 = rewriter.
create<arith::ExtUIOp>(loc, i32Ty, op.getShift());
153 Value zero32 = getConstantValue(loc, i32Ty, 0, rewriter);
154 Value one32 = getConstantValue(loc, i32Ty, 1, rewriter);
155 Value two32 = getConstantValue(loc, i32Ty, 2, rewriter);
156 Value thirty32 = getConstantValue(loc, i32Ty, 30, rewriter);
157 Value thirtyTwo32 = getConstantValue(loc, i32Ty, 32, rewriter);
162 rewriter.
create<arith::MulSIExtendedOp>(loc, value32, multiplier32);
163 Value low32 = value64.getLow();
164 Value high32 = value64.getHigh();
167 Value shiftOver32 = rewriter.
create<arith::CmpIOp>(
168 loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32);
169 Value roundHighBits = rewriter.
create<arith::CmpIOp>(
170 loc, arith::CmpIPredicate::sgt, shift32, thirtyTwo32);
173 rewriter.
create<arith::SubIOp>(loc, thirtyTwo32, shift32);
175 rewriter.
create<arith::SubIOp>(loc, shift32, thirtyTwo32);
178 rewriter.
create<arith::SelectOp>(loc, shiftOver32, zero32, shiftHighL);
180 rewriter.
create<arith::SelectOp>(loc, shiftOver32, shiftHighR, zero32);
183 if (op.getRoundingMode() ==
"DOUBLE_ROUND") {
184 Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter);
185 Value valuePositive = rewriter.
create<arith::CmpIOp>(
186 loc, arith::CmpIPredicate::sge, value32, zero32);
189 rewriter.
create<arith::SelectOp>(loc, valuePositive, one32, negOne32);
191 rewriter.
create<arith::SelectOp>(loc, shiftOver32, roundDir, zero32);
193 Value shiftLow = rewriter.
create<arith::ShRUIOp>(loc, low32, thirty32);
194 Value rounded = rewriter.
create<arith::AddIOp>(loc, shiftLow, roundDir);
195 Value carry = rewriter.
create<arith::ShRSIOp>(loc, rounded, two32);
198 rewriter.
create<arith::ShLIOp>(loc, roundDir, thirty32);
200 low32 = rewriter.
create<arith::AddIOp>(loc, low32, shiftRound);
201 high32 = rewriter.
create<arith::AddIOp>(loc, high32, carry);
206 Value shiftSubOne = rewriter.
create<arith::SubIOp>(loc, shift32, one32);
207 Value roundBit = rewriter.
create<arith::ShLIOp>(loc, one32, shiftSubOne);
208 roundBit = rewriter.
create<arith::SelectOp>(loc, roundHighBits, zero32,
211 Value newLow32 = rewriter.
create<arith::AddIOp>(loc, low32, roundBit);
213 loc, arith::CmpIPredicate::ugt, low32, newLow32);
216 Value rounded32 = rewriter.
create<arith::ExtUIOp>(loc, i32Ty, wasRounded);
217 high32 = rewriter.
create<arith::AddIOp>(loc, high32, rounded32);
223 rewriter.
create<arith::SubIOp>(loc, shiftHighR, one32);
224 Value roundBit = rewriter.
create<arith::ShLIOp>(loc, one32, shiftSubOne);
225 roundBit = rewriter.
create<arith::SelectOp>(loc, roundHighBits, roundBit,
227 high32 = rewriter.
create<arith::AddIOp>(loc, high32, roundBit);
231 high32 = rewriter.
create<arith::ShLIOp>(loc, high32, shiftHighL);
232 high32 = rewriter.
create<arith::ShRSIOp>(loc, high32, shiftHighR);
233 low32 = rewriter.
create<arith::ShRUIOp>(loc, low32, shift32);
234 low32 = rewriter.
create<arith::SelectOp>(loc, shiftOver32, zero32, low32);
237 Value result = rewriter.
create<arith::AddIOp>(loc, low32, high32);
241 result = rewriter.
create<arith::TruncIOp>(loc, resultTy, result);
IntegerAttr getIntegerAttr(Type type, int64_t value)
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
DynamicAPInt round(const Fraction &f)
void populateTosaRescaleToArithConversionPatterns(RewritePatternSet *patterns, bool include32Bit=false)
void populateTosaToArithConversionPatterns(RewritePatternSet *patterns)
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...