MLIR  17.0.0git
QuantTypes.h
Go to the documentation of this file.
1 //===- QuantTypes.h - Quantization Ops and Types ----------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_QUANT_QUANTTYPES_H
10 #define MLIR_DIALECT_QUANT_QUANTTYPES_H
11 
12 #include "mlir/IR/Attributes.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/BuiltinTypes.h"
15 #include "mlir/IR/Dialect.h"
16 #include "mlir/IR/OpDefinition.h"
17 #include "mlir/IR/Types.h"
18 #include "llvm/Support/MathExtras.h"
19 
20 namespace mlir {
21 namespace quant {
22 
23 class QuantizedIntegerType;
24 
25 namespace detail {
26 
32 
33 } // namespace detail
34 
35 /// Enumeration of bit-mapped flags related to quantized types.
36 namespace QuantizationFlags {
37 enum FlagValue {
38  // Indicates that the storage type should be interpreted as a signed
39  // integer. The default is to interpret it as an unsigned value.
40  Signed = 1,
41 };
42 } // namespace QuantizationFlags
43 
44 /// Base class for all quantized types known to this dialect.
45 /// All quantized types have:
46 /// - storageType: The (narrower) numeric type that is being used to
47 /// approximate some expressed type.
48 /// - expressedType: The type that is being approximated.
49 ///
50 /// The base class provides generic support for manipulating the types based
51 /// on these fields.
52 class QuantizedType : public Type {
53 public:
55  using Type::Type;
56 
57  /// The maximum number of bits supported for storage types.
58  static constexpr unsigned MaxStorageBits = 32;
59 
61  unsigned flags, Type storageType,
62  Type expressedType, int64_t storageTypeMin,
63  int64_t storageTypeMax);
64 
65  /// Support method to enable LLVM-style type casting.
66  static bool classof(Type type);
67 
68  /// Gets the minimum possible stored by a storageType. storageTypeMin must
69  /// be greater than or equal to this value.
71  unsigned integralWidth) {
72  if (isSigned) {
73  return llvm::minIntN(integralWidth);
74  }
75  return 0;
76  }
77 
78  /// Gets the maximum possible stored by a storageType. storageTypeMax must
79  /// be less than or equal to this value.
81  unsigned integralWidth) {
82  if (isSigned) {
83  return llvm::maxIntN(integralWidth);
84  }
85  return llvm::maxUIntN(integralWidth);
86  }
87 
88  /// Gets the original expressed type that this quantized type approximates.
89  /// Note that this presumes that the quantized type was always derived from
90  /// a floating point type, which in the broadest definition, is not true (i.e.
91  /// it could be some form of integral, fixed type or affine type in its own
92  /// right); however, at the high level, no examples of such usage are
93  /// presently known and the restriction serves some useful purposes (such as
94  /// always being able to reverse a transformation or measure error). In most
95  /// cases, this will be f32.
96  Type getExpressedType() const;
97 
98  /// Gets the flags associated with this type. Typically a more specific
99  /// accessor is appropriate.
100  unsigned getFlags() const;
101 
102  // Convenience helpers.
103  /// Whether the storage type should be interpreted as a signed quantity
104  /// (true) or an unsigned value (false).
105  bool isSigned() const {
106  return (getFlags() & QuantizationFlags::Signed) ==
108  }
109 
110  /// Gets the underlying type used for to store values. Note that this may
111  /// be signed or unsigned. Use the isSigned() accessor to differentiate.
112  Type getStorageType() const;
113 
114  /// The minimum value that storageType can take.
115  int64_t getStorageTypeMin() const;
116 
117  /// The maximum value that storageType can take.
118  int64_t getStorageTypeMax() const;
119 
120  /// Gets the integral bit width that the underlying storage type can exactly
121  /// represent. For integral storage types, this will just be their width.
122  unsigned getStorageTypeIntegralWidth() const;
123 
124  /// Returns whether the candidateExpressedType is a match for this
125  /// QuantizedType. This will be true if the candidate type is either a
126  /// primitive type or a container type whose element type equals this
127  /// QuantizedType's expressed type.
128  /// Examples of compatible candidateExpressedType:
129  /// !quant.uniform<i8:f32, 1.0> =~ f32
130  /// !quant.uniform<i8:f32, 1.0> =~ tensor<4xf32>
131  bool isCompatibleExpressedType(Type candidateExpressedType);
132 
133  /// Returns the element type as a QuantizedType or nullptr if it is not
134  /// a quantized type. If the type is primitive, returns that. If it is a
135  /// container (vector/tensor), return the element type.
136  /// Examples:
137  /// !quant.uniform<i8:f32, 1.0> -> !quant.uniform<i8:f32, 1.0>
138  /// tensor<4x!quant.uniform<i8:f32, 1.0> -> quant.uniform<i8:f32, 1.0>
139  static QuantizedType getQuantizedElementType(Type primitiveOrContainerType);
140 
141  /// Casts from a type based on the storageType to a corresponding type based
142  /// on this type (returns nullptr if the cast is not valid).
143  /// Examples:
144  /// i8 -> !quant.uniform<i8:f32, 1.0>
145  /// tensor<4xi8> -> tensor<4x!quant.uniform<i8:f32, 1.0}>>
146  /// vector<4xi8> -> vector<4x!quant.uniform<i8:f32, 1.0>>
147  Type castFromStorageType(Type candidateType);
148 
149  /// Casts from a type based on a QuantizedType to a corresponding type based
150  /// on the storageType (returns nullptr if the cast is not valid).
151  /// This is the inverse of castFromStorageType().
152  static Type castToStorageType(Type quantizedType);
153 
154  /// Casts from a type based on the expressedType to a corresponding type based
155  /// on this type (returns nullptr if the cast is not valid).
156  /// Examples:
157  /// f32 -> !quant.uniform<i8:f32, 1.0>
158  /// tensor<4xf32> -> tensor<4x!quant.uniform<i8:f32, 1.0>>
159  /// vector<4xf32> -> vector<4x!quant.uniform<i8:f32, 1.0>>
160  Type castFromExpressedType(Type candidateType);
161 
162  /// Casts from a type based on QuantizedType to a corresponding type based
163  /// on the expressedType (returns nullptr if the cast is not valid).
164  /// This is the inverse of castFromExpressedType.
165  static Type castToExpressedType(Type quantizedType);
166 
167  /// Casts from a type based on the expressedType to the equivalent type
168  /// based on storageType by way of this QuantizedType. Equivalent to:
169  /// QuantizedType::castToStorageType(castFromExpressedType(candidateType))
170  /// (but with validity checks).
171  /// Example (for this = !quant.uniform<i8:f32, 1.0>):
172  /// tensor<4xf32> -> tensor<4xi8>
173  Type castExpressedToStorageType(Type candidateType);
174 
175 private:
176  /// Hide the following methods inherited from `Type`. It is almost certainly
177  /// a bug to call them from a `QuantizedType` object. Users should call
178  /// `getStorageType` or `getExpressedType` to get the underlying types
179  /// they want to inspect.
180  using Type::isBF16;
181  using Type::isF16;
182  using Type::isF32;
183  using Type::isF64;
184  using Type::isIndex;
185  using Type::isInteger;
186 };
187 
188 /// A quantized type that maps storage to/from expressed types in an
189 /// unspecified way.
190 ///
191 /// Typical syntax:
192 /// quant.any<i8:f32>
193 /// quant.any<i8>
194 /// quant.any<i8<-16,15>>
195 ///
196 /// Note that for the any type, the expressed type is optional.
198  : public Type::TypeBase<AnyQuantizedType, QuantizedType,
199  detail::AnyQuantizedTypeStorage> {
200 public:
201  using Base::Base;
202  using Base::getChecked;
203 
204  /// Gets an instance of the type with all parameters specified but not
205  /// checked.
206  static AnyQuantizedType get(unsigned flags, Type storageType,
207  Type expressedType, int64_t storageTypeMin,
208  int64_t storageTypeMax);
209 
210  /// Gets an instance of the type with all specified parameters checked.
211  /// Returns a nullptr convertible type on failure.
212  static AnyQuantizedType
214  Type storageType, Type expressedType, int64_t storageTypeMin,
215  int64_t storageTypeMax);
216 
217  /// Verifies construction invariants and issues errors/warnings.
219  unsigned flags, Type storageType,
220  Type expressedType, int64_t storageTypeMin,
221  int64_t storageTypeMax);
222 };
223 
224 /// Represents a family of uniform, quantized types.
225 ///
226 /// Each instance of this type expresses a mapping between real values (most
227 /// often expressed in floating point f32) and quantized values (either fixed
228 /// point or affine).
229 ///
230 /// The relationship is:
231 /// real_value = scale * (quantized_value - zero_point)
232 ///
233 /// It is used as part of high level graph transformations that have the goal
234 /// of re-expressing parts of a computation in terms of this common form for
235 /// more efficient execution at runtime. In addition, it is designed to be
236 /// expressive enough to facilitate lowering to precise types and operations
237 /// in target hardware.
238 ///
239 /// As a high-level type, focused on intermediate passes, this type holds
240 /// opinions consistent with high-level usage. If lowering math kernels below
241 /// the high level arithmetic ops (i.e. to LLVM IR or hardware specific
242 /// instruction sets), it is expected that the information expressed here
243 /// will be used to drive low level codegen and target specific type selection,
244 /// but this type will likely be erased in the process.
245 ///
246 /// Syntax synopsis:
247 /// Per-layer, all parameters expressed:
248 /// !quant<uniform[StorageType:ExpressedType]{Scale:ZeroPoint}>
249 /// Per-layer, optional parameters omitted:
250 /// !quant<uniform[StorageType]{Scale}>
251 ///
252 /// StorageType: 'i'|'u' NumBits
253 /// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
254 /// Scale: A legal double value
255 /// ZeroPoint: An integer value
257  : public Type::TypeBase<UniformQuantizedType, QuantizedType,
258  detail::UniformQuantizedTypeStorage> {
259 public:
260  using Base::Base;
261  using Base::getChecked;
262 
263  /// Gets an instance of the type with all parameters specified but not
264  /// checked.
265  static UniformQuantizedType get(unsigned flags, Type storageType,
266  Type expressedType, double scale,
267  int64_t zeroPoint, int64_t storageTypeMin,
268  int64_t storageTypeMax);
269 
270  /// Gets an instance of the type with all specified parameters checked.
271  /// Returns a nullptr convertible type on failure.
272  static UniformQuantizedType
274  Type storageType, Type expressedType, double scale,
275  int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax);
276 
277  /// Verifies construction invariants and issues errors/warnings.
279  unsigned flags, Type storageType,
280  Type expressedType, double scale,
281  int64_t zeroPoint, int64_t storageTypeMin,
282  int64_t storageTypeMax);
283 
284  /// Gets the scale term. The scale designates the difference between the real
285  /// values corresponding to consecutive quantized values differing by 1.
286  double getScale() const;
287 
288  /// Gets the storage value corresponding to the real value 0 in the affine
289  /// equation.
290  int64_t getZeroPoint() const;
291 
292  // Fixed point values are real numbers divided by a scale.
293  // Currently, only signed storage types are treated as fixed point.
294  // A fixed point value can be obtained from an affine value by subtracting
295  // the zeroPoint.
296  // In the future, this may be explicit versus implied by type and zeroPoint.
297  bool isFixedPoint() const { return isSigned() && getZeroPoint() == 0; }
298 };
299 
300 /// Represents per-axis (also known as per-channel quantization).
301 ///
302 /// Syntax synopsis:
303 /// Per-axis, all parameters expressed:
304 /// !quant<uniform[StorageType:ExpressedType:QuantizedDim]{QuantParams}>
305 /// Per-axis, optional parameters omitted:
306 /// !quant<uniform[StorageType]{Scale}>
307 ///
308 /// StorageType: 'i'|'u' NumBits
309 /// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
310 /// QuantizedDim: An integer value
311 /// QuantParams: (Scale ':' ZeroPoint)+
312 /// Scale: A legal double value
313 /// ZeroPoint: An integer value
315  : public Type::TypeBase<UniformQuantizedPerAxisType, QuantizedType,
316  detail::UniformQuantizedPerAxisTypeStorage> {
317 public:
318  using Base::Base;
319  using Base::getChecked;
320 
321  /// Gets an instance of the type with all parameters specified but not
322  /// checked.
324  get(unsigned flags, Type storageType, Type expressedType,
325  ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
326  int32_t quantizedDimension, int64_t storageTypeMin,
327  int64_t storageTypeMax);
328 
329  /// Gets an instance of the type with all specified parameters checked.
330  /// Returns a nullptr convertible type on failure.
333  Type storageType, Type expressedType, ArrayRef<double> scales,
334  ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
335  int64_t storageTypeMin, int64_t storageTypeMax);
336 
337  /// Verifies construction invariants and issues errors/warnings.
339  unsigned flags, Type storageType,
340  Type expressedType, ArrayRef<double> scales,
341  ArrayRef<int64_t> zeroPoints,
342  int32_t quantizedDimension,
343  int64_t storageTypeMin, int64_t storageTypeMax);
344 
345  /// Gets the quantization scales. The scales designate the difference between
346  /// the real values corresponding to consecutive quantized values differing
347  /// by 1. The ith scale corresponds to the ith slice in the
348  /// quantized_dimension.
349  ArrayRef<double> getScales() const;
350 
351  /// Gets the storage values corresponding to the real value 0 in the affine
352  /// equation. The ith zero point corresponds to the ith slice in the
353  /// quantized_dimension.
355 
356  /// Specifies the dimension of the Tensor's shape that the scales and
357  /// zero_points correspond to. For example, a tensor t, with dims=[4, 3, 2, 1]
358  /// with quantization params:
359  /// scales=[1.0, 2.0, 3.0], zeroPoints=[1, 2, 3], quantizedDimension=1
360  /// will be quantized across the second dimension of t.
361  /// t[:, 0, :, :] will have scale[0]=1.0, zero_point[0]=1
362  /// t[:, 1, :, :] will have scale[1]=2.0, zero_point[0]=2
363  /// t[:, 2, :, :] will have scale[2]=3.0, zero_point[0]=3
364  int32_t getQuantizedDimension() const;
365 
366  /// Fixed point values are real numbers divided by a scale.
367  /// Currently, only signed storage types are treated as fixed point.
368  /// A fixed point value can be obtained from an affine value by subtracting
369  /// the zeroPoint.
370  /// In the future, this may be explicit versus implied by type and zeroPoint.
371  bool isFixedPoint() const {
372  if (!isSigned())
373  return false;
374  return !llvm::is_contained(getZeroPoints(), 0);
375  }
376 };
377 
378 /// A quantized type that infers its range from given min/max values.
379 ///
380 /// Typical syntax:
381 /// quant.calibrated<f32<-0.922,0.981>>
383  : public Type::TypeBase<CalibratedQuantizedType, QuantizedType,
384  detail::CalibratedQuantizedTypeStorage> {
385 public:
386  using Base::Base;
387  using Base::getChecked;
388 
389  /// Gets an instance of the type with all parameters specified but not
390  /// checked.
391  static CalibratedQuantizedType get(Type expressedType, double min,
392  double max);
393 
394  /// Gets an instance of the type with all specified parameters checked.
395  /// Returns a nullptr convertible type on failure.
398  double min, double max);
399 
400  /// Verifies construction invariants and issues errors/warnings.
402  Type expressedType, double min, double max);
403  double getMin() const;
404  double getMax() const;
405 };
406 
407 } // namespace quant
408 } // namespace mlir
409 
410 #endif // MLIR_DIALECT_QUANT_QUANTTYPES_H
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:45
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
Definition: Types.cpp:52
bool isIndex() const
Definition: Types.cpp:49
constexpr Type()=default
bool isF32() const
Definition: Types.cpp:44
bool isF16() const
Definition: Types.cpp:43
bool isBF16() const
Definition: Types.cpp:42
Utility class for implementing users of storage classes uniqued by a StorageUniquer.
A quantized type that maps storage to/from expressed types in an unspecified way.
Definition: QuantTypes.h:199
static AnyQuantizedType get(unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
Definition: QuantTypes.cpp:216
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
Definition: QuantTypes.cpp:235
static AnyQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
Definition: QuantTypes.cpp:225
A quantized type that infers its range from given min/max values.
Definition: QuantTypes.h:384
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double min, double max)
Verifies construction invariants and issues errors/warnings.
Definition: QuantTypes.cpp:383
static CalibratedQuantizedType get(Type expressedType, double min, double max)
Gets an instance of the type with all parameters specified but not checked.
Definition: QuantTypes.cpp:370
static CalibratedQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double min, double max)
Gets an instance of the type with all specified parameters checked.
Definition: QuantTypes.cpp:375
Base class for all quantized types known to this dialect.
Definition: QuantTypes.h:52
Type getExpressedType() const
Gets the original expressed type that this quantized type approximates.
Definition: QuantTypes.cpp:81
static constexpr unsigned MaxStorageBits
The maximum number of bits supported for storage types.
Definition: QuantTypes.h:58
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Definition: QuantTypes.cpp:32
static Type castToStorageType(Type quantizedType)
Casts from a type based on a QuantizedType to a corresponding type based on the storageType (returns ...
Definition: QuantTypes.cpp:126
Type castExpressedToStorageType(Type candidateType)
Casts from a type based on the expressedType to the equivalent type based on storageType by way of th...
Definition: QuantTypes.cpp:208
static Type castToExpressedType(Type quantizedType)
Casts from a type based on QuantizedType to a corresponding type based on the expressedType (returns ...
Definition: QuantTypes.cpp:181
bool isSigned() const
Whether the storage type should be interpreted as a signed quantity (true) or an unsigned value (fals...
Definition: QuantTypes.h:105
static QuantizedType getQuantizedElementType(Type primitiveOrContainerType)
Returns the element type as a QuantizedType or nullptr if it is not a quantized type.
Definition: QuantTypes.cpp:94
unsigned getFlags() const
Gets the flags associated with this type.
Definition: QuantTypes.cpp:23
int64_t getStorageTypeMax() const
The maximum value that storageType can take.
Definition: QuantTypes.cpp:71
static int64_t getDefaultMaximumForInteger(bool isSigned, unsigned integralWidth)
Gets the maximum possible stored by a storageType.
Definition: QuantTypes.h:80
unsigned getStorageTypeIntegralWidth() const
Gets the integral bit width that the underlying storage type can exactly represent.
Definition: QuantTypes.cpp:75
static bool classof(Type type)
Support method to enable LLVM-style type casting.
Definition: QuantTypes.cpp:27
Type castFromStorageType(Type candidateType)
Casts from a type based on the storageType to a corresponding type based on this type (returns nullpt...
Definition: QuantTypes.cpp:103
int64_t getStorageTypeMin() const
The minimum value that storageType can take.
Definition: QuantTypes.cpp:67
static int64_t getDefaultMinimumForInteger(bool isSigned, unsigned integralWidth)
Gets the minimum possible stored by a storageType.
Definition: QuantTypes.h:70
Type getStorageType() const
Gets the underlying type used for to store values.
Definition: QuantTypes.cpp:63
Type castFromExpressedType(Type candidateType)
Casts from a type based on the expressedType to a corresponding type based on this type (returns null...
Definition: QuantTypes.cpp:153
bool isCompatibleExpressedType(Type candidateExpressedType)
Returns whether the candidateExpressedType is a match for this QuantizedType.
Definition: QuantTypes.cpp:85
Represents per-axis (also known as per-channel quantization).
Definition: QuantTypes.h:316
static UniformQuantizedPerAxisType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
Definition: QuantTypes.cpp:313
bool isFixedPoint() const
Fixed point values are real numbers divided by a scale.
Definition: QuantTypes.h:371
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
Definition: QuantTypes.cpp:323
static UniformQuantizedPerAxisType get(unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
Definition: QuantTypes.cpp:303
int32_t getQuantizedDimension() const
Specifies the dimension of the Tensor's shape that the scales and zero_points correspond to.
Definition: QuantTypes.cpp:366
ArrayRef< int64_t > getZeroPoints() const
Gets the storage values corresponding to the real value 0 in the affine equation.
Definition: QuantTypes.cpp:362
ArrayRef< double > getScales() const
Gets the quantization scales.
Definition: QuantTypes.cpp:358
Represents a family of uniform, quantized types.
Definition: QuantTypes.h:258
double getScale() const
Gets the scale term.
Definition: QuantTypes.cpp:297
int64_t getZeroPoint() const
Gets the storage value corresponding to the real value 0 in the affine equation.
Definition: QuantTypes.cpp:299
static UniformQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
Definition: QuantTypes.cpp:261
static UniformQuantizedType get(unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
Definition: QuantTypes.cpp:252
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
Definition: QuantTypes.cpp:270
This header declares functions that assit transformations in the MemRef dialect.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26