28 LogicalResult matchAndRewrite(tosa::ConstOp op,
30 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, op.getValues());
36 if (
auto shapedTy = dyn_cast<ShapedType>(container))
37 return shapedTy.clone(element);
43 if (
auto shapedTy = dyn_cast<ShapedType>(type)) {
44 Type eTy = shapedTy.getElementType();
54 return rewriter.
create<arith::ConstantOp>(
55 loc, getConstantAttr(type, value, rewriter));
60 class ApplyScaleGenericOpConverter
65 LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
67 StringRef roundingMode = op.getRoundingMode();
68 if (roundingMode !=
"DOUBLE_ROUND" && roundingMode !=
"SINGLE_ROUND") {
73 Value value = op.getValue();
74 Value multiplier32 = op.getMultiplier();
76 Type resultTy = op.getType();
81 Value zero = getConstantValue(loc, valueTy, 0, rewriter);
82 Value one64 = getConstantValue(loc, i64Ty, 1, rewriter);
83 Value thirtyOne32 = getConstantValue(loc, i32Ty, 31, rewriter);
85 Value shift32 = rewriter.
create<arith::ExtUIOp>(loc, i32Ty, op.getShift());
88 Value value64 = value;
90 value64 = rewriter.
create<arith::ExtSIOp>(loc, i64Ty, value);
92 rewriter.
create<arith::ExtSIOp>(loc, i64Ty, multiplier32);
94 rewriter.
create<arith::MulIOp>(loc, value64, multiplier64);
97 Value shift64 = rewriter.
create<arith::ExtUIOp>(loc, i64Ty, shift32);
100 multiply64 = rewriter.
create<arith::AddIOp>(loc, multiply64,
round);
103 if (op.getRoundingMode() ==
"DOUBLE_ROUND") {
104 int64_t roundInt = 1 << 30;
105 Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter);
106 Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter);
108 loc, arith::CmpIPredicate::sge, value, zero);
110 rewriter.
create<arith::SelectOp>(loc, positive, roundUp, roundDown);
111 Value val = rewriter.
create<arith::AddIOp>(loc, dir, multiply64);
113 loc, arith::CmpIPredicate::sgt, shift32, thirtyOne32);
115 rewriter.
create<arith::SelectOp>(loc, valid, val, multiply64);
118 Value result64 = rewriter.
create<arith::ShRSIOp>(loc, multiply64, shift64);
119 Value result32 = rewriter.
create<arith::TruncIOp>(loc, i32Ty, result64);
126 class ApplyScale32BitOpConverter :
public OpRewritePattern<tosa::ApplyScaleOp> {
130 LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
132 StringRef roundingMode = op.getRoundingMode();
133 if (roundingMode !=
"DOUBLE_ROUND" && roundingMode !=
"SINGLE_ROUND") {
139 Type resultTy = op.getType();
142 Value value = op.getValue();
147 Value value32 = op.getValue();
148 Value multiplier32 = op.getMultiplier();
149 Value shift32 = rewriter.
create<arith::ExtUIOp>(loc, i32Ty, op.getShift());
152 Value zero32 = getConstantValue(loc, i32Ty, 0, rewriter);
153 Value one32 = getConstantValue(loc, i32Ty, 1, rewriter);
154 Value two32 = getConstantValue(loc, i32Ty, 2, rewriter);
155 Value thirty32 = getConstantValue(loc, i32Ty, 30, rewriter);
156 Value thirtyTwo32 = getConstantValue(loc, i32Ty, 32, rewriter);
161 rewriter.
create<arith::MulSIExtendedOp>(loc, value32, multiplier32);
162 Value low32 = value64.getLow();
163 Value high32 = value64.getHigh();
166 Value shiftOver32 = rewriter.
create<arith::CmpIOp>(
167 loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32);
168 Value roundHighBits = rewriter.
create<arith::CmpIOp>(
169 loc, arith::CmpIPredicate::sgt, shift32, thirtyTwo32);
172 rewriter.
create<arith::SubIOp>(loc, thirtyTwo32, shift32);
174 rewriter.
create<arith::SubIOp>(loc, shift32, thirtyTwo32);
177 rewriter.
create<arith::SelectOp>(loc, shiftOver32, zero32, shiftHighL);
179 rewriter.
create<arith::SelectOp>(loc, shiftOver32, shiftHighR, zero32);
182 if (op.getRoundingMode() ==
"DOUBLE_ROUND") {
183 Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter);
184 Value valuePositive = rewriter.
create<arith::CmpIOp>(
185 loc, arith::CmpIPredicate::sge, value32, zero32);
188 rewriter.
create<arith::SelectOp>(loc, valuePositive, one32, negOne32);
190 rewriter.
create<arith::SelectOp>(loc, shiftOver32, roundDir, zero32);
192 Value shiftLow = rewriter.
create<arith::ShRUIOp>(loc, low32, thirty32);
193 Value rounded = rewriter.
create<arith::AddIOp>(loc, shiftLow, roundDir);
194 Value carry = rewriter.
create<arith::ShRSIOp>(loc, rounded, two32);
197 rewriter.
create<arith::ShLIOp>(loc, roundDir, thirty32);
199 low32 = rewriter.
create<arith::AddIOp>(loc, low32, shiftRound);
200 high32 = rewriter.
create<arith::AddIOp>(loc, high32, carry);
205 Value shiftSubOne = rewriter.
create<arith::SubIOp>(loc, shift32, one32);
206 Value roundBit = rewriter.
create<arith::ShLIOp>(loc, one32, shiftSubOne);
207 roundBit = rewriter.
create<arith::SelectOp>(loc, roundHighBits, zero32,
210 Value newLow32 = rewriter.
create<arith::AddIOp>(loc, low32, roundBit);
212 loc, arith::CmpIPredicate::ugt, low32, newLow32);
215 Value rounded32 = rewriter.
create<arith::ExtUIOp>(loc, i32Ty, wasRounded);
216 high32 = rewriter.
create<arith::AddIOp>(loc, high32, rounded32);
222 rewriter.
create<arith::SubIOp>(loc, shiftHighR, one32);
223 Value roundBit = rewriter.
create<arith::ShLIOp>(loc, one32, shiftSubOne);
224 roundBit = rewriter.
create<arith::SelectOp>(loc, roundHighBits, roundBit,
226 high32 = rewriter.
create<arith::AddIOp>(loc, high32, roundBit);
230 high32 = rewriter.
create<arith::ShLIOp>(loc, high32, shiftHighL);
231 high32 = rewriter.
create<arith::ShRSIOp>(loc, high32, shiftHighR);
232 low32 = rewriter.
create<arith::ShRUIOp>(loc, low32, shift32);
233 low32 = rewriter.
create<arith::SelectOp>(loc, shiftOver32, zero32, low32);
236 Value result = rewriter.
create<arith::AddIOp>(loc, low32, high32);
240 result = rewriter.
create<arith::TruncIOp>(loc, resultTy, result);
IntegerAttr getIntegerAttr(Type type, int64_t value)
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
DynamicAPInt round(const Fraction &f)
void populateTosaRescaleToArithConversionPatterns(RewritePatternSet *patterns, bool include32Bit=false)
void populateTosaToArithConversionPatterns(RewritePatternSet *patterns)
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...