15 #include "llvm/ADT/StringRef.h"
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Support/MathExtras.h"
28 return llvm::isa<QuantizationDialect>(type.
getDialect());
33 unsigned flags,
Type storageType,
Type expressedType,
34 int64_t storageTypeMin, int64_t storageTypeMax) {
38 auto intStorageType = llvm::dyn_cast<IntegerType>(storageType);
40 return emitError() <<
"storage type must be integral";
41 unsigned integralWidth = intStorageType.getWidth();
44 if (integralWidth == 0 || integralWidth > MaxStorageBits)
45 return emitError() <<
"illegal storage type size: " << integralWidth;
50 int64_t defaultIntegerMin =
51 getDefaultMinimumForInteger(isSigned, integralWidth);
52 int64_t defaultIntegerMax =
53 getDefaultMaximumForInteger(isSigned, integralWidth);
54 if (storageTypeMax - storageTypeMin <= 0 ||
55 storageTypeMin < defaultIntegerMin ||
56 storageTypeMax > defaultIntegerMax) {
57 return emitError() <<
"illegal storage min and storage max: ("
58 << storageTypeMin <<
":" << storageTypeMax <<
")";
78 return static_cast<ImplType *
>(
impl)->storageType.getIntOrFloatBitWidth();
86 if (llvm::isa<ShapedType>(candidateExpressedType)) {
87 return llvm::cast<ShapedType>(candidateExpressedType).getElementType() ==
90 return candidateExpressedType == getExpressedType();
95 if (llvm::isa<ShapedType>(primitiveOrContainerType)) {
97 llvm::cast<ShapedType>(primitiveOrContainerType).getElementType();
98 return llvm::dyn_cast<QuantizedType>(elementType);
100 return llvm::dyn_cast<QuantizedType>(primitiveOrContainerType);
104 if (candidateType == getStorageType()) {
108 if (llvm::isa<RankedTensorType>(candidateType)) {
111 llvm::cast<RankedTensorType>(candidateType).
getShape(),
114 if (llvm::isa<UnrankedTensorType>(candidateType)) {
118 if (llvm::isa<VectorType>(candidateType)) {
128 if (llvm::isa<QuantizedType>(quantizedType)) {
130 return llvm::cast<QuantizedType>(quantizedType).getStorageType();
132 if (llvm::isa<ShapedType>(quantizedType)) {
134 ShapedType sType = llvm::cast<ShapedType>(quantizedType);
135 if (!llvm::isa<QuantizedType>(sType.getElementType())) {
139 llvm::cast<QuantizedType>(sType.getElementType()).getStorageType();
140 if (llvm::isa<RankedTensorType>(quantizedType)) {
143 if (llvm::isa<UnrankedTensorType>(quantizedType)) {
146 if (llvm::isa<VectorType>(quantizedType)) {
155 if (candidateType == getExpressedType()) {
159 if (llvm::isa<ShapedType>(candidateType)) {
160 ShapedType candidateShapedType = llvm::cast<ShapedType>(candidateType);
161 if (candidateShapedType.getElementType() != getExpressedType()) {
165 if (llvm::isa<RankedTensorType>(candidateType)) {
169 if (llvm::isa<UnrankedTensorType>(candidateType)) {
173 if (llvm::isa<VectorType>(candidateType)) {
183 if (llvm::isa<QuantizedType>(quantizedType)) {
185 return llvm::cast<QuantizedType>(quantizedType).getExpressedType();
187 if (llvm::isa<ShapedType>(quantizedType)) {
189 ShapedType sType = llvm::cast<ShapedType>(quantizedType);
190 if (!llvm::isa<QuantizedType>(sType.getElementType())) {
194 llvm::cast<QuantizedType>(sType.getElementType()).getExpressedType();
195 if (llvm::isa<RankedTensorType>(quantizedType)) {
198 if (llvm::isa<UnrankedTensorType>(quantizedType)) {
201 if (llvm::isa<VectorType>(quantizedType)) {
210 Type expressedQuantizedType = castFromExpressedType(candidateType);
211 if (!expressedQuantizedType) {
219 int64_t storageTypeMin,
220 int64_t storageTypeMax) {
222 storageTypeMin, storageTypeMax);
227 unsigned flags,
Type storageType,
228 Type expressedType, int64_t storageTypeMin,
229 int64_t storageTypeMax) {
231 storageType, expressedType, storageTypeMin,
237 unsigned flags,
Type storageType,
Type expressedType,
238 int64_t storageTypeMin, int64_t storageTypeMax) {
240 storageTypeMin, storageTypeMax))) {
247 if (expressedType && !llvm::isa<FloatType>(expressedType))
248 return emitError() <<
"expressed type must be floating point";
254 Type expressedType,
double scale,
256 int64_t storageTypeMin,
257 int64_t storageTypeMax) {
259 scale, zeroPoint, storageTypeMin, storageTypeMax);
264 Type storageType,
Type expressedType,
double scale, int64_t zeroPoint,
265 int64_t storageTypeMin, int64_t storageTypeMax) {
267 storageType, expressedType, scale, zeroPoint,
268 storageTypeMin, storageTypeMax);
273 Type storageType,
Type expressedType,
double scale, int64_t zeroPoint,
274 int64_t storageTypeMin, int64_t storageTypeMax) {
276 storageTypeMin, storageTypeMax))) {
283 return emitError() <<
"uniform quantization requires expressed type";
288 if (!llvm::isa<FloatType>(expressedType))
289 return emitError() <<
"expressed type must be floating point";
292 if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale))
293 return emitError() <<
"illegal scale: " << scale;
301 return getImpl()->zeroPoint;
305 unsigned flags,
Type storageType,
Type expressedType,
307 int32_t quantizedDimension, int64_t storageTypeMin,
308 int64_t storageTypeMax) {
310 scales, zeroPoints, quantizedDimension, storageTypeMin,
318 int64_t storageTypeMin, int64_t storageTypeMax) {
320 storageType, expressedType, scales, zeroPoints,
321 quantizedDimension, storageTypeMin, storageTypeMax);
328 int64_t storageTypeMin, int64_t storageTypeMax) {
330 storageTypeMin, storageTypeMax))) {
337 return emitError() <<
"uniform quantization requires expressed type";
342 if (!llvm::isa<FloatType>(expressedType))
343 return emitError() <<
"expressed type must be floating point";
346 if (scales.size() != zeroPoints.size())
347 return emitError() <<
"illegal number of scales and zeroPoints: "
348 << scales.size() <<
", " << zeroPoints.size();
351 for (
double scale : scales) {
352 if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale))
353 return emitError() <<
"illegal scale: " << scale;
360 return getImpl()->getScales();
364 return getImpl()->getZeroPoints();
368 return getImpl()->quantizedDimension;
389 if (!llvm::isa<FloatType>(expressedType))
390 return emitError() <<
"expressed type must be floating point";
392 return emitError() <<
"illegal min and max: (" <<
min <<
":" <<
max <<
")";
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
This class represents a diagnostic that is inflight and set to be reported.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
A quantized type that maps storage to/from expressed types in an unspecified way.
static AnyQuantizedType get(unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
static AnyQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
A quantized type that infers its range from given min/max values.
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double min, double max)
Verifies construction invariants and issues errors/warnings.
static CalibratedQuantizedType get(Type expressedType, double min, double max)
Gets an instance of the type with all parameters specified but not checked.
static CalibratedQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double min, double max)
Gets an instance of the type with all specified parameters checked.
Base class for all quantized types known to this dialect.
Type getExpressedType() const
Gets the original expressed type that this quantized type approximates.
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
static Type castToStorageType(Type quantizedType)
Casts from a type based on a QuantizedType to a corresponding type based on the storageType (returns ...
Type castExpressedToStorageType(Type candidateType)
Casts from a type based on the expressedType to the equivalent type based on storageType by way of th...
static Type castToExpressedType(Type quantizedType)
Casts from a type based on QuantizedType to a corresponding type based on the expressedType (returns ...
static QuantizedType getQuantizedElementType(Type primitiveOrContainerType)
Returns the element type as a QuantizedType or nullptr if it is not a quantized type.
unsigned getFlags() const
Gets the flags associated with this type.
int64_t getStorageTypeMax() const
The maximum value that storageType can take.
unsigned getStorageTypeIntegralWidth() const
Gets the integral bit width that the underlying storage type can exactly represent.
static bool classof(Type type)
Support method to enable LLVM-style type casting.
Type castFromStorageType(Type candidateType)
Casts from a type based on the storageType to a corresponding type based on this type (returns nullpt...
int64_t getStorageTypeMin() const
The minimum value that storageType can take.
Type getStorageType() const
Gets the underlying type used for to store values.
Type castFromExpressedType(Type candidateType)
Casts from a type based on the expressedType to a corresponding type based on this type (returns null...
bool isCompatibleExpressedType(Type candidateExpressedType)
Returns whether the candidateExpressedType is a match for this QuantizedType.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.