29 LogicalResult matchAndRewrite(tosa::ConstOp op,
31 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, op.getValue());
37 if (
auto shapedTy = dyn_cast<ShapedType>(container))
38 return shapedTy.clone(element);
44 if (
auto shapedTy = dyn_cast<ShapedType>(type)) {
45 Type eTy = shapedTy.getElementType();
55 return rewriter.
create<arith::ConstantOp>(
56 loc, getConstantAttr(type, value, rewriter));
61 class ApplyScaleGenericOpConverter
66 LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
69 Value value = op.getValue();
70 Value multiplier32 = op.getMultiplier();
72 Type resultTy = op.getType();
77 Value zero = getConstantValue(loc, valueTy, 0, rewriter);
78 Value one64 = getConstantValue(loc, i64Ty, 1, rewriter);
79 Value thirtyOne32 = getConstantValue(loc, i32Ty, 31, rewriter);
81 Value shift32 = rewriter.
create<arith::ExtUIOp>(loc, i32Ty, op.getShift());
84 Value value64 = value;
86 value64 = rewriter.
create<arith::ExtSIOp>(loc, i64Ty, value);
88 rewriter.
create<arith::ExtSIOp>(loc, i64Ty, multiplier32);
90 rewriter.
create<arith::MulIOp>(loc, value64, multiplier64);
93 Value shift64 = rewriter.
create<arith::ExtUIOp>(loc, i64Ty, shift32);
96 multiply64 = rewriter.
create<arith::AddIOp>(loc, multiply64,
round);
99 if (op.getDoubleRound()) {
100 int64_t roundInt = 1 << 30;
101 Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter);
102 Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter);
104 loc, arith::CmpIPredicate::sge, value, zero);
106 rewriter.
create<arith::SelectOp>(loc, positive, roundUp, roundDown);
107 Value val = rewriter.
create<arith::AddIOp>(loc, dir, multiply64);
109 loc, arith::CmpIPredicate::sgt, shift32, thirtyOne32);
111 rewriter.
create<arith::SelectOp>(loc, valid, val, multiply64);
114 Value result64 = rewriter.
create<arith::ShRSIOp>(loc, multiply64, shift64);
115 Value result32 = rewriter.
create<arith::TruncIOp>(loc, i32Ty, result64);
122 class ApplyScale32BitOpConverter :
public OpRewritePattern<tosa::ApplyScaleOp> {
126 LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
130 Type resultTy = op.getType();
133 Value value = op.getValue();
138 Value value32 = op.getValue();
139 Value multiplier32 = op.getMultiplier();
140 Value shift32 = rewriter.
create<arith::ExtUIOp>(loc, i32Ty, op.getShift());
143 Value zero32 = getConstantValue(loc, i32Ty, 0, rewriter);
144 Value one32 = getConstantValue(loc, i32Ty, 1, rewriter);
145 Value two32 = getConstantValue(loc, i32Ty, 2, rewriter);
146 Value thirty32 = getConstantValue(loc, i32Ty, 30, rewriter);
147 Value thirtyTwo32 = getConstantValue(loc, i32Ty, 32, rewriter);
152 rewriter.
create<arith::MulSIExtendedOp>(loc, value32, multiplier32);
153 Value low32 = value64.getLow();
154 Value high32 = value64.getHigh();
157 Value shiftOver32 = rewriter.
create<arith::CmpIOp>(
158 loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32);
159 Value roundHighBits = rewriter.
create<arith::CmpIOp>(
160 loc, arith::CmpIPredicate::sgt, shift32, thirtyTwo32);
163 rewriter.
create<arith::SubIOp>(loc, thirtyTwo32, shift32);
165 rewriter.
create<arith::SubIOp>(loc, shift32, thirtyTwo32);
168 rewriter.
create<arith::SelectOp>(loc, shiftOver32, zero32, shiftHighL);
170 rewriter.
create<arith::SelectOp>(loc, shiftOver32, shiftHighR, zero32);
173 if (op.getDoubleRound()) {
174 Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter);
175 Value valuePositive = rewriter.
create<arith::CmpIOp>(
176 loc, arith::CmpIPredicate::sge, value32, zero32);
179 rewriter.
create<arith::SelectOp>(loc, valuePositive, one32, negOne32);
181 rewriter.
create<arith::SelectOp>(loc, shiftOver32, roundDir, zero32);
183 Value shiftLow = rewriter.
create<arith::ShRUIOp>(loc, low32, thirty32);
184 Value rounded = rewriter.
create<arith::AddIOp>(loc, shiftLow, roundDir);
185 Value carry = rewriter.
create<arith::ShRSIOp>(loc, rounded, two32);
188 rewriter.
create<arith::ShLIOp>(loc, roundDir, thirty32);
190 low32 = rewriter.
create<arith::AddIOp>(loc, low32, shiftRound);
191 high32 = rewriter.
create<arith::AddIOp>(loc, high32, carry);
196 Value shiftSubOne = rewriter.
create<arith::SubIOp>(loc, shift32, one32);
197 Value roundBit = rewriter.
create<arith::ShLIOp>(loc, one32, shiftSubOne);
198 roundBit = rewriter.
create<arith::SelectOp>(loc, roundHighBits, zero32,
201 Value newLow32 = rewriter.
create<arith::AddIOp>(loc, low32, roundBit);
203 loc, arith::CmpIPredicate::ugt, low32, newLow32);
206 Value rounded32 = rewriter.
create<arith::ExtUIOp>(loc, i32Ty, wasRounded);
207 high32 = rewriter.
create<arith::AddIOp>(loc, high32, rounded32);
213 rewriter.
create<arith::SubIOp>(loc, shiftHighR, one32);
214 Value roundBit = rewriter.
create<arith::ShLIOp>(loc, one32, shiftSubOne);
215 roundBit = rewriter.
create<arith::SelectOp>(loc, roundHighBits, roundBit,
217 high32 = rewriter.
create<arith::AddIOp>(loc, high32, roundBit);
221 high32 = rewriter.
create<arith::ShLIOp>(loc, high32, shiftHighL);
222 high32 = rewriter.
create<arith::ShRSIOp>(loc, high32, shiftHighR);
223 low32 = rewriter.
create<arith::ShRUIOp>(loc, low32, shift32);
224 low32 = rewriter.
create<arith::SelectOp>(loc, shiftOver32, zero32, low32);
227 Value result = rewriter.
create<arith::AddIOp>(loc, low32, high32);
231 result = rewriter.
create<arith::TruncIOp>(loc, resultTy, result);
248 patterns->
add<ApplyScaleGenericOpConverter>(patterns->
getContext(), 100);
250 patterns->
add<ApplyScale32BitOpConverter>(patterns->
getContext(), 200);
IntegerAttr getIntegerAttr(Type type, int64_t value)
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
DynamicAPInt round(const Fraction &f)
void populateTosaRescaleToArithConversionPatterns(RewritePatternSet *patterns, bool include32Bit=false)
void populateTosaToArithConversionPatterns(RewritePatternSet *patterns)
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...