31 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, op.getValue());
37 if (
auto shapedTy = container.
dyn_cast<ShapedType>())
38 return shapedTy.clone(element);
44 if (
auto shapedTy = type.
dyn_cast<ShapedType>()) {
45 Type eTy = shapedTy.getElementType();
55 return rewriter.
create<arith::ConstantOp>(
56 loc, getConstantAttr(type, value, rewriter));
61 class ApplyScaleGenericOpConverter
69 Value value = op.getValue();
70 Value multiplier32 = op.getMultiplier();
72 Type resultTy = op.getType();
77 Value zero = getConstantValue(loc, valueTy, 0, rewriter);
78 Value one64 = getConstantValue(loc, i64Ty, 1, rewriter);
79 Value thirtyOne32 = getConstantValue(loc, i32Ty, 31, rewriter);
81 Value shift32 = rewriter.
create<arith::ExtUIOp>(loc, i32Ty, op.getShift());
84 Value value64 = rewriter.
create<arith::ExtSIOp>(loc, i64Ty, value);
86 rewriter.
create<arith::ExtSIOp>(loc, i64Ty, multiplier32);
88 rewriter.
create<arith::MulIOp>(loc, value64, multiplier64);
91 Value shift64 = rewriter.
create<arith::ExtUIOp>(loc, i64Ty, shift32);
92 Value round = rewriter.
create<arith::ShLIOp>(loc, one64, shift64);
93 round = rewriter.
create<arith::ShRUIOp>(loc, round, one64);
94 multiply64 = rewriter.
create<arith::AddIOp>(loc, multiply64, round);
97 if (op.getDoubleRound()) {
98 int64_t roundInt = 1 << 30;
99 Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter);
100 Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter);
102 loc, arith::CmpIPredicate::sge, value, zero);
104 rewriter.
create<arith::SelectOp>(loc, positive, roundUp, roundDown);
105 Value val = rewriter.
create<arith::AddIOp>(loc, dir, multiply64);
107 loc, arith::CmpIPredicate::sgt, shift32, thirtyOne32);
109 rewriter.
create<arith::SelectOp>(loc, valid, val, multiply64);
112 Value result64 = rewriter.
create<arith::ShRSIOp>(loc, multiply64, shift64);
113 Value result32 = rewriter.
create<arith::TruncIOp>(loc, i32Ty, result64);
120 class ApplyScale32BitOpConverter :
public OpRewritePattern<tosa::ApplyScaleOp> {
128 Type resultTy = op.getType();
131 Value value = op.getValue();
136 Value value32 = op.getValue();
137 Value multiplier32 = op.getMultiplier();
138 Value shift32 = rewriter.
create<arith::ExtUIOp>(loc, i32Ty, op.getShift());
141 Value zero32 = getConstantValue(loc, i32Ty, 0, rewriter);
142 Value one32 = getConstantValue(loc, i32Ty, 1, rewriter);
143 Value two32 = getConstantValue(loc, i32Ty, 2, rewriter);
144 Value thirty32 = getConstantValue(loc, i32Ty, 30, rewriter);
145 Value thirtyTwo32 = getConstantValue(loc, i32Ty, 32, rewriter);
150 rewriter.
create<arith::MulSIExtendedOp>(loc, value32, multiplier32);
151 Value low32 = value64.getLow();
152 Value high32 = value64.getHigh();
155 Value shiftOver32 = rewriter.
create<arith::CmpIOp>(
156 loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32);
157 Value roundHighBits = rewriter.
create<arith::CmpIOp>(
158 loc, arith::CmpIPredicate::sgt, shift32, thirtyTwo32);
161 rewriter.
create<arith::SubIOp>(loc, thirtyTwo32, shift32);
163 rewriter.
create<arith::SubIOp>(loc, shift32, thirtyTwo32);
166 rewriter.
create<arith::SelectOp>(loc, shiftOver32, zero32, shiftHighL);
168 rewriter.
create<arith::SelectOp>(loc, shiftOver32, shiftHighR, zero32);
171 if (op.getDoubleRound()) {
172 Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter);
173 Value valuePositive = rewriter.
create<arith::CmpIOp>(
174 loc, arith::CmpIPredicate::sge, value32, zero32);
177 rewriter.
create<arith::SelectOp>(loc, valuePositive, one32, negOne32);
179 rewriter.
create<arith::SelectOp>(loc, shiftOver32, roundDir, zero32);
181 Value shiftLow = rewriter.
create<arith::ShRUIOp>(loc, low32, thirty32);
182 Value rounded = rewriter.
create<arith::AddIOp>(loc, shiftLow, roundDir);
183 Value carry = rewriter.
create<arith::ShRSIOp>(loc, rounded, two32);
186 rewriter.
create<arith::ShLIOp>(loc, roundDir, thirty32);
188 low32 = rewriter.
create<arith::AddIOp>(loc, low32, shiftRound);
189 high32 = rewriter.
create<arith::AddIOp>(loc, high32, carry);
194 Value shiftSubOne = rewriter.
create<arith::SubIOp>(loc, shift32, one32);
195 Value roundBit = rewriter.
create<arith::ShLIOp>(loc, one32, shiftSubOne);
196 roundBit = rewriter.
create<arith::SelectOp>(loc, roundHighBits, zero32,
199 Value newLow32 = rewriter.
create<arith::AddIOp>(loc, low32, roundBit);
201 loc, arith::CmpIPredicate::ugt, low32, newLow32);
204 Value rounded32 = rewriter.
create<arith::ExtUIOp>(loc, i32Ty, wasRounded);
205 high32 = rewriter.
create<arith::AddIOp>(loc, high32, rounded32);
211 rewriter.
create<arith::SubIOp>(loc, shiftHighR, one32);
212 Value roundBit = rewriter.
create<arith::ShLIOp>(loc, one32, shiftSubOne);
213 roundBit = rewriter.
create<arith::SelectOp>(loc, roundHighBits, roundBit,
215 high32 = rewriter.
create<arith::AddIOp>(loc, high32, roundBit);
219 high32 = rewriter.
create<arith::ShLIOp>(loc, high32, shiftHighL);
220 high32 = rewriter.
create<arith::ShRSIOp>(loc, high32, shiftHighR);
221 low32 = rewriter.
create<arith::ShRUIOp>(loc, low32, shift32);
222 low32 = rewriter.
create<arith::SelectOp>(loc, shiftOver32, zero32, low32);
225 Value result = rewriter.
create<arith::AddIOp>(loc, low32, high32);
229 result = rewriter.
create<arith::TruncIOp>(loc, resultTy, result);
246 patterns->
add<ApplyScaleGenericOpConverter>(patterns->
getContext(), 100);
248 patterns->
add<ApplyScale32BitOpConverter>(patterns->
getContext(), 200);
Attributes are known-constant values of operations.
IntegerAttr getIntegerAttr(Type type, int64_t value)
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void populateTosaRescaleToArithConversionPatterns(RewritePatternSet *patterns, bool include32Bit=false)
void populateTosaToArithConversionPatterns(RewritePatternSet *patterns)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...