15 #include "llvm/ADT/StringRef.h"
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Support/MathExtras.h"
26 double getMinScale(
Type expressedType) {
27 auto floatType = cast<FloatType>(expressedType);
28 return APFloat::getSmallest(floatType.getFloatSemantics()).convertToDouble();
32 double getMaxScale(
Type expressedType) {
33 auto floatType = cast<FloatType>(expressedType);
34 return APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble();
44 return llvm::isa<QuantDialect>(type.
getDialect());
49 unsigned flags,
Type storageType,
50 Type expressedType, int64_t storageTypeMin,
51 int64_t storageTypeMax) {
55 auto intStorageType = llvm::dyn_cast<IntegerType>(storageType);
57 return emitError() <<
"storage type must be integral";
58 unsigned integralWidth = intStorageType.getWidth();
61 if (integralWidth == 0 || integralWidth > MaxStorageBits)
62 return emitError() <<
"illegal storage type size: " << integralWidth;
67 int64_t defaultIntegerMin =
68 getDefaultMinimumForInteger(isSigned, integralWidth);
69 int64_t defaultIntegerMax =
70 getDefaultMaximumForInteger(isSigned, integralWidth);
71 if (storageTypeMax - storageTypeMin <= 0 ||
72 storageTypeMin < defaultIntegerMin ||
73 storageTypeMax > defaultIntegerMax) {
74 return emitError() <<
"illegal storage min and storage max: ("
75 << storageTypeMin <<
":" << storageTypeMax <<
")";
93 unsigned int integralWidth = getStorageTypeIntegralWidth();
94 bool isSignedInteger = isSigned();
95 int64_t defaultIntegerMin =
96 getDefaultMinimumForInteger(isSignedInteger, integralWidth);
97 int64_t defaultIntegerMax =
98 getDefaultMaximumForInteger(isSignedInteger, integralWidth);
99 return defaultIntegerMin != getStorageTypeMin() ||
100 defaultIntegerMax != getStorageTypeMax();
106 return static_cast<ImplType *
>(
impl)->storageType.getIntOrFloatBitWidth();
114 if (llvm::isa<ShapedType>(candidateExpressedType)) {
115 return llvm::cast<ShapedType>(candidateExpressedType).getElementType() ==
118 return candidateExpressedType == getExpressedType();
123 if (llvm::isa<ShapedType>(primitiveOrContainerType)) {
125 llvm::cast<ShapedType>(primitiveOrContainerType).getElementType();
126 return llvm::dyn_cast<QuantizedType>(elementType);
128 return llvm::dyn_cast<QuantizedType>(primitiveOrContainerType);
132 if (candidateType == getStorageType()) {
136 if (llvm::isa<RankedTensorType>(candidateType)) {
139 llvm::cast<RankedTensorType>(candidateType).
getShape(),
142 if (llvm::isa<UnrankedTensorType>(candidateType)) {
146 if (llvm::isa<VectorType>(candidateType)) {
156 if (llvm::isa<QuantizedType>(quantizedType)) {
158 return llvm::cast<QuantizedType>(quantizedType).getStorageType();
160 if (llvm::isa<ShapedType>(quantizedType)) {
162 ShapedType sType = llvm::cast<ShapedType>(quantizedType);
163 if (!llvm::isa<QuantizedType>(sType.getElementType())) {
167 llvm::cast<QuantizedType>(sType.getElementType()).getStorageType();
168 if (llvm::isa<RankedTensorType>(quantizedType)) {
171 if (llvm::isa<UnrankedTensorType>(quantizedType)) {
174 if (llvm::isa<VectorType>(quantizedType)) {
183 if (candidateType == getExpressedType()) {
187 if (llvm::isa<ShapedType>(candidateType)) {
188 ShapedType candidateShapedType = llvm::cast<ShapedType>(candidateType);
189 if (candidateShapedType.getElementType() != getExpressedType()) {
193 if (llvm::isa<RankedTensorType>(candidateType)) {
197 if (llvm::isa<UnrankedTensorType>(candidateType)) {
201 if (llvm::isa<VectorType>(candidateType)) {
211 if (llvm::isa<QuantizedType>(quantizedType)) {
213 return llvm::cast<QuantizedType>(quantizedType).getExpressedType();
215 if (llvm::isa<ShapedType>(quantizedType)) {
217 ShapedType sType = llvm::cast<ShapedType>(quantizedType);
218 if (!llvm::isa<QuantizedType>(sType.getElementType())) {
222 llvm::cast<QuantizedType>(sType.getElementType()).getExpressedType();
223 if (llvm::isa<RankedTensorType>(quantizedType)) {
226 if (llvm::isa<UnrankedTensorType>(quantizedType)) {
229 if (llvm::isa<VectorType>(quantizedType)) {
238 Type expressedQuantizedType = castFromExpressedType(candidateType);
239 if (!expressedQuantizedType) {
247 int64_t storageTypeMin,
248 int64_t storageTypeMax) {
250 storageTypeMin, storageTypeMax);
255 unsigned flags,
Type storageType,
256 Type expressedType, int64_t storageTypeMin,
257 int64_t storageTypeMax) {
259 storageType, expressedType, storageTypeMin,
265 unsigned flags,
Type storageType,
266 Type expressedType, int64_t storageTypeMin,
267 int64_t storageTypeMax) {
269 expressedType, storageTypeMin,
277 if (expressedType && !llvm::isa<FloatType>(expressedType))
278 return emitError() <<
"expressed type must be floating point";
284 Type expressedType,
double scale,
286 int64_t storageTypeMin,
287 int64_t storageTypeMax) {
289 scale, zeroPoint, storageTypeMin, storageTypeMax);
294 Type storageType,
Type expressedType,
double scale, int64_t zeroPoint,
295 int64_t storageTypeMin, int64_t storageTypeMax) {
297 storageType, expressedType, scale, zeroPoint,
298 storageTypeMin, storageTypeMax);
303 Type storageType,
Type expressedType,
double scale, int64_t zeroPoint,
304 int64_t storageTypeMin, int64_t storageTypeMax) {
306 expressedType, storageTypeMin,
314 return emitError() <<
"uniform quantization requires expressed type";
319 if (!llvm::isa<FloatType>(expressedType))
320 return emitError() <<
"expressed type must be floating point";
323 double minScale = getMinScale(expressedType);
324 double maxScale = getMaxScale(expressedType);
325 if (scale < minScale || scale > maxScale)
326 return emitError() <<
"scale out of expressed type range [" << minScale
327 <<
", " << maxScale <<
"]";
335 return getImpl()->zeroPoint;
339 unsigned flags,
Type storageType,
Type expressedType,
341 int32_t quantizedDimension, int64_t storageTypeMin,
342 int64_t storageTypeMax) {
344 scales, zeroPoints, quantizedDimension, storageTypeMin,
352 int64_t storageTypeMin, int64_t storageTypeMax) {
354 storageType, expressedType, scales, zeroPoints,
355 quantizedDimension, storageTypeMin, storageTypeMax);
362 int64_t storageTypeMin, int64_t storageTypeMax) {
364 expressedType, storageTypeMin,
372 return emitError() <<
"uniform quantization requires expressed type";
377 if (!llvm::isa<FloatType>(expressedType))
378 return emitError() <<
"expressed type must be floating point";
381 if (scales.size() != zeroPoints.size())
382 return emitError() <<
"illegal number of scales and zeroPoints: "
383 << scales.size() <<
", " << zeroPoints.size();
386 double minScale = getMinScale(expressedType);
387 double maxScale = getMaxScale(expressedType);
388 for (
double scale : scales) {
389 if (scale < minScale || scale > maxScale)
390 return emitError() <<
"scale out of expressed type range [" << minScale
391 <<
", " << maxScale <<
"]";
395 if (quantizedDimension < 0)
396 return emitError() <<
"illegal quantized dimension: " << quantizedDimension;
402 return getImpl()->getScales();
406 return getImpl()->getZeroPoints();
410 return getImpl()->quantizedDimension;
431 if (!llvm::isa<FloatType>(expressedType))
432 return emitError() <<
"expressed type must be floating point";
434 return emitError() <<
"illegal min and max: (" <<
min <<
":" <<
max <<
")";
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
This class represents a diagnostic that is inflight and set to be reported.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
A quantized type that maps storage to/from expressed types in an unspecified way.
static AnyQuantizedType get(unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
static AnyQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
A quantized type that infers its range from given min/max values.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double min, double max)
Verifies construction invariants and issues errors/warnings.
static CalibratedQuantizedType get(Type expressedType, double min, double max)
Gets an instance of the type with all parameters specified but not checked.
static CalibratedQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double min, double max)
Gets an instance of the type with all specified parameters checked.
Base class for all quantized types known to this dialect.
Type getExpressedType() const
Gets the original expressed type that this quantized type approximates.
bool hasStorageTypeBounds() const
Return whether the storage type has explicit min or max boundaries different from the minimum and max...
static Type castToStorageType(Type quantizedType)
Casts from a type based on a QuantizedType to a corresponding type based on the storageType (returns ...
Type castExpressedToStorageType(Type candidateType)
Casts from a type based on the expressedType to the equivalent type based on storageType by way of th...
static Type castToExpressedType(Type quantizedType)
Casts from a type based on QuantizedType to a corresponding type based on the expressedType (returns ...
static QuantizedType getQuantizedElementType(Type primitiveOrContainerType)
Returns the element type as a QuantizedType or nullptr if it is not a quantized type.
unsigned getFlags() const
Gets the flags associated with this type.
int64_t getStorageTypeMax() const
The maximum value that storageType can take.
unsigned getStorageTypeIntegralWidth() const
Gets the integral bit width that the underlying storage type can exactly represent.
static bool classof(Type type)
Support method to enable LLVM-style type casting.
Type castFromStorageType(Type candidateType)
Casts from a type based on the storageType to a corresponding type based on this type (returns nullpt...
int64_t getStorageTypeMin() const
The minimum value that storageType can take.
Type getStorageType() const
Gets the underlying type used for to store values.
Type castFromExpressedType(Type candidateType)
Casts from a type based on the expressedType to a corresponding type based on this type (returns null...
bool isCompatibleExpressedType(Type candidateExpressedType)
Returns whether the candidateExpressedType is a match for this QuantizedType.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...