23double getMinScale(
Type expressedType) {
24 auto floatType = cast<FloatType>(expressedType);
25 return APFloat::getSmallest(floatType.getFloatSemantics()).convertToDouble();
29double getMaxScale(
Type expressedType) {
30 auto floatType = cast<FloatType>(expressedType);
31 return APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble();
41 return llvm::isa<QuantDialect>(type.
getDialect());
46 unsigned flags,
Type storageType,
52 auto intStorageType = llvm::dyn_cast<IntegerType>(storageType);
54 return emitError() <<
"storage type must be integral";
55 unsigned integralWidth = intStorageType.getWidth();
59 return emitError() <<
"illegal storage type size: " << integralWidth;
68 if (storageTypeMax - storageTypeMin <= 0 ||
69 storageTypeMin < defaultIntegerMin ||
70 storageTypeMax > defaultIntegerMax) {
71 return emitError() <<
"illegal storage min and storage max: ("
72 << storageTypeMin <<
":" << storageTypeMax <<
")";
103 return static_cast<ImplType *
>(
impl)->storageType.getIntOrFloatBitWidth();
111 if (llvm::isa<ShapedType>(candidateExpressedType)) {
112 return llvm::cast<ShapedType>(candidateExpressedType).getElementType() ==
120 if (llvm::isa<ShapedType>(primitiveOrContainerType)) {
122 llvm::cast<ShapedType>(primitiveOrContainerType).getElementType();
123 return llvm::dyn_cast<QuantizedType>(elementType);
125 return llvm::dyn_cast<QuantizedType>(primitiveOrContainerType);
133 if (llvm::isa<RankedTensorType>(candidateType)) {
135 return RankedTensorType::get(
136 llvm::cast<RankedTensorType>(candidateType).
getShape(),
139 if (llvm::isa<UnrankedTensorType>(candidateType)) {
143 if (llvm::isa<VectorType>(candidateType)) {
145 return VectorType::get(llvm::cast<VectorType>(candidateType).
getShape(),
153 if (llvm::isa<QuantizedType>(quantizedType)) {
155 return llvm::cast<QuantizedType>(quantizedType).getStorageType();
157 if (llvm::isa<ShapedType>(quantizedType)) {
159 ShapedType sType = llvm::cast<ShapedType>(quantizedType);
160 if (!llvm::isa<QuantizedType>(sType.getElementType())) {
164 llvm::cast<QuantizedType>(sType.getElementType()).getStorageType();
165 if (llvm::isa<RankedTensorType>(quantizedType)) {
166 return RankedTensorType::get(sType.getShape(), storageType);
168 if (llvm::isa<UnrankedTensorType>(quantizedType)) {
169 return UnrankedTensorType::get(storageType);
171 if (llvm::isa<VectorType>(quantizedType)) {
172 return VectorType::get(sType.getShape(), storageType);
184 if (llvm::isa<ShapedType>(candidateType)) {
185 ShapedType candidateShapedType = llvm::cast<ShapedType>(candidateType);
190 if (llvm::isa<RankedTensorType>(candidateType)) {
192 return RankedTensorType::get(candidateShapedType.getShape(), *
this);
194 if (llvm::isa<UnrankedTensorType>(candidateType)) {
196 return UnrankedTensorType::get(*
this);
198 if (llvm::isa<VectorType>(candidateType)) {
200 return VectorType::get(candidateShapedType.getShape(), *
this);
208 if (llvm::isa<QuantizedType>(quantizedType)) {
210 return llvm::cast<QuantizedType>(quantizedType).getExpressedType();
212 if (llvm::isa<ShapedType>(quantizedType)) {
214 ShapedType sType = llvm::cast<ShapedType>(quantizedType);
215 if (!llvm::isa<QuantizedType>(sType.getElementType())) {
219 llvm::cast<QuantizedType>(sType.getElementType()).getExpressedType();
220 if (llvm::isa<RankedTensorType>(quantizedType)) {
221 return RankedTensorType::get(sType.getShape(), expressedType);
223 if (llvm::isa<UnrankedTensorType>(quantizedType)) {
224 return UnrankedTensorType::get(expressedType);
226 if (llvm::isa<VectorType>(quantizedType)) {
227 return VectorType::get(sType.getShape(), expressedType);
236 if (!expressedQuantizedType) {
247 storageTypeMin, storageTypeMax);
252 unsigned flags,
Type storageType,
256 storageType, expressedType, storageTypeMin,
262 unsigned flags,
Type storageType,
266 expressedType, storageTypeMin,
274 if (expressedType && !llvm::isa<FloatType>(expressedType))
275 return emitError() <<
"expressed type must be floating point";
281 Type expressedType,
double scale,
286 scale, zeroPoint, storageTypeMin, storageTypeMax);
294 storageType, expressedType, scale, zeroPoint,
295 storageTypeMin, storageTypeMax);
303 expressedType, storageTypeMin,
311 return emitError() <<
"uniform quantization requires expressed type";
316 if (!llvm::isa<FloatType>(expressedType))
317 return emitError() <<
"expressed type must be floating point";
320 double minScale = getMinScale(expressedType);
321 double maxScale = getMaxScale(expressedType);
322 if (scale < minScale || scale > maxScale)
323 return emitError() <<
"scale out of expressed type range [" << minScale
324 <<
", " << maxScale <<
"]";
336 unsigned flags,
Type storageType,
Type expressedType,
338 int32_t quantizedDimension,
int64_t storageTypeMin,
341 scales, zeroPoints, quantizedDimension, storageTypeMin,
351 storageType, expressedType, scales, zeroPoints,
352 quantizedDimension, storageTypeMin, storageTypeMax);
361 expressedType, storageTypeMin,
369 return emitError() <<
"uniform quantization requires expressed type";
374 if (!llvm::isa<FloatType>(expressedType))
375 return emitError() <<
"expressed type must be floating point";
378 if (scales.size() != zeroPoints.size())
379 return emitError() <<
"illegal number of scales and zeroPoints: "
380 << scales.size() <<
", " << zeroPoints.size();
383 double minScale = getMinScale(expressedType);
384 double maxScale = getMaxScale(expressedType);
385 for (
double scale : scales) {
386 if (scale < minScale || scale > maxScale)
387 return emitError() <<
"scale out of expressed type range [" << minScale
388 <<
", " << maxScale <<
"]";
392 if (quantizedDimension < 0)
393 return emitError() <<
"illegal quantized dimension: " << quantizedDimension;
403 return getImpl()->getZeroPoints();
407 return getImpl()->quantizedDimension;
411 unsigned flags,
Type storageType,
Type expressedType,
416 scales, zeroPoints, quantizedDimensions, blockSizes,
417 storageTypeMin, storageTypeMax);
427 storageType, expressedType, scales, zeroPoints,
428 quantizedDimensions, blockSizes, storageTypeMin,
439 expressedType, storageTypeMin,
447 return emitError() <<
"uniform quantization requires expressed type";
452 if (!llvm::isa<FloatType>(expressedType))
453 return emitError() <<
"expressed type must be floating point";
456 if (scales.
getType().getElementType() != expressedType) {
457 return emitError() <<
"type of scale values "
458 << scales.
getType().getElementType()
459 <<
" must match the expressed type " << expressedType;
463 if (zeroPoints.
getType().getElementType() != storageType) {
464 return emitError() <<
"type of zero point values "
465 << zeroPoints.
getType().getElementType()
466 <<
" must match the storage type " << storageType;
470 if (scales.
getType().getShape() != zeroPoints.
getType().getShape())
471 return emitError() <<
"shape of scales and zeroPoints ("
472 << scales.
getType().getShape() <<
" vs "
473 << zeroPoints.
getType().getShape() <<
") does not match";
476 if (quantizedDimensions.size() != blockSizes.size())
477 return emitError() <<
"number of quantized dimensions and block sizes ("
478 << scales.
size() <<
" vs " << zeroPoints.
size()
479 <<
") does not match";
482 for (
auto quantizedDimension : quantizedDimensions) {
483 if (quantizedDimension < 0)
484 return emitError() <<
"illegal quantized dimension: "
485 << quantizedDimension;
489 for (
auto blockSize : blockSizes) {
491 return emitError() <<
"illegal block size: " << blockSize;
502 return getImpl()->getZeroPoints();
507 return getImpl()->getQuantizedDimensions();
511 return getImpl()->getBlockSizes();
519 for (
auto [dim, size] :
521 result.push_back({dim, size});
545 if (!llvm::isa<FloatType>(expressedType))
546 return emitError() <<
"expressed type must be floating point";
548 return emitError() <<
"illegal min and max: (" <<
min <<
":" <<
max <<
")";
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
An attribute that represents a reference to a dense vector or tensor object.
int64_t size() const
Returns the number of elements held by this attribute.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class represents a diagnostic that is inflight and set to be reported.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
static ConcreteType get(MLIRContext *ctx, Args &&...args)
static ConcreteType getChecked(const Location &loc, Args &&...args)
ImplType * getImpl() const
A quantized type that maps storage to/from expressed types in an unspecified way.
static AnyQuantizedType get(unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
static AnyQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
A quantized type that infers its range from given min/max values.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double min, double max)
Verifies construction invariants and issues errors/warnings.
static CalibratedQuantizedType get(Type expressedType, double min, double max)
Gets an instance of the type with all parameters specified but not checked.
static CalibratedQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double min, double max)
Gets an instance of the type with all specified parameters checked.
Base class for all quantized types known to this dialect.
Type getExpressedType() const
Gets the original expressed type that this quantized type approximates.
static constexpr unsigned MaxStorageBits
The maximum number of bits supported for storage types.
bool hasStorageTypeBounds() const
Return whether the storage type has explicit min or max boundaries different from the minimum and max...
static Type castToStorageType(Type quantizedType)
Casts from a type based on a QuantizedType to a corresponding type based on the storageType (returns ...
Type castExpressedToStorageType(Type candidateType)
Casts from a type based on the expressedType to the equivalent type based on storageType by way of th...
detail::QuantizedTypeStorage ImplType
static Type castToExpressedType(Type quantizedType)
Casts from a type based on QuantizedType to a corresponding type based on the expressedType (returns ...
bool isSigned() const
Whether the storage type should be interpreted as a signed quantity (true) or an unsigned value (fals...
static QuantizedType getQuantizedElementType(Type primitiveOrContainerType)
Returns the element type as a QuantizedType or nullptr if it is not a quantized type.
unsigned getFlags() const
Gets the flags associated with this type.
int64_t getStorageTypeMax() const
The maximum value that storageType can take.
static int64_t getDefaultMaximumForInteger(bool isSigned, unsigned integralWidth)
Gets the maximum possible stored by a storageType.
unsigned getStorageTypeIntegralWidth() const
Gets the integral bit width that the underlying storage type can exactly represent.
static bool classof(Type type)
Support method to enable LLVM-style type casting.
Type castFromStorageType(Type candidateType)
Casts from a type based on the storageType to a corresponding type based on this type (returns nullpt...
int64_t getStorageTypeMin() const
The minimum value that storageType can take.
static int64_t getDefaultMinimumForInteger(bool isSigned, unsigned integralWidth)
Gets the minimum possible stored by a storageType.
Type getStorageType() const
Gets the underlying type used for to store values.
Type castFromExpressedType(Type candidateType)
Casts from a type based on the expressedType to a corresponding type based on this type (returns null...
bool isCompatibleExpressedType(Type candidateExpressedType)
Returns whether the candidateExpressedType is a match for this QuantizedType.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::function_ref< Fn > function_ref