15 #include "llvm/ADT/StringRef.h"
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Support/MathExtras.h"
26 double getMinScale(
Type expressedType) {
27 auto floatType = cast<FloatType>(expressedType);
28 return APFloat::getSmallest(floatType.getFloatSemantics()).convertToDouble();
32 double getMaxScale(
Type expressedType) {
33 auto floatType = cast<FloatType>(expressedType);
34 return APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble();
44 return llvm::isa<QuantDialect>(type.
getDialect());
49 unsigned flags,
Type storageType,
50 Type expressedType, int64_t storageTypeMin,
51 int64_t storageTypeMax) {
55 auto intStorageType = llvm::dyn_cast<IntegerType>(storageType);
57 return emitError() <<
"storage type must be integral";
58 unsigned integralWidth = intStorageType.getWidth();
61 if (integralWidth == 0 || integralWidth > MaxStorageBits)
62 return emitError() <<
"illegal storage type size: " << integralWidth;
67 int64_t defaultIntegerMin =
68 getDefaultMinimumForInteger(isSigned, integralWidth);
69 int64_t defaultIntegerMax =
70 getDefaultMaximumForInteger(isSigned, integralWidth);
71 if (storageTypeMax - storageTypeMin <= 0 ||
72 storageTypeMin < defaultIntegerMin ||
73 storageTypeMax > defaultIntegerMax) {
74 return emitError() <<
"illegal storage min and storage max: ("
75 << storageTypeMin <<
":" << storageTypeMax <<
")";
93 unsigned int integralWidth = getStorageTypeIntegralWidth();
94 bool isSignedInteger = isSigned();
95 int64_t defaultIntegerMin =
96 getDefaultMinimumForInteger(isSignedInteger, integralWidth);
97 int64_t defaultIntegerMax =
98 getDefaultMaximumForInteger(isSignedInteger, integralWidth);
99 return defaultIntegerMin != getStorageTypeMin() ||
100 defaultIntegerMax != getStorageTypeMax();
106 return static_cast<ImplType *
>(
impl)->storageType.getIntOrFloatBitWidth();
114 if (llvm::isa<ShapedType>(candidateExpressedType)) {
115 return llvm::cast<ShapedType>(candidateExpressedType).getElementType() ==
118 return candidateExpressedType == getExpressedType();
123 if (llvm::isa<ShapedType>(primitiveOrContainerType)) {
125 llvm::cast<ShapedType>(primitiveOrContainerType).getElementType();
126 return llvm::dyn_cast<QuantizedType>(elementType);
128 return llvm::dyn_cast<QuantizedType>(primitiveOrContainerType);
132 if (candidateType == getStorageType()) {
136 if (llvm::isa<RankedTensorType>(candidateType)) {
139 llvm::cast<RankedTensorType>(candidateType).
getShape(),
142 if (llvm::isa<UnrankedTensorType>(candidateType)) {
146 if (llvm::isa<VectorType>(candidateType)) {
156 if (llvm::isa<QuantizedType>(quantizedType)) {
158 return llvm::cast<QuantizedType>(quantizedType).getStorageType();
160 if (llvm::isa<ShapedType>(quantizedType)) {
162 ShapedType sType = llvm::cast<ShapedType>(quantizedType);
163 if (!llvm::isa<QuantizedType>(sType.getElementType())) {
167 llvm::cast<QuantizedType>(sType.getElementType()).getStorageType();
168 if (llvm::isa<RankedTensorType>(quantizedType)) {
171 if (llvm::isa<UnrankedTensorType>(quantizedType)) {
174 if (llvm::isa<VectorType>(quantizedType)) {
183 if (candidateType == getExpressedType()) {
187 if (llvm::isa<ShapedType>(candidateType)) {
188 ShapedType candidateShapedType = llvm::cast<ShapedType>(candidateType);
189 if (candidateShapedType.getElementType() != getExpressedType()) {
193 if (llvm::isa<RankedTensorType>(candidateType)) {
197 if (llvm::isa<UnrankedTensorType>(candidateType)) {
201 if (llvm::isa<VectorType>(candidateType)) {
211 if (llvm::isa<QuantizedType>(quantizedType)) {
213 return llvm::cast<QuantizedType>(quantizedType).getExpressedType();
215 if (llvm::isa<ShapedType>(quantizedType)) {
217 ShapedType sType = llvm::cast<ShapedType>(quantizedType);
218 if (!llvm::isa<QuantizedType>(sType.getElementType())) {
222 llvm::cast<QuantizedType>(sType.getElementType()).getExpressedType();
223 if (llvm::isa<RankedTensorType>(quantizedType)) {
226 if (llvm::isa<UnrankedTensorType>(quantizedType)) {
229 if (llvm::isa<VectorType>(quantizedType)) {
238 Type expressedQuantizedType = castFromExpressedType(candidateType);
239 if (!expressedQuantizedType) {
247 int64_t storageTypeMin,
248 int64_t storageTypeMax) {
250 storageTypeMin, storageTypeMax);
255 unsigned flags,
Type storageType,
256 Type expressedType, int64_t storageTypeMin,
257 int64_t storageTypeMax) {
259 storageType, expressedType, storageTypeMin,
265 unsigned flags,
Type storageType,
266 Type expressedType, int64_t storageTypeMin,
267 int64_t storageTypeMax) {
269 expressedType, storageTypeMin,
277 if (expressedType && !llvm::isa<FloatType>(expressedType))
278 return emitError() <<
"expressed type must be floating point";
284 Type expressedType,
double scale,
286 int64_t storageTypeMin,
287 int64_t storageTypeMax) {
289 scale, zeroPoint, storageTypeMin, storageTypeMax);
294 Type storageType,
Type expressedType,
double scale, int64_t zeroPoint,
295 int64_t storageTypeMin, int64_t storageTypeMax) {
297 storageType, expressedType, scale, zeroPoint,
298 storageTypeMin, storageTypeMax);
303 Type storageType,
Type expressedType,
double scale, int64_t zeroPoint,
304 int64_t storageTypeMin, int64_t storageTypeMax) {
306 expressedType, storageTypeMin,
314 return emitError() <<
"uniform quantization requires expressed type";
319 if (!llvm::isa<FloatType>(expressedType))
320 return emitError() <<
"expressed type must be floating point";
323 double minScale = getMinScale(expressedType);
324 double maxScale = getMaxScale(expressedType);
325 if (scale < minScale || scale > maxScale)
326 return emitError() <<
"scale out of expressed type range [" << minScale
327 <<
", " << maxScale <<
"]";
335 return getImpl()->zeroPoint;
339 unsigned flags,
Type storageType,
Type expressedType,
341 int32_t quantizedDimension, int64_t storageTypeMin,
342 int64_t storageTypeMax) {
344 scales, zeroPoints, quantizedDimension, storageTypeMin,
352 int64_t storageTypeMin, int64_t storageTypeMax) {
354 storageType, expressedType, scales, zeroPoints,
355 quantizedDimension, storageTypeMin, storageTypeMax);
362 int64_t storageTypeMin, int64_t storageTypeMax) {
364 expressedType, storageTypeMin,
372 return emitError() <<
"uniform quantization requires expressed type";
377 if (!llvm::isa<FloatType>(expressedType))
378 return emitError() <<
"expressed type must be floating point";
381 if (scales.size() != zeroPoints.size())
382 return emitError() <<
"illegal number of scales and zeroPoints: "
383 << scales.size() <<
", " << zeroPoints.size();
386 double minScale = getMinScale(expressedType);
387 double maxScale = getMaxScale(expressedType);
388 for (
double scale : scales) {
389 if (scale < minScale || scale > maxScale)
390 return emitError() <<
"scale out of expressed type range [" << minScale
391 <<
", " << maxScale <<
"]";
395 if (quantizedDimension < 0)
396 return emitError() <<
"illegal quantized dimension: " << quantizedDimension;
402 return getImpl()->getScales();
406 return getImpl()->getZeroPoints();
410 return getImpl()->quantizedDimension;
414 unsigned flags,
Type storageType,
Type expressedType,
417 int64_t storageTypeMin, int64_t storageTypeMax) {
419 scales, zeroPoints, quantizedDimensions, blockSizes,
420 storageTypeMin, storageTypeMax);
428 int64_t storageTypeMax) {
430 storageType, expressedType, scales, zeroPoints,
431 quantizedDimensions, blockSizes, storageTypeMin,
440 int64_t storageTypeMax) {
442 expressedType, storageTypeMin,
450 return emitError() <<
"uniform quantization requires expressed type";
455 if (!llvm::isa<FloatType>(expressedType))
456 return emitError() <<
"expressed type must be floating point";
459 if (scales.
getType().getElementType() != expressedType) {
460 return emitError() <<
"type of scale values "
461 << scales.
getType().getElementType()
462 <<
" must match the expressed type " << expressedType;
466 if (zeroPoints.
getType().getElementType() != storageType) {
467 return emitError() <<
"type of zero point values "
468 << zeroPoints.
getType().getElementType()
469 <<
" must match the storage type " << storageType;
473 if (scales.
getType().getShape() != zeroPoints.
getType().getShape())
474 return emitError() <<
"shape of scales and zeroPoints ("
475 << scales.
getType().getShape() <<
" vs "
476 << zeroPoints.
getType().getShape() <<
") does not match";
479 if (quantizedDimensions.size() != blockSizes.size())
480 return emitError() <<
"number of quantized dimensions and block sizes ("
481 << scales.
size() <<
" vs " << zeroPoints.
size()
482 <<
") does not match";
485 for (
auto quantizedDimension : quantizedDimensions) {
486 if (quantizedDimension < 0)
487 return emitError() <<
"illegal quantized dimension: "
488 << quantizedDimension;
492 for (
auto blockSize : blockSizes) {
494 return emitError() <<
"illegal block size: " << blockSize;
501 return getImpl()->getScales();
505 return getImpl()->getZeroPoints();
510 return getImpl()->getQuantizedDimensions();
514 return getImpl()->getBlockSizes();
520 result.reserve(getQuantizedDimensions().size());
522 for (
auto [dim, size] :
523 llvm::zip(getQuantizedDimensions(), getBlockSizes())) {
524 result.push_back({dim, size});
548 if (!llvm::isa<FloatType>(expressedType))
549 return emitError() <<
"expressed type must be floating point";
551 return emitError() <<
"illegal min and max: (" <<
min <<
":" <<
max <<
")";
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
An attribute that represents a reference to a dense vector or tensor object.
int64_t size() const
Returns the number of elements held by this attribute.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class represents a diagnostic that is inflight and set to be reported.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
A quantized type that maps storage to/from expressed types in an unspecified way.
static AnyQuantizedType get(unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
static AnyQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
A quantized type that infers its range from given min/max values.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double min, double max)
Verifies construction invariants and issues errors/warnings.
static CalibratedQuantizedType get(Type expressedType, double min, double max)
Gets an instance of the type with all parameters specified but not checked.
static CalibratedQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double min, double max)
Gets an instance of the type with all specified parameters checked.
Base class for all quantized types known to this dialect.
Type getExpressedType() const
Gets the original expressed type that this quantized type approximates.
bool hasStorageTypeBounds() const
Return whether the storage type has explicit min or max boundaries different from the minimum and max...
static Type castToStorageType(Type quantizedType)
Casts from a type based on a QuantizedType to a corresponding type based on the storageType (returns ...
Type castExpressedToStorageType(Type candidateType)
Casts from a type based on the expressedType to the equivalent type based on storageType by way of th...
static Type castToExpressedType(Type quantizedType)
Casts from a type based on QuantizedType to a corresponding type based on the expressedType (returns ...
static QuantizedType getQuantizedElementType(Type primitiveOrContainerType)
Returns the element type as a QuantizedType or nullptr if it is not a quantized type.
unsigned getFlags() const
Gets the flags associated with this type.
int64_t getStorageTypeMax() const
The maximum value that storageType can take.
unsigned getStorageTypeIntegralWidth() const
Gets the integral bit width that the underlying storage type can exactly represent.
static bool classof(Type type)
Support method to enable LLVM-style type casting.
Type castFromStorageType(Type candidateType)
Casts from a type based on the storageType to a corresponding type based on this type (returns nullpt...
int64_t getStorageTypeMin() const
The minimum value that storageType can take.
Type getStorageType() const
Gets the underlying type used for to store values.
Type castFromExpressedType(Type candidateType)
Casts from a type based on the expressedType to a corresponding type based on this type (returns null...
bool isCompatibleExpressedType(Type candidateExpressedType)
Returns whether the candidateExpressedType is a match for this QuantizedType.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...