MLIR  20.0.0git
QuantTypes.cpp
Go to the documentation of this file.
1 //===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TypeDetail.h"
12 
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/MLIRContext.h"
15 #include "llvm/ADT/StringRef.h"
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Support/MathExtras.h"
18 
19 using namespace mlir;
20 using namespace mlir::quant;
21 using namespace mlir::quant::detail;
22 
23 namespace {
24 
25 // Return the minimum scale representable in a given float type
26 double getMinScale(Type expressedType) {
27  auto floatType = cast<FloatType>(expressedType);
28  return APFloat::getSmallest(floatType.getFloatSemantics()).convertToDouble();
29 }
30 
31 // Return the maximum scale representable in a given float type
32 double getMaxScale(Type expressedType) {
33  auto floatType = cast<FloatType>(expressedType);
34  return APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble();
35 }
36 
37 } // namespace
38 
39 unsigned QuantizedType::getFlags() const {
40  return static_cast<ImplType *>(impl)->flags;
41 }
42 
44  return llvm::isa<QuantDialect>(type.getDialect());
45 }
46 
47 LogicalResult
49  unsigned flags, Type storageType,
50  Type expressedType, int64_t storageTypeMin,
51  int64_t storageTypeMax) {
52  // Verify that the storage type is integral.
53  // This restriction may be lifted at some point in favor of using bf16
54  // or f16 as exact representations on hardware where that is advantageous.
55  auto intStorageType = llvm::dyn_cast<IntegerType>(storageType);
56  if (!intStorageType)
57  return emitError() << "storage type must be integral";
58  unsigned integralWidth = intStorageType.getWidth();
59 
60  // Verify storage width.
61  if (integralWidth == 0 || integralWidth > MaxStorageBits)
62  return emitError() << "illegal storage type size: " << integralWidth;
63 
64  // Verify storageTypeMin and storageTypeMax.
65  bool isSigned =
67  int64_t defaultIntegerMin =
68  getDefaultMinimumForInteger(isSigned, integralWidth);
69  int64_t defaultIntegerMax =
70  getDefaultMaximumForInteger(isSigned, integralWidth);
71  if (storageTypeMax - storageTypeMin <= 0 ||
72  storageTypeMin < defaultIntegerMin ||
73  storageTypeMax > defaultIntegerMax) {
74  return emitError() << "illegal storage min and storage max: ("
75  << storageTypeMin << ":" << storageTypeMax << ")";
76  }
77  return success();
78 }
79 
81  return static_cast<ImplType *>(impl)->storageType;
82 }
83 
85  return static_cast<ImplType *>(impl)->storageTypeMin;
86 }
87 
89  return static_cast<ImplType *>(impl)->storageTypeMax;
90 }
91 
93  unsigned int integralWidth = getStorageTypeIntegralWidth();
94  bool isSignedInteger = isSigned();
95  int64_t defaultIntegerMin =
96  getDefaultMinimumForInteger(isSignedInteger, integralWidth);
97  int64_t defaultIntegerMax =
98  getDefaultMaximumForInteger(isSignedInteger, integralWidth);
99  return defaultIntegerMin != getStorageTypeMin() ||
100  defaultIntegerMax != getStorageTypeMax();
101 }
102 
104  // NOTE: If ever supporting non-integral storage types, some other scheme
105  // for determining the width will be needed.
106  return static_cast<ImplType *>(impl)->storageType.getIntOrFloatBitWidth();
107 }
108 
110  return static_cast<ImplType *>(impl)->expressedType;
111 }
112 
113 bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) {
114  if (llvm::isa<ShapedType>(candidateExpressedType)) {
115  return llvm::cast<ShapedType>(candidateExpressedType).getElementType() ==
116  getExpressedType();
117  }
118  return candidateExpressedType == getExpressedType();
119 }
120 
122 QuantizedType::getQuantizedElementType(Type primitiveOrContainerType) {
123  if (llvm::isa<ShapedType>(primitiveOrContainerType)) {
124  Type elementType =
125  llvm::cast<ShapedType>(primitiveOrContainerType).getElementType();
126  return llvm::dyn_cast<QuantizedType>(elementType);
127  }
128  return llvm::dyn_cast<QuantizedType>(primitiveOrContainerType);
129 }
130 
132  if (candidateType == getStorageType()) {
133  // i.e. i32 -> quant<"uniform[i8:f32]{1.0}">
134  return *this;
135  }
136  if (llvm::isa<RankedTensorType>(candidateType)) {
137  // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
138  return RankedTensorType::get(
139  llvm::cast<RankedTensorType>(candidateType).getShape(),
140  getStorageType());
141  }
142  if (llvm::isa<UnrankedTensorType>(candidateType)) {
143  // i.e. tensor<i8> -> tensor<!quant<"uniform[i8:f32]{1.0}">>
144  return UnrankedTensorType::get(getStorageType());
145  }
146  if (llvm::isa<VectorType>(candidateType)) {
147  // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
148  return VectorType::get(llvm::cast<VectorType>(candidateType).getShape(),
149  getStorageType());
150  }
151 
152  return nullptr;
153 }
154 
156  if (llvm::isa<QuantizedType>(quantizedType)) {
157  // i.e. quant<"uniform[i8:f32]{1.0}"> -> i8
158  return llvm::cast<QuantizedType>(quantizedType).getStorageType();
159  }
160  if (llvm::isa<ShapedType>(quantizedType)) {
161  // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
162  ShapedType sType = llvm::cast<ShapedType>(quantizedType);
163  if (!llvm::isa<QuantizedType>(sType.getElementType())) {
164  return nullptr;
165  }
166  Type storageType =
167  llvm::cast<QuantizedType>(sType.getElementType()).getStorageType();
168  if (llvm::isa<RankedTensorType>(quantizedType)) {
169  return RankedTensorType::get(sType.getShape(), storageType);
170  }
171  if (llvm::isa<UnrankedTensorType>(quantizedType)) {
172  return UnrankedTensorType::get(storageType);
173  }
174  if (llvm::isa<VectorType>(quantizedType)) {
175  return VectorType::get(sType.getShape(), storageType);
176  }
177  }
178 
179  return nullptr;
180 }
181 
183  if (candidateType == getExpressedType()) {
184  // i.e. f32 -> quant<"uniform[i8:f32]{1.0}">
185  return *this;
186  }
187  if (llvm::isa<ShapedType>(candidateType)) {
188  ShapedType candidateShapedType = llvm::cast<ShapedType>(candidateType);
189  if (candidateShapedType.getElementType() != getExpressedType()) {
190  return nullptr;
191  }
192 
193  if (llvm::isa<RankedTensorType>(candidateType)) {
194  // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
195  return RankedTensorType::get(candidateShapedType.getShape(), *this);
196  }
197  if (llvm::isa<UnrankedTensorType>(candidateType)) {
198  // i.e. tensor<xf32> -> tensor<x!quant<"uniform[i8:f32]{1.0}">>
199  return UnrankedTensorType::get(*this);
200  }
201  if (llvm::isa<VectorType>(candidateType)) {
202  // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
203  return VectorType::get(candidateShapedType.getShape(), *this);
204  }
205  }
206 
207  return nullptr;
208 }
209 
211  if (llvm::isa<QuantizedType>(quantizedType)) {
212  // i.e. quant<"uniform[i8:f32]{1.0}"> -> f32
213  return llvm::cast<QuantizedType>(quantizedType).getExpressedType();
214  }
215  if (llvm::isa<ShapedType>(quantizedType)) {
216  // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
217  ShapedType sType = llvm::cast<ShapedType>(quantizedType);
218  if (!llvm::isa<QuantizedType>(sType.getElementType())) {
219  return nullptr;
220  }
221  Type expressedType =
222  llvm::cast<QuantizedType>(sType.getElementType()).getExpressedType();
223  if (llvm::isa<RankedTensorType>(quantizedType)) {
224  return RankedTensorType::get(sType.getShape(), expressedType);
225  }
226  if (llvm::isa<UnrankedTensorType>(quantizedType)) {
227  return UnrankedTensorType::get(expressedType);
228  }
229  if (llvm::isa<VectorType>(quantizedType)) {
230  return VectorType::get(sType.getShape(), expressedType);
231  }
232  }
233 
234  return nullptr;
235 }
236 
238  Type expressedQuantizedType = castFromExpressedType(candidateType);
239  if (!expressedQuantizedType) {
240  return nullptr;
241  }
242  return QuantizedType::castToStorageType(expressedQuantizedType);
243 }
244 
245 AnyQuantizedType AnyQuantizedType::get(unsigned flags, Type storageType,
246  Type expressedType,
247  int64_t storageTypeMin,
248  int64_t storageTypeMax) {
249  return Base::get(storageType.getContext(), flags, storageType, expressedType,
250  storageTypeMin, storageTypeMax);
251 }
252 
255  unsigned flags, Type storageType,
256  Type expressedType, int64_t storageTypeMin,
257  int64_t storageTypeMax) {
258  return Base::getChecked(emitError, storageType.getContext(), flags,
259  storageType, expressedType, storageTypeMin,
260  storageTypeMax);
261 }
262 
263 LogicalResult
265  unsigned flags, Type storageType,
266  Type expressedType, int64_t storageTypeMin,
267  int64_t storageTypeMax) {
268  if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
269  expressedType, storageTypeMin,
270  storageTypeMax))) {
271  return failure();
272  }
273 
274  // Verify that the expressed type is floating point.
275  // If this restriction is ever eliminated, the parser/printer must be
276  // extended.
277  if (expressedType && !llvm::isa<FloatType>(expressedType))
278  return emitError() << "expressed type must be floating point";
279 
280  return success();
281 }
282 
284  Type expressedType, double scale,
285  int64_t zeroPoint,
286  int64_t storageTypeMin,
287  int64_t storageTypeMax) {
288  return Base::get(storageType.getContext(), flags, storageType, expressedType,
289  scale, zeroPoint, storageTypeMin, storageTypeMax);
290 }
291 
293  function_ref<InFlightDiagnostic()> emitError, unsigned flags,
294  Type storageType, Type expressedType, double scale, int64_t zeroPoint,
295  int64_t storageTypeMin, int64_t storageTypeMax) {
296  return Base::getChecked(emitError, storageType.getContext(), flags,
297  storageType, expressedType, scale, zeroPoint,
298  storageTypeMin, storageTypeMax);
299 }
300 
302  function_ref<InFlightDiagnostic()> emitError, unsigned flags,
303  Type storageType, Type expressedType, double scale, int64_t zeroPoint,
304  int64_t storageTypeMin, int64_t storageTypeMax) {
305  if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
306  expressedType, storageTypeMin,
307  storageTypeMax))) {
308  return failure();
309  }
310 
311  // Uniform quantization requires fully expressed parameters, including
312  // expressed type.
313  if (!expressedType)
314  return emitError() << "uniform quantization requires expressed type";
315 
316  // Verify that the expressed type is floating point.
317  // If this restriction is ever eliminated, the parser/printer must be
318  // extended.
319  if (!llvm::isa<FloatType>(expressedType))
320  return emitError() << "expressed type must be floating point";
321 
322  // Verify scale.
323  double minScale = getMinScale(expressedType);
324  double maxScale = getMaxScale(expressedType);
325  if (scale < minScale || scale > maxScale)
326  return emitError() << "scale out of expressed type range [" << minScale
327  << ", " << maxScale << "]";
328 
329  return success();
330 }
331 
332 double UniformQuantizedType::getScale() const { return getImpl()->scale; }
333 
335  return getImpl()->zeroPoint;
336 }
337 
339  unsigned flags, Type storageType, Type expressedType,
340  ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
341  int32_t quantizedDimension, int64_t storageTypeMin,
342  int64_t storageTypeMax) {
343  return Base::get(storageType.getContext(), flags, storageType, expressedType,
344  scales, zeroPoints, quantizedDimension, storageTypeMin,
345  storageTypeMax);
346 }
347 
349  function_ref<InFlightDiagnostic()> emitError, unsigned flags,
350  Type storageType, Type expressedType, ArrayRef<double> scales,
351  ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
352  int64_t storageTypeMin, int64_t storageTypeMax) {
353  return Base::getChecked(emitError, storageType.getContext(), flags,
354  storageType, expressedType, scales, zeroPoints,
355  quantizedDimension, storageTypeMin, storageTypeMax);
356 }
357 
359  function_ref<InFlightDiagnostic()> emitError, unsigned flags,
360  Type storageType, Type expressedType, ArrayRef<double> scales,
361  ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
362  int64_t storageTypeMin, int64_t storageTypeMax) {
363  if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
364  expressedType, storageTypeMin,
365  storageTypeMax))) {
366  return failure();
367  }
368 
369  // Uniform quantization requires fully expressed parameters, including
370  // expressed type.
371  if (!expressedType)
372  return emitError() << "uniform quantization requires expressed type";
373 
374  // Verify that the expressed type is floating point.
375  // If this restriction is ever eliminated, the parser/printer must be
376  // extended.
377  if (!llvm::isa<FloatType>(expressedType))
378  return emitError() << "expressed type must be floating point";
379 
380  // Ensure that the number of scales and zeroPoints match.
381  if (scales.size() != zeroPoints.size())
382  return emitError() << "illegal number of scales and zeroPoints: "
383  << scales.size() << ", " << zeroPoints.size();
384 
385  // Verify scale.
386  double minScale = getMinScale(expressedType);
387  double maxScale = getMaxScale(expressedType);
388  for (double scale : scales) {
389  if (scale < minScale || scale > maxScale)
390  return emitError() << "scale out of expressed type range [" << minScale
391  << ", " << maxScale << "]";
392  }
393 
394  // Verify quantized dimension.
395  if (quantizedDimension < 0)
396  return emitError() << "illegal quantized dimension: " << quantizedDimension;
397 
398  return success();
399 }
400 
402  return getImpl()->getScales();
403 }
404 
406  return getImpl()->getZeroPoints();
407 }
408 
410  return getImpl()->quantizedDimension;
411 }
412 
414  double min, double max) {
415  return Base::get(expressedType.getContext(), expressedType, min, max);
416 }
417 
419  function_ref<InFlightDiagnostic()> emitError, Type expressedType,
420  double min, double max) {
421  return Base::getChecked(emitError, expressedType.getContext(), expressedType,
422  min, max);
423 }
424 
426  function_ref<InFlightDiagnostic()> emitError, Type expressedType,
427  double min, double max) {
428  // Verify that the expressed type is floating point.
429  // If this restriction is ever eliminated, the parser/printer must be
430  // extended.
431  if (!llvm::isa<FloatType>(expressedType))
432  return emitError() << "expressed type must be floating point";
433  if (max <= min)
434  return emitError() << "illegal min and max: (" << min << ":" << max << ")";
435 
436  return success();
437 }
438 
439 double CalibratedQuantizedType::getMin() const { return getImpl()->min; }
440 
441 double CalibratedQuantizedType::getMax() const { return getImpl()->max; }
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:123
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
A quantized type that maps storage to/from expressed types in an unspecified way.
Definition: QuantTypes.h:200
static AnyQuantizedType get(unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
Definition: QuantTypes.cpp:245
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
Definition: QuantTypes.cpp:264
static AnyQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
Definition: QuantTypes.cpp:254
A quantized type that infers its range from given min/max values.
Definition: QuantTypes.h:391
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double min, double max)
Verifies construction invariants and issues errors/warnings.
Definition: QuantTypes.cpp:425
static CalibratedQuantizedType get(Type expressedType, double min, double max)
Gets an instance of the type with all parameters specified but not checked.
Definition: QuantTypes.cpp:413
static CalibratedQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double min, double max)
Gets an instance of the type with all specified parameters checked.
Definition: QuantTypes.cpp:418
Base class for all quantized types known to this dialect.
Definition: QuantTypes.h:49
Type getExpressedType() const
Gets the original expressed type that this quantized type approximates.
Definition: QuantTypes.cpp:109
bool hasStorageTypeBounds() const
Return whether the storage type has explicit min or max boundaries different from the minimum and max...
Definition: QuantTypes.cpp:92
static Type castToStorageType(Type quantizedType)
Casts from a type based on a QuantizedType to a corresponding type based on the storageType (returns ...
Definition: QuantTypes.cpp:155
Type castExpressedToStorageType(Type candidateType)
Casts from a type based on the expressedType to the equivalent type based on storageType by way of th...
Definition: QuantTypes.cpp:237
static Type castToExpressedType(Type quantizedType)
Casts from a type based on QuantizedType to a corresponding type based on the expressedType (returns ...
Definition: QuantTypes.cpp:210
static QuantizedType getQuantizedElementType(Type primitiveOrContainerType)
Returns the element type as a QuantizedType or nullptr if it is not a quantized type.
Definition: QuantTypes.cpp:122
unsigned getFlags() const
Gets the flags associated with this type.
Definition: QuantTypes.cpp:39
int64_t getStorageTypeMax() const
The maximum value that storageType can take.
Definition: QuantTypes.cpp:88
unsigned getStorageTypeIntegralWidth() const
Gets the integral bit width that the underlying storage type can exactly represent.
Definition: QuantTypes.cpp:103
static bool classof(Type type)
Support method to enable LLVM-style type casting.
Definition: QuantTypes.cpp:43
Type castFromStorageType(Type candidateType)
Casts from a type based on the storageType to a corresponding type based on this type (returns nullpt...
Definition: QuantTypes.cpp:131
int64_t getStorageTypeMin() const
The minimum value that storageType can take.
Definition: QuantTypes.cpp:84
Type getStorageType() const
Gets the underlying type used for to store values.
Definition: QuantTypes.cpp:80
Type castFromExpressedType(Type candidateType)
Casts from a type based on the expressedType to a corresponding type based on this type (returns null...
Definition: QuantTypes.cpp:182
bool isCompatibleExpressedType(Type candidateExpressedType)
Returns whether the candidateExpressedType is a match for this QuantizedType.
Definition: QuantTypes.cpp:113
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Definition: QuantTypes.cpp:48
Represents per-axis (also known as per-channel quantization).
Definition: QuantTypes.h:321
static UniformQuantizedPerAxisType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
Definition: QuantTypes.cpp:348
static UniformQuantizedPerAxisType get(unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
Definition: QuantTypes.cpp:338
int32_t getQuantizedDimension() const
Specifies the dimension of the Tensor's shape that the scales and zero_points correspond to.
Definition: QuantTypes.cpp:409
ArrayRef< int64_t > getZeroPoints() const
Gets the storage values corresponding to the real value 0 in the affine equation.
Definition: QuantTypes.cpp:405
ArrayRef< double > getScales() const
Gets the quantization scales.
Definition: QuantTypes.cpp:401
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
Definition: QuantTypes.cpp:358
Represents a family of uniform, quantized types.
Definition: QuantTypes.h:261
double getScale() const
Gets the scale term.
Definition: QuantTypes.cpp:332
int64_t getZeroPoint() const
Gets the storage value corresponding to the real value 0 in the affine equation.
Definition: QuantTypes.cpp:334
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
Definition: QuantTypes.cpp:301
static UniformQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
Definition: QuantTypes.cpp:292
static UniformQuantizedType get(unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
Definition: QuantTypes.cpp:283
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...