24double getMinScale(
Type expressedType) {
25 auto floatType = cast<FloatType>(expressedType);
26 return APFloat::getSmallest(floatType.getFloatSemantics()).convertToDouble();
30double getMaxScale(
Type expressedType) {
31 auto floatType = cast<FloatType>(expressedType);
32 return APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble();
42 return llvm::isa<QuantDialect>(type.
getDialect());
47 unsigned flags,
Type storageType,
50 if (
auto quantStorageTypeInterface =
51 llvm::dyn_cast<QuantStorageTypeInterface>(storageType)) {
52 unsigned integralWidth = quantStorageTypeInterface.getStorageWidth();
56 return emitError() <<
"illegal storage type size: " << integralWidth;
59 int64_t defaultMin = quantStorageTypeInterface.getDefaultMinimum(
isSigned);
60 int64_t defaultMax = quantStorageTypeInterface.getDefaultMaximum(
isSigned);
62 if (storageTypeMax - storageTypeMin <= 0 || storageTypeMin < defaultMin ||
63 storageTypeMax > defaultMax) {
64 return emitError() <<
"illegal storage min and storage max: ("
65 << storageTypeMin <<
":" << storageTypeMax <<
")";
71 return emitError() <<
"storage type must implement QuantStorageTypeInterface";
88 auto quantStorageTypeInterface =
89 llvm::dyn_cast<QuantStorageTypeInterface>(storageType);
91 int64_t defaultMin = quantStorageTypeInterface.getDefaultMinimum(
isSigned());
92 int64_t defaultMax = quantStorageTypeInterface.getDefaultMaximum(
isSigned());
99 auto quantStorageTypeInterface =
100 llvm::dyn_cast<QuantStorageTypeInterface>(storageType);
102 return quantStorageTypeInterface.getStorageWidth();
110 if (llvm::isa<ShapedType>(candidateExpressedType)) {
111 return llvm::cast<ShapedType>(candidateExpressedType).getElementType() ==
119 if (llvm::isa<ShapedType>(primitiveOrContainerType)) {
121 llvm::cast<ShapedType>(primitiveOrContainerType).getElementType();
122 return llvm::dyn_cast<QuantizedType>(elementType);
124 return llvm::dyn_cast<QuantizedType>(primitiveOrContainerType);
132 if (llvm::isa<RankedTensorType>(candidateType)) {
134 return RankedTensorType::get(
135 llvm::cast<RankedTensorType>(candidateType).
getShape(),
138 if (llvm::isa<UnrankedTensorType>(candidateType)) {
142 if (llvm::isa<VectorType>(candidateType)) {
144 return VectorType::get(llvm::cast<VectorType>(candidateType).
getShape(),
152 if (llvm::isa<QuantizedType>(quantizedType)) {
154 return llvm::cast<QuantizedType>(quantizedType).getStorageType();
156 if (llvm::isa<ShapedType>(quantizedType)) {
158 ShapedType sType = llvm::cast<ShapedType>(quantizedType);
159 if (!llvm::isa<QuantizedType>(sType.getElementType())) {
163 llvm::cast<QuantizedType>(sType.getElementType()).getStorageType();
164 if (llvm::isa<RankedTensorType>(quantizedType)) {
165 return RankedTensorType::get(sType.getShape(), storageType);
167 if (llvm::isa<UnrankedTensorType>(quantizedType)) {
168 return UnrankedTensorType::get(storageType);
170 if (llvm::isa<VectorType>(quantizedType)) {
171 return VectorType::get(sType.getShape(), storageType);
183 if (llvm::isa<ShapedType>(candidateType)) {
184 ShapedType candidateShapedType = llvm::cast<ShapedType>(candidateType);
189 if (llvm::isa<RankedTensorType>(candidateType)) {
191 return RankedTensorType::get(candidateShapedType.getShape(), *
this);
193 if (llvm::isa<UnrankedTensorType>(candidateType)) {
195 return UnrankedTensorType::get(*
this);
197 if (llvm::isa<VectorType>(candidateType)) {
199 return VectorType::get(candidateShapedType.getShape(), *
this);
207 if (llvm::isa<QuantizedType>(quantizedType)) {
209 return llvm::cast<QuantizedType>(quantizedType).getExpressedType();
211 if (llvm::isa<ShapedType>(quantizedType)) {
213 ShapedType sType = llvm::cast<ShapedType>(quantizedType);
214 if (!llvm::isa<QuantizedType>(sType.getElementType())) {
218 llvm::cast<QuantizedType>(sType.getElementType()).getExpressedType();
219 if (llvm::isa<RankedTensorType>(quantizedType)) {
220 return RankedTensorType::get(sType.getShape(), expressedType);
222 if (llvm::isa<UnrankedTensorType>(quantizedType)) {
223 return UnrankedTensorType::get(expressedType);
225 if (llvm::isa<VectorType>(quantizedType)) {
226 return VectorType::get(sType.getShape(), expressedType);
235 if (!expressedQuantizedType) {
246 storageTypeMin, storageTypeMax);
251 unsigned flags,
Type storageType,
255 storageType, expressedType, storageTypeMin,
261 unsigned flags,
Type storageType,
265 expressedType, storageTypeMin,
273 if (expressedType && !llvm::isa<FloatType>(expressedType))
274 return emitError() <<
"expressed type must be floating point";
280 Type expressedType,
double scale,
285 scale, zeroPoint, storageTypeMin, storageTypeMax);
293 storageType, expressedType, scale, zeroPoint,
294 storageTypeMin, storageTypeMax);
302 expressedType, storageTypeMin,
310 return emitError() <<
"uniform quantization requires expressed type";
315 if (!llvm::isa<FloatType>(expressedType))
316 return emitError() <<
"expressed type must be floating point";
319 double minScale = getMinScale(expressedType);
320 double maxScale = getMaxScale(expressedType);
321 if (scale < minScale || scale > maxScale)
322 return emitError() <<
"scale out of expressed type range [" << minScale
323 <<
", " << maxScale <<
"]";
335 unsigned flags,
Type storageType,
Type expressedType,
337 int32_t quantizedDimension,
int64_t storageTypeMin,
340 scales, zeroPoints, quantizedDimension, storageTypeMin,
350 storageType, expressedType, scales, zeroPoints,
351 quantizedDimension, storageTypeMin, storageTypeMax);
360 expressedType, storageTypeMin,
368 return emitError() <<
"uniform quantization requires expressed type";
373 if (!llvm::isa<FloatType>(expressedType))
374 return emitError() <<
"expressed type must be floating point";
377 if (scales.size() != zeroPoints.size())
378 return emitError() <<
"illegal number of scales and zeroPoints: "
379 << scales.size() <<
", " << zeroPoints.size();
382 double minScale = getMinScale(expressedType);
383 double maxScale = getMaxScale(expressedType);
384 for (
double scale : scales) {
385 if (scale < minScale || scale > maxScale)
386 return emitError() <<
"scale out of expressed type range [" << minScale
387 <<
", " << maxScale <<
"]";
391 if (quantizedDimension < 0)
392 return emitError() <<
"illegal quantized dimension: " << quantizedDimension;
402 return getImpl()->getZeroPoints();
406 return getImpl()->quantizedDimension;
410 unsigned flags,
Type storageType,
Type expressedType,
415 scales, zeroPoints, quantizedDimensions, blockSizes,
416 storageTypeMin, storageTypeMax);
426 storageType, expressedType, scales, zeroPoints,
427 quantizedDimensions, blockSizes, storageTypeMin,
438 expressedType, storageTypeMin,
446 return emitError() <<
"uniform quantization requires expressed type";
451 if (!llvm::isa<FloatType>(expressedType))
452 return emitError() <<
"expressed type must be floating point";
455 if (scales.
getType().getElementType() != expressedType) {
456 return emitError() <<
"type of scale values "
457 << scales.
getType().getElementType()
458 <<
" must match the expressed type " << expressedType;
462 if (zeroPoints.
getType().getElementType() != storageType) {
463 return emitError() <<
"type of zero point values "
464 << zeroPoints.
getType().getElementType()
465 <<
" must match the storage type " << storageType;
469 if (scales.
getType().getShape() != zeroPoints.
getType().getShape())
470 return emitError() <<
"shape of scales and zeroPoints ("
471 << scales.
getType().getShape() <<
" vs "
472 << zeroPoints.
getType().getShape() <<
") does not match";
475 if (quantizedDimensions.size() != blockSizes.size())
476 return emitError() <<
"number of quantized dimensions and block sizes ("
477 << scales.
size() <<
" vs " << zeroPoints.
size()
478 <<
") does not match";
481 for (
auto quantizedDimension : quantizedDimensions) {
482 if (quantizedDimension < 0)
483 return emitError() <<
"illegal quantized dimension: "
484 << quantizedDimension;
488 for (
auto blockSize : blockSizes) {
490 return emitError() <<
"illegal block size: " << blockSize;
501 return getImpl()->getZeroPoints();
506 return getImpl()->getQuantizedDimensions();
510 return getImpl()->getBlockSizes();
518 for (
auto [dim, size] :
520 result.push_back({dim, size});
544 if (!llvm::isa<FloatType>(expressedType))
545 return emitError() <<
"expressed type must be floating point";
547 return emitError() <<
"illegal min and max: (" <<
min <<
":" <<
max <<
")";
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
An attribute that represents a reference to a dense vector or tensor object.
int64_t size() const
Returns the number of elements held by this attribute.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class represents a diagnostic that is inflight and set to be reported.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
static ConcreteType get(MLIRContext *ctx, Args &&...args)
static ConcreteType getChecked(const Location &loc, Args &&...args)
ImplType * getImpl() const
A quantized type that maps storage to/from expressed types in an unspecified way.
static AnyQuantizedType get(unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
static AnyQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
A quantized type that infers its range from given min/max values.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double min, double max)
Verifies construction invariants and issues errors/warnings.
static CalibratedQuantizedType get(Type expressedType, double min, double max)
Gets an instance of the type with all parameters specified but not checked.
static CalibratedQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double min, double max)
Gets an instance of the type with all specified parameters checked.
Base class for all quantized types known to this dialect.
Type getExpressedType() const
Gets the original expressed type that this quantized type approximates.
static constexpr unsigned MaxStorageBits
The maximum number of bits supported for storage types.
bool hasStorageTypeBounds() const
Return whether the storage type has explicit min or max boundaries different from the minimum and max...
static Type castToStorageType(Type quantizedType)
Casts from a type based on a QuantizedType to a corresponding type based on the storageType (returns ...
Type castExpressedToStorageType(Type candidateType)
Casts from a type based on the expressedType to the equivalent type based on storageType by way of th...
detail::QuantizedTypeStorage ImplType
static Type castToExpressedType(Type quantizedType)
Casts from a type based on QuantizedType to a corresponding type based on the expressedType (returns ...
bool isSigned() const
Whether the storage type should be interpreted as a signed quantity (true) or an unsigned value (fals...
static QuantizedType getQuantizedElementType(Type primitiveOrContainerType)
Returns the element type as a QuantizedType or nullptr if it is not a quantized type.
unsigned getFlags() const
Gets the flags associated with this type.
int64_t getStorageTypeMax() const
The maximum value that storageType can take.
unsigned getStorageTypeIntegralWidth() const
Gets the integral bit width that the underlying storage type can exactly represent.
static bool classof(Type type)
Support method to enable LLVM-style type casting.
Type castFromStorageType(Type candidateType)
Casts from a type based on the storageType to a corresponding type based on this type (returns nullpt...
int64_t getStorageTypeMin() const
The minimum value that storageType can take.
Type getStorageType() const
Gets the underlying type used for to store values.
Type castFromExpressedType(Type candidateType)
Casts from a type based on the expressedType to a corresponding type based on this type (returns null...
bool isCompatibleExpressedType(Type candidateExpressedType)
Returns whether the candidateExpressedType is a match for this QuantizedType.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::function_ref< Fn > function_ref