26double getMinScale(
Type expressedType) {
27 auto floatType = cast<FloatType>(expressedType);
28 return APFloat::getSmallest(floatType.getFloatSemantics()).convertToDouble();
32double getMaxScale(
Type expressedType) {
33 auto floatType = cast<FloatType>(expressedType);
34 return APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble();
44 return llvm::isa<QuantDialect>(type.
getDialect());
49 unsigned flags,
Type storageType,
52 if (
auto quantStorageTypeInterface =
53 llvm::dyn_cast<QuantStorageTypeInterface>(storageType)) {
54 unsigned integralWidth = quantStorageTypeInterface.getStorageWidth();
58 return emitError() <<
"illegal storage type size: " << integralWidth;
61 int64_t defaultMin = quantStorageTypeInterface.getDefaultMinimum(
isSigned);
62 int64_t defaultMax = quantStorageTypeInterface.getDefaultMaximum(
isSigned);
64 if (storageTypeMax - storageTypeMin <= 0 || storageTypeMin < defaultMin ||
65 storageTypeMax > defaultMax) {
66 return emitError() <<
"illegal storage min and storage max: ("
67 << storageTypeMin <<
":" << storageTypeMax <<
")";
73 return emitError() <<
"storage type must implement QuantStorageTypeInterface";
90 auto quantStorageTypeInterface =
91 llvm::dyn_cast<QuantStorageTypeInterface>(storageType);
93 int64_t defaultMin = quantStorageTypeInterface.getDefaultMinimum(
isSigned());
94 int64_t defaultMax = quantStorageTypeInterface.getDefaultMaximum(
isSigned());
101 auto quantStorageTypeInterface =
102 llvm::dyn_cast<QuantStorageTypeInterface>(storageType);
104 return quantStorageTypeInterface.getStorageWidth();
112 if (llvm::isa<ShapedType>(candidateExpressedType)) {
113 return llvm::cast<ShapedType>(candidateExpressedType).getElementType() ==
121 if (llvm::isa<ShapedType>(primitiveOrContainerType)) {
123 llvm::cast<ShapedType>(primitiveOrContainerType).getElementType();
124 return llvm::dyn_cast<QuantizedType>(elementType);
126 return llvm::dyn_cast<QuantizedType>(primitiveOrContainerType);
134 if (llvm::isa<RankedTensorType>(candidateType)) {
136 return RankedTensorType::get(
137 llvm::cast<RankedTensorType>(candidateType).
getShape(),
140 if (llvm::isa<UnrankedTensorType>(candidateType)) {
144 if (llvm::isa<VectorType>(candidateType)) {
146 return VectorType::get(llvm::cast<VectorType>(candidateType).
getShape(),
154 if (llvm::isa<QuantizedType>(quantizedType)) {
156 return llvm::cast<QuantizedType>(quantizedType).getStorageType();
158 if (llvm::isa<ShapedType>(quantizedType)) {
160 ShapedType sType = llvm::cast<ShapedType>(quantizedType);
161 if (!llvm::isa<QuantizedType>(sType.getElementType())) {
165 llvm::cast<QuantizedType>(sType.getElementType()).getStorageType();
166 if (llvm::isa<RankedTensorType>(quantizedType)) {
167 return RankedTensorType::get(sType.getShape(), storageType);
169 if (llvm::isa<UnrankedTensorType>(quantizedType)) {
170 return UnrankedTensorType::get(storageType);
172 if (llvm::isa<VectorType>(quantizedType)) {
173 return VectorType::get(sType.getShape(), storageType);
185 if (llvm::isa<ShapedType>(candidateType)) {
186 ShapedType candidateShapedType = llvm::cast<ShapedType>(candidateType);
191 if (llvm::isa<RankedTensorType>(candidateType)) {
193 return RankedTensorType::get(candidateShapedType.getShape(), *
this);
195 if (llvm::isa<UnrankedTensorType>(candidateType)) {
197 return UnrankedTensorType::get(*
this);
199 if (llvm::isa<VectorType>(candidateType)) {
201 return VectorType::get(candidateShapedType.getShape(), *
this);
209 if (llvm::isa<QuantizedType>(quantizedType)) {
211 return llvm::cast<QuantizedType>(quantizedType).getExpressedType();
213 if (llvm::isa<ShapedType>(quantizedType)) {
215 ShapedType sType = llvm::cast<ShapedType>(quantizedType);
216 if (!llvm::isa<QuantizedType>(sType.getElementType())) {
220 llvm::cast<QuantizedType>(sType.getElementType()).getExpressedType();
221 if (llvm::isa<RankedTensorType>(quantizedType)) {
222 return RankedTensorType::get(sType.getShape(), expressedType);
224 if (llvm::isa<UnrankedTensorType>(quantizedType)) {
225 return UnrankedTensorType::get(expressedType);
227 if (llvm::isa<VectorType>(quantizedType)) {
228 return VectorType::get(sType.getShape(), expressedType);
237 if (!expressedQuantizedType) {
248 storageTypeMin, storageTypeMax);
253 unsigned flags,
Type storageType,
257 storageType, expressedType, storageTypeMin,
263 unsigned flags,
Type storageType,
267 expressedType, storageTypeMin,
275 if (expressedType && !llvm::isa<FloatType>(expressedType))
276 return emitError() <<
"expressed type must be floating point";
282 Type expressedType,
double scale,
287 scale, zeroPoint, storageTypeMin, storageTypeMax);
295 storageType, expressedType, scale, zeroPoint,
296 storageTypeMin, storageTypeMax);
304 expressedType, storageTypeMin,
312 return emitError() <<
"uniform quantization requires expressed type";
317 if (!llvm::isa<FloatType>(expressedType))
318 return emitError() <<
"expressed type must be floating point";
321 double minScale = getMinScale(expressedType);
322 double maxScale = getMaxScale(expressedType);
323 if (scale < minScale || scale > maxScale)
324 return emitError() <<
"scale out of expressed type range [" << minScale
325 <<
", " << maxScale <<
"]";
337 unsigned flags,
Type storageType,
Type expressedType,
339 int32_t quantizedDimension,
int64_t storageTypeMin,
342 scales, zeroPoints, quantizedDimension, storageTypeMin,
352 storageType, expressedType, scales, zeroPoints,
353 quantizedDimension, storageTypeMin, storageTypeMax);
362 expressedType, storageTypeMin,
370 return emitError() <<
"uniform quantization requires expressed type";
375 if (!llvm::isa<FloatType>(expressedType))
376 return emitError() <<
"expressed type must be floating point";
379 if (scales.size() != zeroPoints.size())
380 return emitError() <<
"illegal number of scales and zeroPoints: "
381 << scales.size() <<
", " << zeroPoints.size();
384 double minScale = getMinScale(expressedType);
385 double maxScale = getMaxScale(expressedType);
386 for (
double scale : scales) {
387 if (scale < minScale || scale > maxScale)
388 return emitError() <<
"scale out of expressed type range [" << minScale
389 <<
", " << maxScale <<
"]";
393 if (quantizedDimension < 0)
394 return emitError() <<
"illegal quantized dimension: " << quantizedDimension;
404 return getImpl()->getZeroPoints();
408 return getImpl()->quantizedDimension;
412 unsigned flags,
Type storageType,
Type expressedType,
417 scales, zeroPoints, quantizedDimensions, blockSizes,
418 storageTypeMin, storageTypeMax);
428 storageType, expressedType, scales, zeroPoints,
429 quantizedDimensions, blockSizes, storageTypeMin,
440 expressedType, storageTypeMin,
448 return emitError() <<
"uniform quantization requires expressed type";
453 if (!llvm::isa<FloatType>(expressedType))
454 return emitError() <<
"expressed type must be floating point";
457 if (scales.
getType().getElementType() != expressedType) {
458 return emitError() <<
"type of scale values "
459 << scales.
getType().getElementType()
460 <<
" must match the expressed type " << expressedType;
464 if (zeroPoints.
getType().getElementType() != storageType) {
465 return emitError() <<
"type of zero point values "
466 << zeroPoints.
getType().getElementType()
467 <<
" must match the storage type " << storageType;
471 if (scales.
getType().getShape() != zeroPoints.
getType().getShape())
472 return emitError() <<
"shape of scales and zeroPoints ("
473 << scales.
getType().getShape() <<
" vs "
474 << zeroPoints.
getType().getShape() <<
") does not match";
477 if (quantizedDimensions.size() != blockSizes.size())
478 return emitError() <<
"number of quantized dimensions and block sizes ("
479 << scales.
size() <<
" vs " << zeroPoints.
size()
480 <<
") does not match";
483 for (
auto quantizedDimension : quantizedDimensions) {
484 if (quantizedDimension < 0)
485 return emitError() <<
"illegal quantized dimension: "
486 << quantizedDimension;
490 for (
auto blockSize : blockSizes) {
492 return emitError() <<
"illegal block size: " << blockSize;
503 return getImpl()->getZeroPoints();
508 return getImpl()->getQuantizedDimensions();
512 return getImpl()->getBlockSizes();
520 for (
auto [dim, size] :
522 result.push_back({dim, size});
546 if (!llvm::isa<FloatType>(expressedType))
547 return emitError() <<
"expressed type must be floating point";
549 return emitError() <<
"illegal min and max: (" <<
min <<
":" <<
max <<
")";
561 std::optional<int64_t> storageMin,
562 std::optional<int64_t> storageMax) {
563 return Base::get(ctx, storageType, quantileType, quantiles, storageMin,
570 std::optional<int64_t> storageMin, std::optional<int64_t> storageMax) {
572 storageMin, storageMax);
578 std::optional<int64_t> storageMin, std::optional<int64_t> storageMax) {
580 return emitError() <<
"storage type must be an integer or float type";
581 if (!llvm::isa<mlir::FloatType>(quantileType))
582 return emitError() <<
"quantile type must be a float type";
583 if (quantiles.empty())
584 return emitError() <<
"quantile values must not be empty";
585 if (storageMin.has_value() != storageMax.has_value())
587 <<
"storage min and max must both be specified or both omitted";
588 if (storageMin && storageMax && *storageMin >= *storageMax)
589 return emitError() <<
"storage min must be less than storage max";
592 bool isSigned = !llvm::isa<mlir::IntegerType>(storageType) ||
593 llvm::cast<mlir::IntegerType>(storageType).isSigned();
595 storageMin.value_or(isSigned ? -(1LL << (width - 1)) : 0LL);
596 auto effectiveMax = storageMax.value_or(isSigned ? (1LL << (width - 1)) - 1
597 : (1LL << width) - 1);
598 auto expectedSize = effectiveMax - effectiveMin + 1;
599 if (
static_cast<decltype(expectedSize)
>(quantiles.size()) != expectedSize)
600 return emitError() <<
"quantile LUT size (" << quantiles.size()
601 <<
") must equal the number of representable storage "
603 << expectedSize <<
")";
605 for (
double v : quantiles)
606 if (std::isnan(v) || std::isinf(v))
608 <<
"quantile values must be finite (no NaN or infinity)";
638 if (
auto intType = mlir::dyn_cast<mlir::IntegerType>(
getStorageType()))
639 return intType.isSigned();
665 std::string
result =
"!quant.quantile<";
666 llvm::raw_string_ostream os(
result);
670 llvm::seq<size_t>(0, quantiles.size()), os,
671 [&](
size_t index) { os << quantiles[index]; },
",");
675 os <<
", <" << *minVal <<
":" << *maxVal <<
">";
687 return width > 0 ? 8 / width : 0;
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
An attribute that represents a reference to a dense vector or tensor object.
int64_t size() const
Returns the number of elements held by this attribute.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
static TypeID get()
Construct a type info object for the given type T.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
TypeID getTypeID()
Return a unique identifier for the concrete type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
static ConcreteType get(MLIRContext *ctx, Args &&...args)
static ConcreteType getChecked(const Location &loc, Args &&...args)
ImplType * getImpl() const
A quantized type that maps storage to/from expressed types in an unspecified way.
static AnyQuantizedType get(unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
static AnyQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
A quantized type that infers its range from given min/max values.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double min, double max)
Verifies construction invariants and issues errors/warnings.
static CalibratedQuantizedType get(Type expressedType, double min, double max)
Gets an instance of the type with all parameters specified but not checked.
static CalibratedQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double min, double max)
Gets an instance of the type with all specified parameters checked.
unsigned getLogicalBitWidth() const
static QuantileType getChecked(function_ref< InFlightDiagnostic()> emitError, mlir::MLIRContext *ctx, Type storageType, Type quantileType, ArrayRef< double > quantiles, std::optional< int64_t > storageMin=std::nullopt, std::optional< int64_t > storageMax=std::nullopt)
static bool classof(mlir::Type type)
Methods for support type inquiry through isa, cast, and dyn_cast.
unsigned getStorageWidth() const
std::optional< int64_t > getStorageMin() const
Return the explicit storage minimum, if set.
int64_t getDefaultMinimum(bool isSigned) const
unsigned getElementsPerByte() const
std::string getStorageTypeName(bool isSigned) const
detail::QuantileTypeStorage ImplType
bool shouldDefaultToSigned() const
Type getQuantileType() const
int64_t getDefaultMaximum(bool isSigned) const
static QuantileType get(mlir::MLIRContext *ctx, Type storageType, Type quantileType, ArrayRef< double > quantiles={}, std::optional< int64_t > storageMin=std::nullopt, std::optional< int64_t > storageMax=std::nullopt)
ArrayRef< double > getQuantiles() const
Return the quantile table of this float type.
std::optional< unsigned > getPreferredAlignmentBytes() const
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type storageType, Type quantileType, ArrayRef< double > quantiles, std::optional< int64_t > storageMin, std::optional< int64_t > storageMax)
std::optional< int64_t > getStorageMax() const
Return the explicit storage maximum, if set.
Type getStorageType() const
Base class for all quantized types known to this dialect.
Type getExpressedType() const
Gets the original expressed type that this quantized type approximates.
static constexpr unsigned MaxStorageBits
The maximum number of bits supported for storage types.
bool hasStorageTypeBounds() const
Return whether the storage type has explicit min or max boundaries different from the minimum and max...
static Type castToStorageType(Type quantizedType)
Casts from a type based on a QuantizedType to a corresponding type based on the storageType (returns ...
Type castExpressedToStorageType(Type candidateType)
Casts from a type based on the expressedType to the equivalent type based on storageType by way of th...
detail::QuantizedTypeStorage ImplType
static Type castToExpressedType(Type quantizedType)
Casts from a type based on QuantizedType to a corresponding type based on the expressedType (returns ...
bool isSigned() const
Whether the storage type should be interpreted as a signed quantity (true) or an unsigned value (fals...
static QuantizedType getQuantizedElementType(Type primitiveOrContainerType)
Returns the element type as a QuantizedType or nullptr if it is not a quantized type.
unsigned getFlags() const
Gets the flags associated with this type.
int64_t getStorageTypeMax() const
The maximum value that storageType can take.
unsigned getStorageTypeIntegralWidth() const
Gets the integral bit width that the underlying storage type can exactly represent.
static bool classof(Type type)
Support method to enable LLVM-style type casting.
Type castFromStorageType(Type candidateType)
Casts from a type based on the storageType to a corresponding type based on this type (returns nullpt...
int64_t getStorageTypeMin() const
The minimum value that storageType can take.
Type getStorageType() const
Gets the underlying type used for to store values.
Type castFromExpressedType(Type candidateType)
Casts from a type based on the expressedType to a corresponding type based on this type (returns null...
bool isCompatibleExpressedType(Type candidateExpressedType)
Returns whether the candidateExpressedType is a match for this QuantizedType.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::function_ref< Fn > function_ref