MLIR  19.0.0git
QuantTypes.cpp
Go to the documentation of this file.
1 //===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "TypeDetail.h"
12 
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/MLIRContext.h"
15 #include "llvm/ADT/StringRef.h"
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Support/MathExtras.h"
18 
19 using namespace mlir;
20 using namespace mlir::quant;
21 using namespace mlir::quant::detail;
22 
23 unsigned QuantizedType::getFlags() const {
24  return static_cast<ImplType *>(impl)->flags;
25 }
26 
28  return llvm::isa<QuantizationDialect>(type.getDialect());
29 }
30 
31 LogicalResult
33  unsigned flags, Type storageType, Type expressedType,
34  int64_t storageTypeMin, int64_t storageTypeMax) {
35  // Verify that the storage type is integral.
36  // This restriction may be lifted at some point in favor of using bf16
37  // or f16 as exact representations on hardware where that is advantageous.
38  auto intStorageType = llvm::dyn_cast<IntegerType>(storageType);
39  if (!intStorageType)
40  return emitError() << "storage type must be integral";
41  unsigned integralWidth = intStorageType.getWidth();
42 
43  // Verify storage width.
44  if (integralWidth == 0 || integralWidth > MaxStorageBits)
45  return emitError() << "illegal storage type size: " << integralWidth;
46 
47  // Verify storageTypeMin and storageTypeMax.
48  bool isSigned =
50  int64_t defaultIntegerMin =
51  getDefaultMinimumForInteger(isSigned, integralWidth);
52  int64_t defaultIntegerMax =
53  getDefaultMaximumForInteger(isSigned, integralWidth);
54  if (storageTypeMax - storageTypeMin <= 0 ||
55  storageTypeMin < defaultIntegerMin ||
56  storageTypeMax > defaultIntegerMax) {
57  return emitError() << "illegal storage min and storage max: ("
58  << storageTypeMin << ":" << storageTypeMax << ")";
59  }
60  return success();
61 }
62 
64  return static_cast<ImplType *>(impl)->storageType;
65 }
66 
68  return static_cast<ImplType *>(impl)->storageTypeMin;
69 }
70 
72  return static_cast<ImplType *>(impl)->storageTypeMax;
73 }
74 
76  // NOTE: If ever supporting non-integral storage types, some other scheme
77  // for determining the width will be needed.
78  return static_cast<ImplType *>(impl)->storageType.getIntOrFloatBitWidth();
79 }
80 
82  return static_cast<ImplType *>(impl)->expressedType;
83 }
84 
85 bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) {
86  if (llvm::isa<ShapedType>(candidateExpressedType)) {
87  return llvm::cast<ShapedType>(candidateExpressedType).getElementType() ==
88  getExpressedType();
89  }
90  return candidateExpressedType == getExpressedType();
91 }
92 
94 QuantizedType::getQuantizedElementType(Type primitiveOrContainerType) {
95  if (llvm::isa<ShapedType>(primitiveOrContainerType)) {
96  Type elementType =
97  llvm::cast<ShapedType>(primitiveOrContainerType).getElementType();
98  return llvm::dyn_cast<QuantizedType>(elementType);
99  }
100  return llvm::dyn_cast<QuantizedType>(primitiveOrContainerType);
101 }
102 
104  if (candidateType == getStorageType()) {
105  // i.e. i32 -> quant<"uniform[i8:f32]{1.0}">
106  return *this;
107  }
108  if (llvm::isa<RankedTensorType>(candidateType)) {
109  // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
110  return RankedTensorType::get(
111  llvm::cast<RankedTensorType>(candidateType).getShape(),
112  getStorageType());
113  }
114  if (llvm::isa<UnrankedTensorType>(candidateType)) {
115  // i.e. tensor<i8> -> tensor<!quant<"uniform[i8:f32]{1.0}">>
116  return UnrankedTensorType::get(getStorageType());
117  }
118  if (llvm::isa<VectorType>(candidateType)) {
119  // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
120  return VectorType::get(llvm::cast<VectorType>(candidateType).getShape(),
121  getStorageType());
122  }
123 
124  return nullptr;
125 }
126 
128  if (llvm::isa<QuantizedType>(quantizedType)) {
129  // i.e. quant<"uniform[i8:f32]{1.0}"> -> i8
130  return llvm::cast<QuantizedType>(quantizedType).getStorageType();
131  }
132  if (llvm::isa<ShapedType>(quantizedType)) {
133  // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
134  ShapedType sType = llvm::cast<ShapedType>(quantizedType);
135  if (!llvm::isa<QuantizedType>(sType.getElementType())) {
136  return nullptr;
137  }
138  Type storageType =
139  llvm::cast<QuantizedType>(sType.getElementType()).getStorageType();
140  if (llvm::isa<RankedTensorType>(quantizedType)) {
141  return RankedTensorType::get(sType.getShape(), storageType);
142  }
143  if (llvm::isa<UnrankedTensorType>(quantizedType)) {
144  return UnrankedTensorType::get(storageType);
145  }
146  if (llvm::isa<VectorType>(quantizedType)) {
147  return VectorType::get(sType.getShape(), storageType);
148  }
149  }
150 
151  return nullptr;
152 }
153 
155  if (candidateType == getExpressedType()) {
156  // i.e. f32 -> quant<"uniform[i8:f32]{1.0}">
157  return *this;
158  }
159  if (llvm::isa<ShapedType>(candidateType)) {
160  ShapedType candidateShapedType = llvm::cast<ShapedType>(candidateType);
161  if (candidateShapedType.getElementType() != getExpressedType()) {
162  return nullptr;
163  }
164 
165  if (llvm::isa<RankedTensorType>(candidateType)) {
166  // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
167  return RankedTensorType::get(candidateShapedType.getShape(), *this);
168  }
169  if (llvm::isa<UnrankedTensorType>(candidateType)) {
170  // i.e. tensor<xf32> -> tensor<x!quant<"uniform[i8:f32]{1.0}">>
171  return UnrankedTensorType::get(*this);
172  }
173  if (llvm::isa<VectorType>(candidateType)) {
174  // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
175  return VectorType::get(candidateShapedType.getShape(), *this);
176  }
177  }
178 
179  return nullptr;
180 }
181 
183  if (llvm::isa<QuantizedType>(quantizedType)) {
184  // i.e. quant<"uniform[i8:f32]{1.0}"> -> f32
185  return llvm::cast<QuantizedType>(quantizedType).getExpressedType();
186  }
187  if (llvm::isa<ShapedType>(quantizedType)) {
188  // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
189  ShapedType sType = llvm::cast<ShapedType>(quantizedType);
190  if (!llvm::isa<QuantizedType>(sType.getElementType())) {
191  return nullptr;
192  }
193  Type expressedType =
194  llvm::cast<QuantizedType>(sType.getElementType()).getExpressedType();
195  if (llvm::isa<RankedTensorType>(quantizedType)) {
196  return RankedTensorType::get(sType.getShape(), expressedType);
197  }
198  if (llvm::isa<UnrankedTensorType>(quantizedType)) {
199  return UnrankedTensorType::get(expressedType);
200  }
201  if (llvm::isa<VectorType>(quantizedType)) {
202  return VectorType::get(sType.getShape(), expressedType);
203  }
204  }
205 
206  return nullptr;
207 }
208 
210  Type expressedQuantizedType = castFromExpressedType(candidateType);
211  if (!expressedQuantizedType) {
212  return nullptr;
213  }
214  return QuantizedType::castToStorageType(expressedQuantizedType);
215 }
216 
217 AnyQuantizedType AnyQuantizedType::get(unsigned flags, Type storageType,
218  Type expressedType,
219  int64_t storageTypeMin,
220  int64_t storageTypeMax) {
221  return Base::get(storageType.getContext(), flags, storageType, expressedType,
222  storageTypeMin, storageTypeMax);
223 }
224 
227  unsigned flags, Type storageType,
228  Type expressedType, int64_t storageTypeMin,
229  int64_t storageTypeMax) {
230  return Base::getChecked(emitError, storageType.getContext(), flags,
231  storageType, expressedType, storageTypeMin,
232  storageTypeMax);
233 }
234 
235 LogicalResult
237  unsigned flags, Type storageType, Type expressedType,
238  int64_t storageTypeMin, int64_t storageTypeMax) {
239  if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
240  storageTypeMin, storageTypeMax))) {
241  return failure();
242  }
243 
244  // Verify that the expressed type is floating point.
245  // If this restriction is ever eliminated, the parser/printer must be
246  // extended.
247  if (expressedType && !llvm::isa<FloatType>(expressedType))
248  return emitError() << "expressed type must be floating point";
249 
250  return success();
251 }
252 
254  Type expressedType, double scale,
255  int64_t zeroPoint,
256  int64_t storageTypeMin,
257  int64_t storageTypeMax) {
258  return Base::get(storageType.getContext(), flags, storageType, expressedType,
259  scale, zeroPoint, storageTypeMin, storageTypeMax);
260 }
261 
263  function_ref<InFlightDiagnostic()> emitError, unsigned flags,
264  Type storageType, Type expressedType, double scale, int64_t zeroPoint,
265  int64_t storageTypeMin, int64_t storageTypeMax) {
266  return Base::getChecked(emitError, storageType.getContext(), flags,
267  storageType, expressedType, scale, zeroPoint,
268  storageTypeMin, storageTypeMax);
269 }
270 
272  function_ref<InFlightDiagnostic()> emitError, unsigned flags,
273  Type storageType, Type expressedType, double scale, int64_t zeroPoint,
274  int64_t storageTypeMin, int64_t storageTypeMax) {
275  if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
276  storageTypeMin, storageTypeMax))) {
277  return failure();
278  }
279 
280  // Uniform quantization requires fully expressed parameters, including
281  // expressed type.
282  if (!expressedType)
283  return emitError() << "uniform quantization requires expressed type";
284 
285  // Verify that the expressed type is floating point.
286  // If this restriction is ever eliminated, the parser/printer must be
287  // extended.
288  if (!llvm::isa<FloatType>(expressedType))
289  return emitError() << "expressed type must be floating point";
290 
291  // Verify scale.
292  if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale))
293  return emitError() << "illegal scale: " << scale;
294 
295  return success();
296 }
297 
298 double UniformQuantizedType::getScale() const { return getImpl()->scale; }
299 
301  return getImpl()->zeroPoint;
302 }
303 
305  unsigned flags, Type storageType, Type expressedType,
306  ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
307  int32_t quantizedDimension, int64_t storageTypeMin,
308  int64_t storageTypeMax) {
309  return Base::get(storageType.getContext(), flags, storageType, expressedType,
310  scales, zeroPoints, quantizedDimension, storageTypeMin,
311  storageTypeMax);
312 }
313 
315  function_ref<InFlightDiagnostic()> emitError, unsigned flags,
316  Type storageType, Type expressedType, ArrayRef<double> scales,
317  ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
318  int64_t storageTypeMin, int64_t storageTypeMax) {
319  return Base::getChecked(emitError, storageType.getContext(), flags,
320  storageType, expressedType, scales, zeroPoints,
321  quantizedDimension, storageTypeMin, storageTypeMax);
322 }
323 
325  function_ref<InFlightDiagnostic()> emitError, unsigned flags,
326  Type storageType, Type expressedType, ArrayRef<double> scales,
327  ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
328  int64_t storageTypeMin, int64_t storageTypeMax) {
329  if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
330  storageTypeMin, storageTypeMax))) {
331  return failure();
332  }
333 
334  // Uniform quantization requires fully expressed parameters, including
335  // expressed type.
336  if (!expressedType)
337  return emitError() << "uniform quantization requires expressed type";
338 
339  // Verify that the expressed type is floating point.
340  // If this restriction is ever eliminated, the parser/printer must be
341  // extended.
342  if (!llvm::isa<FloatType>(expressedType))
343  return emitError() << "expressed type must be floating point";
344 
345  // Ensure that the number of scales and zeroPoints match.
346  if (scales.size() != zeroPoints.size())
347  return emitError() << "illegal number of scales and zeroPoints: "
348  << scales.size() << ", " << zeroPoints.size();
349 
350  // Verify scale.
351  for (double scale : scales) {
352  if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale))
353  return emitError() << "illegal scale: " << scale;
354  }
355 
356  return success();
357 }
358 
360  return getImpl()->getScales();
361 }
362 
364  return getImpl()->getZeroPoints();
365 }
366 
368  return getImpl()->quantizedDimension;
369 }
370 
372  double min, double max) {
373  return Base::get(expressedType.getContext(), expressedType, min, max);
374 }
375 
377  function_ref<InFlightDiagnostic()> emitError, Type expressedType,
378  double min, double max) {
379  return Base::getChecked(emitError, expressedType.getContext(), expressedType,
380  min, max);
381 }
382 
383 LogicalResult
385  Type expressedType, double min, double max) {
386  // Verify that the expressed type is floating point.
387  // If this restriction is ever eliminated, the parser/printer must be
388  // extended.
389  if (!llvm::isa<FloatType>(expressedType))
390  return emitError() << "expressed type must be floating point";
391  if (max <= min)
392  return emitError() << "illegal min and max: (" << min << ":" << max << ")";
393 
394  return success();
395 }
396 
397 double CalibratedQuantizedType::getMin() const { return getImpl()->min; }
398 
399 double CalibratedQuantizedType::getMax() const { return getImpl()->max; }
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:307
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:123
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
A quantized type that maps storage to/from expressed types in an unspecified way.
Definition: QuantTypes.h:196
static AnyQuantizedType get(unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
Definition: QuantTypes.cpp:217
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
Definition: QuantTypes.cpp:236
static AnyQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
Definition: QuantTypes.cpp:226
A quantized type that infers its range from given min/max values.
Definition: QuantTypes.h:387
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double min, double max)
Verifies construction invariants and issues errors/warnings.
Definition: QuantTypes.cpp:384
static CalibratedQuantizedType get(Type expressedType, double min, double max)
Gets an instance of the type with all parameters specified but not checked.
Definition: QuantTypes.cpp:371
static CalibratedQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double min, double max)
Gets an instance of the type with all specified parameters checked.
Definition: QuantTypes.cpp:376
Base class for all quantized types known to this dialect.
Definition: QuantTypes.h:49
Type getExpressedType() const
Gets the original expressed type that this quantized type approximates.
Definition: QuantTypes.cpp:81
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Definition: QuantTypes.cpp:32
static Type castToStorageType(Type quantizedType)
Casts from a type based on a QuantizedType to a corresponding type based on the storageType (returns ...
Definition: QuantTypes.cpp:127
Type castExpressedToStorageType(Type candidateType)
Casts from a type based on the expressedType to the equivalent type based on storageType by way of th...
Definition: QuantTypes.cpp:209
static Type castToExpressedType(Type quantizedType)
Casts from a type based on QuantizedType to a corresponding type based on the expressedType (returns ...
Definition: QuantTypes.cpp:182
static QuantizedType getQuantizedElementType(Type primitiveOrContainerType)
Returns the element type as a QuantizedType or nullptr if it is not a quantized type.
Definition: QuantTypes.cpp:94
unsigned getFlags() const
Gets the flags associated with this type.
Definition: QuantTypes.cpp:23
int64_t getStorageTypeMax() const
The maximum value that storageType can take.
Definition: QuantTypes.cpp:71
unsigned getStorageTypeIntegralWidth() const
Gets the integral bit width that the underlying storage type can exactly represent.
Definition: QuantTypes.cpp:75
static bool classof(Type type)
Support method to enable LLVM-style type casting.
Definition: QuantTypes.cpp:27
Type castFromStorageType(Type candidateType)
Casts from a type based on the storageType to a corresponding type based on this type (returns nullpt...
Definition: QuantTypes.cpp:103
int64_t getStorageTypeMin() const
The minimum value that storageType can take.
Definition: QuantTypes.cpp:67
Type getStorageType() const
Gets the underlying type used for to store values.
Definition: QuantTypes.cpp:63
Type castFromExpressedType(Type candidateType)
Casts from a type based on the expressedType to a corresponding type based on this type (returns null...
Definition: QuantTypes.cpp:154
bool isCompatibleExpressedType(Type candidateExpressedType)
Returns whether the candidateExpressedType is a match for this QuantizedType.
Definition: QuantTypes.cpp:85
Represents per-axis (also known as per-channel quantization).
Definition: QuantTypes.h:317
static UniformQuantizedPerAxisType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
Definition: QuantTypes.cpp:314
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
Definition: QuantTypes.cpp:324
static UniformQuantizedPerAxisType get(unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
Definition: QuantTypes.cpp:304
int32_t getQuantizedDimension() const
Specifies the dimension of the Tensor's shape that the scales and zero_points correspond to.
Definition: QuantTypes.cpp:367
ArrayRef< int64_t > getZeroPoints() const
Gets the storage values corresponding to the real value 0 in the affine equation.
Definition: QuantTypes.cpp:363
ArrayRef< double > getScales() const
Gets the quantization scales.
Definition: QuantTypes.cpp:359
Represents a family of uniform, quantized types.
Definition: QuantTypes.h:257
double getScale() const
Gets the scale term.
Definition: QuantTypes.cpp:298
int64_t getZeroPoint() const
Gets the storage value corresponding to the real value 0 in the affine equation.
Definition: QuantTypes.cpp:300
static UniformQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
Definition: QuantTypes.cpp:262
static UniformQuantizedType get(unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
Definition: QuantTypes.cpp:253
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
Definition: QuantTypes.cpp:271
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...