26 const double mantissa =
std::frexp(scale, &shift);
27 auto shiftedM = std::round(mantissa * (int64_t(1) << 15));
30 assert(shiftedM <= (int64_t(1) << 15) &&
31 "Shifted mantissa exceeds 16 signed bits");
33 if (shiftedM == (int64_t(1) << 15)) {
40 shift = (-shift) + 15;
43 "Shifted mantissa exceeds 32-bit signed output type");
45 multiplier =
static_cast<int32_t
>(shiftedM);
50 multiplier = multiplier >> std::min<int32_t>(31, shift - 63);
62 const double mantissa =
std::frexp(scale, &shift);
63 auto shiftedM = std::round(mantissa * (int64_t(1) << 31));
66 assert(shiftedM <= (int64_t(1) << 31) &&
67 "Shifted mantissa exceeds 32 signed bits");
68 if (shiftedM == (int64_t(1) << 31)) {
75 shift = (-shift) + 31;
78 "Shifted mantissa exceeds 32-bit signed output type");
80 multiplier =
static_cast<int32_t
>(shiftedM);
85 multiplier = multiplier >> std::min<int32_t>(31, shift - 63);
92 int32_t &shift, int32_t scaleWidth) {
102 assert(0 &&
"Unsupported Tosa quantized_scale regime specified!");
106 #define GET_UQTYPE(input_type) \ 107 ((input_type).getElementType().dyn_cast<quant::UniformQuantizedType>()) 108 #define GET_QTYPE(input_type) \ 109 ((input_type).getElementType().dyn_cast<quant::QuantizedType>()) 115 ConvOpQuantizationAttr
122 if (!inputType || !weightType)
126 auto weightPerTensorQType =
GET_UQTYPE(weightType);
127 auto weightPerAxisQType = weightType.getElementType()
131 assert(!((
bool)weightPerTensorQType && (
bool)weightPerAxisQType) &&
132 "Weights must be either per-tensor or per-axis quantized");
135 assert(!((
bool)inputQType ^
136 ((
bool)weightPerTensorQType || (
bool)weightPerAxisQType)) &&
137 "Inputs and weights must be all quantized or all not quantized");
140 int64_t inputZp = inputQType.getZeroPoint();
141 int64_t weightZp = 0;
143 if (weightPerTensorQType) {
144 weightZp = weightPerTensorQType.getZeroPoint();
145 }
else if (weightPerAxisQType) {
146 weightZp = weightPerAxisQType.getZeroPoints().front();
149 return builder.
getAttr<tosa::ConvOpQuantizationAttr>(inputZp, weightZp);
159 MatMulOpQuantizationAttr
166 if (!aType || !bType)
173 assert(!((
bool)aQType ^ (
bool)bQType) &&
174 "Matmul operands must be all quantized or all not quantized");
177 return builder.
getAttr<tosa::MatMulOpQuantizationAttr>(
178 aQType.getZeroPoint(), bQType.getZeroPoint());
188 UnaryOpQuantizationAttr
190 Type outputRawType) {
193 auto outputType = outputRawType.
dyn_cast<ShapedType>();
195 if (!inputType || !outputType)
202 assert(!((
bool)inputQType ^ (
bool)outputQType) &&
203 "Unary inputs/outputs must be all quantized or all not quantized");
206 return builder.
getAttr<UnaryOpQuantizationAttr>(inputQType.getZeroPoint(),
207 outputQType.getZeroPoint());
226 return builder.
getAttr<tosa::PadOpQuantizationAttr>(
227 inputQType.getZeroPoint());
241 assert(inputType && weightType &&
242 "Could not extract input or weight tensors from Conv op");
245 auto weightQType =
GET_QTYPE(weightType);
247 assert(inputQType && weightQType &&
248 "Could not extract input or weight tensor types from Conv op");
250 unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
251 unsigned weightBits = weightQType.getStorageTypeIntegralWidth();
253 auto outputShapedType = outputType.
dyn_cast<ShapedType>();
254 assert(outputShapedType &&
255 "Could not extract output shape type from Conv op");
257 IntegerType accElementType;
258 if (inputBits == 16 && weightBits == 8)
262 auto accType = outputShapedType.clone(accElementType);
269 IntegerAttr quantBits,
int filterQuantDim,
270 bool isSigned,
BoolAttr narrowRange) {
283 if (minElems || maxElems) {
285 if (minElems.getNumElements() != maxElems.getNumElements())
287 min.reserve(minElems.getNumElements());
288 max.reserve(maxElems.getNumElements());
289 for (
auto i : minElems)
290 min.push_back(FloatAttr::getValueAsDouble(i));
291 for (
auto i : maxElems)
292 max.push_back(FloatAttr::getValueAsDouble(i));
294 auto minVal = minAttr.
dyn_cast<FloatAttr>();
296 min.push_back(minVal.getValueAsDouble());
299 auto maxVal = maxAttr.
dyn_cast<FloatAttr>();
301 max.push_back(maxVal.getValueAsDouble());
306 if (min.size() == max.size()) {
307 if (min.size() == 1) {
310 narrowRange.
getValue(), convfunc.expressedType, isSigned);
311 }
else if (min.size() > 1) {
312 auto shape = inputDType.
dyn_cast<ShapedType>();
315 if ((filterQuantDim) >= 0 && (shape.getRank() > filterQuantDim)) {
317 builder.
getUnknownLoc(), quantBits.getInt(), filterQuantDim, min[0],
318 max[0], narrowRange.
getValue(), convfunc.expressedType, isSigned);
330 return convfunc.convert(retType);
337 IntegerAttr quantBits,
int filterQuantDim,
338 bool isSigned,
BoolAttr narrowRange) {
341 maxAttr, quantBits, filterQuantDim,
342 isSigned, narrowRange));
Include the generated interface declarations.
#define GET_QTYPE(input_type)
An attribute that represents a reference to a dense float vector or tensor object.
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
TypeAttr buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDType, Attribute minAttr, Attribute maxAttr, IntegerAttr quantBits, int filterQuantDim, bool isSigned, BoolAttr narrowRange)
Builds Tosa quantization attributes from min/max values.
bool getValue() const
Return the boolean value of this attribute.
Attributes are known-constant values of operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
#define GET_UQTYPE(input_type)
static void computeMultiplierAndShiftTosaScale32(double scale, int32_t &multiplier, int32_t &shift)
From a scale value, generates multiplier and shift values where mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that multiplier = mantissa*2^shift for 32-bit scaling.
IntegerType getIntegerType(unsigned width)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
static std::pair< Value, Value > frexp(ImplicitLocOpBuilder &builder, Value arg, bool isPositive=false)
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void computeMultiplierAndShift(double scale, int32_t &multiplier, int32_t &shift, int32_t scaleWidth)
From a scale value, computes multiplier and shift values for 16 or 32-bit scale widths.
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
static void computeMultiplierAndShiftTosaScale16(double scale, int32_t &multiplier, int32_t &shift)
From a scale value, generates multiplier and shift values where mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that multiplier = mantissa*2^shift for 16-bit scaling.
UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits, double rmin, double rmax, bool narrowRange, Type expressedType, bool isSigned=false)
Converts per-layer FakeQuant attributes to the corresponding type.
Type getType() const
Return the type of this value.
Base class for all quantized types known to this dialect.
Type buildQTypeFromMinMax(OpBuilder builder, Type inputDType, Attribute minAttr, Attribute maxAttr, IntegerAttr quantBits, int filterQuantDim, bool isSigned, BoolAttr narrowRange)
Builds Tosa quantization attributes from min/max values.
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
static ExpressedToQuantizedConverter forInputType(Type inputType)
Creates a converter for the given input type.
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
This class helps build Operations.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)