23 double getMinScale(
Type expressedType) {
24 auto floatType = cast<FloatType>(expressedType);
25 return APFloat::getSmallest(floatType.getFloatSemantics()).convertToDouble();
29 double getMaxScale(
Type expressedType) {
30 auto floatType = cast<FloatType>(expressedType);
31 return APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble();
41 return llvm::isa<QuantDialect>(type.
getDialect());
46 unsigned flags,
Type storageType,
47 Type expressedType, int64_t storageTypeMin,
48 int64_t storageTypeMax) {
52 auto intStorageType = llvm::dyn_cast<IntegerType>(storageType);
54 return emitError() <<
"storage type must be integral";
55 unsigned integralWidth = intStorageType.getWidth();
58 if (integralWidth == 0 || integralWidth > MaxStorageBits)
59 return emitError() <<
"illegal storage type size: " << integralWidth;
64 int64_t defaultIntegerMin =
65 getDefaultMinimumForInteger(isSigned, integralWidth);
66 int64_t defaultIntegerMax =
67 getDefaultMaximumForInteger(isSigned, integralWidth);
68 if (storageTypeMax - storageTypeMin <= 0 ||
69 storageTypeMin < defaultIntegerMin ||
70 storageTypeMax > defaultIntegerMax) {
71 return emitError() <<
"illegal storage min and storage max: ("
72 << storageTypeMin <<
":" << storageTypeMax <<
")";
90 unsigned int integralWidth = getStorageTypeIntegralWidth();
91 bool isSignedInteger = isSigned();
92 int64_t defaultIntegerMin =
93 getDefaultMinimumForInteger(isSignedInteger, integralWidth);
94 int64_t defaultIntegerMax =
95 getDefaultMaximumForInteger(isSignedInteger, integralWidth);
96 return defaultIntegerMin != getStorageTypeMin() ||
97 defaultIntegerMax != getStorageTypeMax();
103 return static_cast<ImplType *
>(
impl)->storageType.getIntOrFloatBitWidth();
111 if (llvm::isa<ShapedType>(candidateExpressedType)) {
112 return llvm::cast<ShapedType>(candidateExpressedType).getElementType() ==
115 return candidateExpressedType == getExpressedType();
120 if (llvm::isa<ShapedType>(primitiveOrContainerType)) {
122 llvm::cast<ShapedType>(primitiveOrContainerType).getElementType();
123 return llvm::dyn_cast<QuantizedType>(elementType);
125 return llvm::dyn_cast<QuantizedType>(primitiveOrContainerType);
129 if (candidateType == getStorageType()) {
133 if (llvm::isa<RankedTensorType>(candidateType)) {
136 llvm::cast<RankedTensorType>(candidateType).
getShape(),
139 if (llvm::isa<UnrankedTensorType>(candidateType)) {
143 if (llvm::isa<VectorType>(candidateType)) {
153 if (llvm::isa<QuantizedType>(quantizedType)) {
155 return llvm::cast<QuantizedType>(quantizedType).getStorageType();
157 if (llvm::isa<ShapedType>(quantizedType)) {
159 ShapedType sType = llvm::cast<ShapedType>(quantizedType);
160 if (!llvm::isa<QuantizedType>(sType.getElementType())) {
164 llvm::cast<QuantizedType>(sType.getElementType()).getStorageType();
165 if (llvm::isa<RankedTensorType>(quantizedType)) {
168 if (llvm::isa<UnrankedTensorType>(quantizedType)) {
171 if (llvm::isa<VectorType>(quantizedType)) {
180 if (candidateType == getExpressedType()) {
184 if (llvm::isa<ShapedType>(candidateType)) {
185 ShapedType candidateShapedType = llvm::cast<ShapedType>(candidateType);
186 if (candidateShapedType.getElementType() != getExpressedType()) {
190 if (llvm::isa<RankedTensorType>(candidateType)) {
194 if (llvm::isa<UnrankedTensorType>(candidateType)) {
198 if (llvm::isa<VectorType>(candidateType)) {
208 if (llvm::isa<QuantizedType>(quantizedType)) {
210 return llvm::cast<QuantizedType>(quantizedType).getExpressedType();
212 if (llvm::isa<ShapedType>(quantizedType)) {
214 ShapedType sType = llvm::cast<ShapedType>(quantizedType);
215 if (!llvm::isa<QuantizedType>(sType.getElementType())) {
219 llvm::cast<QuantizedType>(sType.getElementType()).getExpressedType();
220 if (llvm::isa<RankedTensorType>(quantizedType)) {
223 if (llvm::isa<UnrankedTensorType>(quantizedType)) {
226 if (llvm::isa<VectorType>(quantizedType)) {
235 Type expressedQuantizedType = castFromExpressedType(candidateType);
236 if (!expressedQuantizedType) {
244 int64_t storageTypeMin,
245 int64_t storageTypeMax) {
247 storageTypeMin, storageTypeMax);
252 unsigned flags,
Type storageType,
253 Type expressedType, int64_t storageTypeMin,
254 int64_t storageTypeMax) {
256 storageType, expressedType, storageTypeMin,
262 unsigned flags,
Type storageType,
263 Type expressedType, int64_t storageTypeMin,
264 int64_t storageTypeMax) {
266 expressedType, storageTypeMin,
274 if (expressedType && !llvm::isa<FloatType>(expressedType))
275 return emitError() <<
"expressed type must be floating point";
281 Type expressedType,
double scale,
283 int64_t storageTypeMin,
284 int64_t storageTypeMax) {
286 scale, zeroPoint, storageTypeMin, storageTypeMax);
291 Type storageType,
Type expressedType,
double scale, int64_t zeroPoint,
292 int64_t storageTypeMin, int64_t storageTypeMax) {
294 storageType, expressedType, scale, zeroPoint,
295 storageTypeMin, storageTypeMax);
300 Type storageType,
Type expressedType,
double scale, int64_t zeroPoint,
301 int64_t storageTypeMin, int64_t storageTypeMax) {
303 expressedType, storageTypeMin,
311 return emitError() <<
"uniform quantization requires expressed type";
316 if (!llvm::isa<FloatType>(expressedType))
317 return emitError() <<
"expressed type must be floating point";
320 double minScale = getMinScale(expressedType);
321 double maxScale = getMaxScale(expressedType);
322 if (scale < minScale || scale > maxScale)
323 return emitError() <<
"scale out of expressed type range [" << minScale
324 <<
", " << maxScale <<
"]";
332 return getImpl()->zeroPoint;
336 unsigned flags,
Type storageType,
Type expressedType,
338 int32_t quantizedDimension, int64_t storageTypeMin,
339 int64_t storageTypeMax) {
341 scales, zeroPoints, quantizedDimension, storageTypeMin,
349 int64_t storageTypeMin, int64_t storageTypeMax) {
351 storageType, expressedType, scales, zeroPoints,
352 quantizedDimension, storageTypeMin, storageTypeMax);
359 int64_t storageTypeMin, int64_t storageTypeMax) {
361 expressedType, storageTypeMin,
369 return emitError() <<
"uniform quantization requires expressed type";
374 if (!llvm::isa<FloatType>(expressedType))
375 return emitError() <<
"expressed type must be floating point";
378 if (scales.size() != zeroPoints.size())
379 return emitError() <<
"illegal number of scales and zeroPoints: "
380 << scales.size() <<
", " << zeroPoints.size();
383 double minScale = getMinScale(expressedType);
384 double maxScale = getMaxScale(expressedType);
385 for (
double scale : scales) {
386 if (scale < minScale || scale > maxScale)
387 return emitError() <<
"scale out of expressed type range [" << minScale
388 <<
", " << maxScale <<
"]";
392 if (quantizedDimension < 0)
393 return emitError() <<
"illegal quantized dimension: " << quantizedDimension;
399 return getImpl()->getScales();
403 return getImpl()->getZeroPoints();
407 return getImpl()->quantizedDimension;
411 unsigned flags,
Type storageType,
Type expressedType,
414 int64_t storageTypeMin, int64_t storageTypeMax) {
416 scales, zeroPoints, quantizedDimensions, blockSizes,
417 storageTypeMin, storageTypeMax);
425 int64_t storageTypeMax) {
427 storageType, expressedType, scales, zeroPoints,
428 quantizedDimensions, blockSizes, storageTypeMin,
437 int64_t storageTypeMax) {
439 expressedType, storageTypeMin,
447 return emitError() <<
"uniform quantization requires expressed type";
452 if (!llvm::isa<FloatType>(expressedType))
453 return emitError() <<
"expressed type must be floating point";
456 if (scales.
getType().getElementType() != expressedType) {
457 return emitError() <<
"type of scale values "
458 << scales.
getType().getElementType()
459 <<
" must match the expressed type " << expressedType;
463 if (zeroPoints.
getType().getElementType() != storageType) {
464 return emitError() <<
"type of zero point values "
465 << zeroPoints.
getType().getElementType()
466 <<
" must match the storage type " << storageType;
470 if (scales.
getType().getShape() != zeroPoints.
getType().getShape())
471 return emitError() <<
"shape of scales and zeroPoints ("
472 << scales.
getType().getShape() <<
" vs "
473 << zeroPoints.
getType().getShape() <<
") does not match";
476 if (quantizedDimensions.size() != blockSizes.size())
477 return emitError() <<
"number of quantized dimensions and block sizes ("
478 << scales.
size() <<
" vs " << zeroPoints.
size()
479 <<
") does not match";
482 for (
auto quantizedDimension : quantizedDimensions) {
483 if (quantizedDimension < 0)
484 return emitError() <<
"illegal quantized dimension: "
485 << quantizedDimension;
489 for (
auto blockSize : blockSizes) {
491 return emitError() <<
"illegal block size: " << blockSize;
498 return getImpl()->getScales();
502 return getImpl()->getZeroPoints();
507 return getImpl()->getQuantizedDimensions();
511 return getImpl()->getBlockSizes();
517 result.reserve(getQuantizedDimensions().size());
519 for (
auto [dim, size] :
520 llvm::zip(getQuantizedDimensions(), getBlockSizes())) {
521 result.push_back({dim, size});
545 if (!llvm::isa<FloatType>(expressedType))
546 return emitError() <<
"expressed type must be floating point";
548 return emitError() <<
"illegal min and max: (" <<
min <<
":" <<
max <<
")";
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
An attribute that represents a reference to a dense vector or tensor object.
int64_t size() const
Returns the number of elements held by this attribute.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class represents a diagnostic that is inflight and set to be reported.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
A quantized type that maps storage to/from expressed types in an unspecified way.
static AnyQuantizedType get(unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
static AnyQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
A quantized type that infers its range from given min/max values.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double min, double max)
Verifies construction invariants and issues errors/warnings.
static CalibratedQuantizedType get(Type expressedType, double min, double max)
Gets an instance of the type with all parameters specified but not checked.
static CalibratedQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double min, double max)
Gets an instance of the type with all specified parameters checked.
Base class for all quantized types known to this dialect.
Type getExpressedType() const
Gets the original expressed type that this quantized type approximates.
bool hasStorageTypeBounds() const
Return whether the storage type has explicit min or max boundaries different from the minimum and max...
static Type castToStorageType(Type quantizedType)
Casts from a type based on a QuantizedType to a corresponding type based on the storageType (returns ...
Type castExpressedToStorageType(Type candidateType)
Casts from a type based on the expressedType to the equivalent type based on storageType by way of th...
static Type castToExpressedType(Type quantizedType)
Casts from a type based on QuantizedType to a corresponding type based on the expressedType (returns ...
static QuantizedType getQuantizedElementType(Type primitiveOrContainerType)
Returns the element type as a QuantizedType or nullptr if it is not a quantized type.
unsigned getFlags() const
Gets the flags associated with this type.
int64_t getStorageTypeMax() const
The maximum value that storageType can take.
unsigned getStorageTypeIntegralWidth() const
Gets the integral bit width that the underlying storage type can exactly represent.
static bool classof(Type type)
Support method to enable LLVM-style type casting.
Type castFromStorageType(Type candidateType)
Casts from a type based on the storageType to a corresponding type based on this type (returns nullpt...
int64_t getStorageTypeMin() const
The minimum value that storageType can take.
Type getStorageType() const
Gets the underlying type used for to store values.
Type castFromExpressedType(Type candidateType)
Casts from a type based on the expressedType to a corresponding type based on this type (returns null...
bool isCompatibleExpressedType(Type candidateExpressedType)
Returns whether the candidateExpressedType is a match for this QuantizedType.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...