MLIR  16.0.0git
QuantTypes.cpp
Go to the documentation of this file.
1 //===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "TypeDetail.h"
12 
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/MLIRContext.h"
15 #include "llvm/ADT/StringRef.h"
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Support/MathExtras.h"
18 
19 using namespace mlir;
20 using namespace mlir::quant;
21 using namespace mlir::quant::detail;
22 
23 unsigned QuantizedType::getFlags() const {
24  return static_cast<ImplType *>(impl)->flags;
25 }
26 
28  return llvm::isa<QuantizationDialect>(type.getDialect());
29 }
30 
33  unsigned flags, Type storageType, Type expressedType,
34  int64_t storageTypeMin, int64_t storageTypeMax) {
35  // Verify that the storage type is integral.
36  // This restriction may be lifted at some point in favor of using bf16
37  // or f16 as exact representations on hardware where that is advantageous.
38  auto intStorageType = storageType.dyn_cast<IntegerType>();
39  if (!intStorageType)
40  return emitError() << "storage type must be integral";
41  unsigned integralWidth = intStorageType.getWidth();
42 
43  // Verify storage width.
44  if (integralWidth == 0 || integralWidth > MaxStorageBits)
45  return emitError() << "illegal storage type size: " << integralWidth;
46 
47  // Verify storageTypeMin and storageTypeMax.
48  bool isSigned =
50  int64_t defaultIntegerMin =
51  getDefaultMinimumForInteger(isSigned, integralWidth);
52  int64_t defaultIntegerMax =
53  getDefaultMaximumForInteger(isSigned, integralWidth);
54  if (storageTypeMax - storageTypeMin <= 0 ||
55  storageTypeMin < defaultIntegerMin ||
56  storageTypeMax > defaultIntegerMax) {
57  return emitError() << "illegal storage min and storage max: ("
58  << storageTypeMin << ":" << storageTypeMax << ")";
59  }
60  return success();
61 }
62 
64  return static_cast<ImplType *>(impl)->storageType;
65 }
66 
68  return static_cast<ImplType *>(impl)->storageTypeMin;
69 }
70 
72  return static_cast<ImplType *>(impl)->storageTypeMax;
73 }
74 
76  // NOTE: If ever supporting non-integral storage types, some other scheme
77  // for determining the width will be needed.
78  return static_cast<ImplType *>(impl)->storageType.getIntOrFloatBitWidth();
79 }
80 
82  return static_cast<ImplType *>(impl)->expressedType;
83 }
84 
85 bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) {
86  if (candidateExpressedType.isa<ShapedType>()) {
87  return candidateExpressedType.cast<ShapedType>().getElementType() ==
88  getExpressedType();
89  }
90  return candidateExpressedType == getExpressedType();
91 }
92 
94 QuantizedType::getQuantizedElementType(Type primitiveOrContainerType) {
95  if (primitiveOrContainerType.isa<ShapedType>()) {
96  Type elementType =
97  primitiveOrContainerType.cast<ShapedType>().getElementType();
98  return elementType.dyn_cast<QuantizedType>();
99  }
100  return primitiveOrContainerType.dyn_cast<QuantizedType>();
101 }
102 
104  if (candidateType == getStorageType()) {
105  // i.e. i32 -> quant<"uniform[i8:f32]{1.0}">
106  return *this;
107  }
108  if (candidateType.isa<RankedTensorType>()) {
109  // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
110  return RankedTensorType::get(
111  candidateType.cast<RankedTensorType>().getShape(), getStorageType());
112  }
113  if (candidateType.isa<UnrankedTensorType>()) {
114  // i.e. tensor<i8> -> tensor<!quant<"uniform[i8:f32]{1.0}">>
115  return UnrankedTensorType::get(getStorageType());
116  }
117  if (candidateType.isa<VectorType>()) {
118  // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
119  return VectorType::get(candidateType.cast<VectorType>().getShape(),
120  getStorageType());
121  }
122 
123  return nullptr;
124 }
125 
127  if (quantizedType.isa<QuantizedType>()) {
128  // i.e. quant<"uniform[i8:f32]{1.0}"> -> i8
129  return quantizedType.cast<QuantizedType>().getStorageType();
130  }
131  if (quantizedType.isa<ShapedType>()) {
132  // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
133  ShapedType sType = quantizedType.cast<ShapedType>();
134  if (!sType.getElementType().isa<QuantizedType>()) {
135  return nullptr;
136  }
137  Type storageType =
138  sType.getElementType().cast<QuantizedType>().getStorageType();
139  if (quantizedType.isa<RankedTensorType>()) {
140  return RankedTensorType::get(sType.getShape(), storageType);
141  }
142  if (quantizedType.isa<UnrankedTensorType>()) {
143  return UnrankedTensorType::get(storageType);
144  }
145  if (quantizedType.isa<VectorType>()) {
146  return VectorType::get(sType.getShape(), storageType);
147  }
148  }
149 
150  return nullptr;
151 }
152 
154  if (candidateType == getExpressedType()) {
155  // i.e. f32 -> quant<"uniform[i8:f32]{1.0}">
156  return *this;
157  }
158  if (candidateType.isa<ShapedType>()) {
159  ShapedType candidateShapedType = candidateType.cast<ShapedType>();
160  if (candidateShapedType.getElementType() != getExpressedType()) {
161  return nullptr;
162  }
163 
164  if (candidateType.isa<RankedTensorType>()) {
165  // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
166  return RankedTensorType::get(candidateShapedType.getShape(), *this);
167  }
168  if (candidateType.isa<UnrankedTensorType>()) {
169  // i.e. tensor<xf32> -> tensor<x!quant<"uniform[i8:f32]{1.0}">>
170  return UnrankedTensorType::get(*this);
171  }
172  if (candidateType.isa<VectorType>()) {
173  // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
174  return VectorType::get(candidateShapedType.getShape(), *this);
175  }
176  }
177 
178  return nullptr;
179 }
180 
182  if (quantizedType.isa<QuantizedType>()) {
183  // i.e. quant<"uniform[i8:f32]{1.0}"> -> f32
184  return quantizedType.cast<QuantizedType>().getExpressedType();
185  }
186  if (quantizedType.isa<ShapedType>()) {
187  // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
188  ShapedType sType = quantizedType.cast<ShapedType>();
189  if (!sType.getElementType().isa<QuantizedType>()) {
190  return nullptr;
191  }
192  Type expressedType =
193  sType.getElementType().cast<QuantizedType>().getExpressedType();
194  if (quantizedType.isa<RankedTensorType>()) {
195  return RankedTensorType::get(sType.getShape(), expressedType);
196  }
197  if (quantizedType.isa<UnrankedTensorType>()) {
198  return UnrankedTensorType::get(expressedType);
199  }
200  if (quantizedType.isa<VectorType>()) {
201  return VectorType::get(sType.getShape(), expressedType);
202  }
203  }
204 
205  return nullptr;
206 }
207 
209  Type expressedQuantizedType = castFromExpressedType(candidateType);
210  if (!expressedQuantizedType) {
211  return nullptr;
212  }
213  return QuantizedType::castToStorageType(expressedQuantizedType);
214 }
215 
216 AnyQuantizedType AnyQuantizedType::get(unsigned flags, Type storageType,
217  Type expressedType,
218  int64_t storageTypeMin,
219  int64_t storageTypeMax) {
220  return Base::get(storageType.getContext(), flags, storageType, expressedType,
221  storageTypeMin, storageTypeMax);
222 }
223 
226  unsigned flags, Type storageType,
227  Type expressedType, int64_t storageTypeMin,
228  int64_t storageTypeMax) {
229  return Base::getChecked(emitError, storageType.getContext(), flags,
230  storageType, expressedType, storageTypeMin,
231  storageTypeMax);
232 }
233 
236  unsigned flags, Type storageType, Type expressedType,
237  int64_t storageTypeMin, int64_t storageTypeMax) {
238  if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
239  storageTypeMin, storageTypeMax))) {
240  return failure();
241  }
242 
243  // Verify that the expressed type is floating point.
244  // If this restriction is ever eliminated, the parser/printer must be
245  // extended.
246  if (expressedType && !expressedType.isa<FloatType>())
247  return emitError() << "expressed type must be floating point";
248 
249  return success();
250 }
251 
253  Type expressedType, double scale,
254  int64_t zeroPoint,
255  int64_t storageTypeMin,
256  int64_t storageTypeMax) {
257  return Base::get(storageType.getContext(), flags, storageType, expressedType,
258  scale, zeroPoint, storageTypeMin, storageTypeMax);
259 }
260 
262  function_ref<InFlightDiagnostic()> emitError, unsigned flags,
263  Type storageType, Type expressedType, double scale, int64_t zeroPoint,
264  int64_t storageTypeMin, int64_t storageTypeMax) {
265  return Base::getChecked(emitError, storageType.getContext(), flags,
266  storageType, expressedType, scale, zeroPoint,
267  storageTypeMin, storageTypeMax);
268 }
269 
271  function_ref<InFlightDiagnostic()> emitError, unsigned flags,
272  Type storageType, Type expressedType, double scale, int64_t zeroPoint,
273  int64_t storageTypeMin, int64_t storageTypeMax) {
274  if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
275  storageTypeMin, storageTypeMax))) {
276  return failure();
277  }
278 
279  // Uniform quantization requires fully expressed parameters, including
280  // expressed type.
281  if (!expressedType)
282  return emitError() << "uniform quantization requires expressed type";
283 
284  // Verify that the expressed type is floating point.
285  // If this restriction is ever eliminated, the parser/printer must be
286  // extended.
287  if (!expressedType.isa<FloatType>())
288  return emitError() << "expressed type must be floating point";
289 
290  // Verify scale.
291  if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale))
292  return emitError() << "illegal scale: " << scale;
293 
294  return success();
295 }
296 
297 double UniformQuantizedType::getScale() const { return getImpl()->scale; }
298 
300  return getImpl()->zeroPoint;
301 }
302 
304  unsigned flags, Type storageType, Type expressedType,
305  ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
306  int32_t quantizedDimension, int64_t storageTypeMin,
307  int64_t storageTypeMax) {
308  return Base::get(storageType.getContext(), flags, storageType, expressedType,
309  scales, zeroPoints, quantizedDimension, storageTypeMin,
310  storageTypeMax);
311 }
312 
314  function_ref<InFlightDiagnostic()> emitError, unsigned flags,
315  Type storageType, Type expressedType, ArrayRef<double> scales,
316  ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
317  int64_t storageTypeMin, int64_t storageTypeMax) {
318  return Base::getChecked(emitError, storageType.getContext(), flags,
319  storageType, expressedType, scales, zeroPoints,
320  quantizedDimension, storageTypeMin, storageTypeMax);
321 }
322 
324  function_ref<InFlightDiagnostic()> emitError, unsigned flags,
325  Type storageType, Type expressedType, ArrayRef<double> scales,
326  ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
327  int64_t storageTypeMin, int64_t storageTypeMax) {
328  if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
329  storageTypeMin, storageTypeMax))) {
330  return failure();
331  }
332 
333  // Uniform quantization requires fully expressed parameters, including
334  // expressed type.
335  if (!expressedType)
336  return emitError() << "uniform quantization requires expressed type";
337 
338  // Verify that the expressed type is floating point.
339  // If this restriction is ever eliminated, the parser/printer must be
340  // extended.
341  if (!expressedType.isa<FloatType>())
342  return emitError() << "expressed type must be floating point";
343 
344  // Ensure that the number of scales and zeroPoints match.
345  if (scales.size() != zeroPoints.size())
346  return emitError() << "illegal number of scales and zeroPoints: "
347  << scales.size() << ", " << zeroPoints.size();
348 
349  // Verify scale.
350  for (double scale : scales) {
351  if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale))
352  return emitError() << "illegal scale: " << scale;
353  }
354 
355  return success();
356 }
357 
359  return getImpl()->getScales();
360 }
361 
363  return getImpl()->getZeroPoints();
364 }
365 
367  return getImpl()->quantizedDimension;
368 }
369 
371  double min, double max) {
372  return Base::get(expressedType.getContext(), expressedType, min, max);
373 }
374 
376  function_ref<InFlightDiagnostic()> emitError, Type expressedType,
377  double min, double max) {
378  return Base::getChecked(emitError, expressedType.getContext(), expressedType,
379  min, max);
380 }
381 
384  Type expressedType, double min, double max) {
385  // Verify that the expressed type is floating point.
386  // If this restriction is ever eliminated, the parser/printer must be
387  // extended.
388  if (!expressedType.isa<FloatType>())
389  return emitError() << "expressed type must be floating point";
390  if (max <= min)
391  return emitError() << "illegal min and max: (" << min << ":" << max << ")";
392 
393  return success();
394 }
395 
396 double CalibratedQuantizedType::getMin() const { return getImpl()->min; }
397 
398 double CalibratedQuantizedType::getMax() const { return getImpl()->max; }
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
Definition: QuantTypes.cpp:270
Include the generated interface declarations.
int64_t getStorageTypeMin() const
The minimum value that storageType can take.
Definition: QuantTypes.cpp:67
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:121
Type getStorageType() const
Gets the underlying type used for to store values.
Definition: QuantTypes.cpp:63
static CalibratedQuantizedType get(Type expressedType, double min, double max)
Gets an instance of the type with all parameters specified but not checked.
Definition: QuantTypes.cpp:370
static CalibratedQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double min, double max)
Gets an instance of the type with all specified parameters checked.
Definition: QuantTypes.cpp:375
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:310
Type getExpressedType() const
Gets the original expressed type that this quantized type approximates.
Definition: QuantTypes.cpp:81
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
Definition: QuantTypes.cpp:323
static QuantizedType getQuantizedElementType(Type primitiveOrContainerType)
Returns the element type as a QuantizedType or nullptr if it is not a quantized type.
Definition: QuantTypes.cpp:94
Represents a family of uniform, quantized types.
Definition: QuantTypes.h:256
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
A quantized type that infers its range from given min/max values.
Definition: QuantTypes.h:383
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Definition: QuantTypes.cpp:32
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:685
static UniformQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
Definition: QuantTypes.cpp:261
static UniformQuantizedPerAxisType get(unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
Definition: QuantTypes.cpp:303
bool isCompatibleExpressedType(Type candidateExpressedType)
Returns whether the candidateExpressedType is a match for this QuantizedType.
Definition: QuantTypes.cpp:85
static UniformQuantizedPerAxisType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
Definition: QuantTypes.cpp:313
Type castFromExpressedType(Type candidateType)
Casts from a type based on the expressedType to a corresponding type based on this type (returns null...
Definition: QuantTypes.cpp:153
Type castExpressedToStorageType(Type candidateType)
Casts from a type based on the expressedType to the equivalent type based on storageType by way of th...
Definition: QuantTypes.cpp:208
static AnyQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
Definition: QuantTypes.cpp:225
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
unsigned getStorageTypeIntegralWidth() const
Gets the integral bit width that the underlying storage type can exactly represent.
Definition: QuantTypes.cpp:75
unsigned getFlags() const
Gets the flags associated with this type.
Definition: QuantTypes.cpp:23
A quantized type that maps storage to/from expressed types in an unspecified way. ...
Definition: QuantTypes.h:197
U dyn_cast() const
Definition: Types.h:270
int64_t getZeroPoint() const
Gets the storage value corresponding to the real value 0 in the affine equation.
Definition: QuantTypes.cpp:299
Type castFromStorageType(Type candidateType)
Casts from a type based on the storageType to a corresponding type based on this type (returns nullpt...
Definition: QuantTypes.cpp:103
Represents per-axis (also known as per-channel quantization).
Definition: QuantTypes.h:314
int64_t getStorageTypeMax() const
The maximum value that storageType can take.
Definition: QuantTypes.cpp:71
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
Definition: QuantTypes.cpp:235
static UniformQuantizedType get(unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
Definition: QuantTypes.cpp:252
double getScale() const
Gets the scale term.
Definition: QuantTypes.cpp:297
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:19
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double min, double max)
Verifies construction invariants and issues errors/warnings.
Definition: QuantTypes.cpp:383
static bool classof(Type type)
Support method to enable LLVM-style type casting.
Definition: QuantTypes.cpp:27
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
static AnyQuantizedType get(unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
Definition: QuantTypes.cpp:216
Base class for all quantized types known to this dialect.
Definition: QuantTypes.h:52
int32_t getQuantizedDimension() const
Specifies the dimension of the Tensor&#39;s shape that the scales and zero_points correspond to...
Definition: QuantTypes.cpp:366
static Type castToExpressedType(Type quantizedType)
Casts from a type based on QuantizedType to a corresponding type based on the expressedType (returns ...
Definition: QuantTypes.cpp:181
ArrayRef< double > getScales() const
Gets the quantization scales.
Definition: QuantTypes.cpp:358
static Type castToStorageType(Type quantizedType)
Casts from a type based on a QuantizedType to a corresponding type based on the storageType (returns ...
Definition: QuantTypes.cpp:126
ArrayRef< int64_t > getZeroPoints() const
Gets the storage values corresponding to the real value 0 in the affine equation. ...
Definition: QuantTypes.cpp:362
bool isa() const
Definition: Types.h:254
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
U cast() const
Definition: Types.h:278