MLIR  18.0.0git
QuantTypes.h
Go to the documentation of this file.
1 //===- QuantTypes.h - Quantization Ops and Types ----------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_QUANT_QUANTTYPES_H
10 #define MLIR_DIALECT_QUANT_QUANTTYPES_H
11 
12 #include "mlir/IR/Attributes.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/BuiltinTypes.h"
15 #include "mlir/IR/Dialect.h"
16 #include "mlir/IR/OpDefinition.h"
17 #include "mlir/IR/Types.h"
18 #include "llvm/Support/MathExtras.h"
19 
20 namespace mlir {
21 namespace quant {
22 namespace detail {
23 
29 
30 } // namespace detail
31 
32 /// Enumeration of bit-mapped flags related to quantized types.
33 namespace QuantizationFlags {
34 enum FlagValue {
35  // Indicates that the storage type should be interpreted as a signed
36  // integer. The default is to interpret it as an unsigned value.
37  Signed = 1,
38 };
39 } // namespace QuantizationFlags
40 
41 /// Base class for all quantized types known to this dialect.
42 /// All quantized types have:
43 /// - storageType: The (narrower) numeric type that is being used to
44 /// approximate some expressed type.
45 /// - expressedType: The type that is being approximated.
46 ///
47 /// The base class provides generic support for manipulating the types based
48 /// on these fields.
49 class QuantizedType : public Type {
50 public:
52  using Type::Type;
53 
54  /// The maximum number of bits supported for storage types.
55  static constexpr unsigned MaxStorageBits = 32;
56 
58  unsigned flags, Type storageType,
59  Type expressedType, int64_t storageTypeMin,
60  int64_t storageTypeMax);
61 
62  /// Support method to enable LLVM-style type casting.
63  static bool classof(Type type);
64 
65  /// Gets the minimum possible stored by a storageType. storageTypeMin must
66  /// be greater than or equal to this value.
68  unsigned integralWidth) {
69  if (isSigned) {
70  return llvm::minIntN(integralWidth);
71  }
72  return 0;
73  }
74 
75  /// Gets the maximum possible stored by a storageType. storageTypeMax must
76  /// be less than or equal to this value.
78  unsigned integralWidth) {
79  if (isSigned) {
80  return llvm::maxIntN(integralWidth);
81  }
82  return llvm::maxUIntN(integralWidth);
83  }
84 
85  /// Gets the original expressed type that this quantized type approximates.
86  /// Note that this presumes that the quantized type was always derived from
87  /// a floating point type, which in the broadest definition, is not true (i.e.
88  /// it could be some form of integral, fixed type or affine type in its own
89  /// right); however, at the high level, no examples of such usage are
90  /// presently known and the restriction serves some useful purposes (such as
91  /// always being able to reverse a transformation or measure error). In most
92  /// cases, this will be f32.
93  Type getExpressedType() const;
94 
95  /// Gets the flags associated with this type. Typically a more specific
96  /// accessor is appropriate.
97  unsigned getFlags() const;
98 
99  // Convenience helpers.
100  /// Whether the storage type should be interpreted as a signed quantity
101  /// (true) or an unsigned value (false).
102  bool isSigned() const {
103  return (getFlags() & QuantizationFlags::Signed) ==
105  }
106 
107  /// Gets the underlying type used for to store values. Note that this may
108  /// be signed or unsigned. Use the isSigned() accessor to differentiate.
109  Type getStorageType() const;
110 
111  /// The minimum value that storageType can take.
112  int64_t getStorageTypeMin() const;
113 
114  /// The maximum value that storageType can take.
115  int64_t getStorageTypeMax() const;
116 
117  /// Gets the integral bit width that the underlying storage type can exactly
118  /// represent. For integral storage types, this will just be their width.
119  unsigned getStorageTypeIntegralWidth() const;
120 
121  /// Returns whether the candidateExpressedType is a match for this
122  /// QuantizedType. This will be true if the candidate type is either a
123  /// primitive type or a container type whose element type equals this
124  /// QuantizedType's expressed type.
125  /// Examples of compatible candidateExpressedType:
126  /// !quant.uniform<i8:f32, 1.0> =~ f32
127  /// !quant.uniform<i8:f32, 1.0> =~ tensor<4xf32>
128  bool isCompatibleExpressedType(Type candidateExpressedType);
129 
130  /// Returns the element type as a QuantizedType or nullptr if it is not
131  /// a quantized type. If the type is primitive, returns that. If it is a
132  /// container (vector/tensor), return the element type.
133  /// Examples:
134  /// !quant.uniform<i8:f32, 1.0> -> !quant.uniform<i8:f32, 1.0>
135  /// tensor<4x!quant.uniform<i8:f32, 1.0> -> quant.uniform<i8:f32, 1.0>
136  static QuantizedType getQuantizedElementType(Type primitiveOrContainerType);
137 
138  /// Casts from a type based on the storageType to a corresponding type based
139  /// on this type (returns nullptr if the cast is not valid).
140  /// Examples:
141  /// i8 -> !quant.uniform<i8:f32, 1.0>
142  /// tensor<4xi8> -> tensor<4x!quant.uniform<i8:f32, 1.0}>>
143  /// vector<4xi8> -> vector<4x!quant.uniform<i8:f32, 1.0>>
144  Type castFromStorageType(Type candidateType);
145 
146  /// Casts from a type based on a QuantizedType to a corresponding type based
147  /// on the storageType (returns nullptr if the cast is not valid).
148  /// This is the inverse of castFromStorageType().
149  static Type castToStorageType(Type quantizedType);
150 
151  /// Casts from a type based on the expressedType to a corresponding type based
152  /// on this type (returns nullptr if the cast is not valid).
153  /// Examples:
154  /// f32 -> !quant.uniform<i8:f32, 1.0>
155  /// tensor<4xf32> -> tensor<4x!quant.uniform<i8:f32, 1.0>>
156  /// vector<4xf32> -> vector<4x!quant.uniform<i8:f32, 1.0>>
157  Type castFromExpressedType(Type candidateType);
158 
159  /// Casts from a type based on QuantizedType to a corresponding type based
160  /// on the expressedType (returns nullptr if the cast is not valid).
161  /// This is the inverse of castFromExpressedType.
162  static Type castToExpressedType(Type quantizedType);
163 
164  /// Casts from a type based on the expressedType to the equivalent type
165  /// based on storageType by way of this QuantizedType. Equivalent to:
166  /// QuantizedType::castToStorageType(castFromExpressedType(candidateType))
167  /// (but with validity checks).
168  /// Example (for this = !quant.uniform<i8:f32, 1.0>):
169  /// tensor<4xf32> -> tensor<4xi8>
170  Type castExpressedToStorageType(Type candidateType);
171 
172 private:
173  /// Hide the following methods inherited from `Type`. It is almost certainly
174  /// a bug to call them from a `QuantizedType` object. Users should call
175  /// `getStorageType` or `getExpressedType` to get the underlying types
176  /// they want to inspect.
177  using Type::isBF16;
178  using Type::isF16;
179  using Type::isF32;
180  using Type::isF64;
181  using Type::isIndex;
182  using Type::isInteger;
183 };
184 
185 /// A quantized type that maps storage to/from expressed types in an
186 /// unspecified way.
187 ///
188 /// Typical syntax:
189 /// quant.any<i8:f32>
190 /// quant.any<i8>
191 /// quant.any<i8<-16,15>>
192 ///
193 /// Note that for the any type, the expressed type is optional.
195  : public Type::TypeBase<AnyQuantizedType, QuantizedType,
196  detail::AnyQuantizedTypeStorage> {
197 public:
198  using Base::Base;
199  using Base::getChecked;
200 
201  /// Gets an instance of the type with all parameters specified but not
202  /// checked.
203  static AnyQuantizedType get(unsigned flags, Type storageType,
204  Type expressedType, int64_t storageTypeMin,
205  int64_t storageTypeMax);
206 
207  /// Gets an instance of the type with all specified parameters checked.
208  /// Returns a nullptr convertible type on failure.
209  static AnyQuantizedType
211  Type storageType, Type expressedType, int64_t storageTypeMin,
212  int64_t storageTypeMax);
213 
214  /// Verifies construction invariants and issues errors/warnings.
216  unsigned flags, Type storageType,
217  Type expressedType, int64_t storageTypeMin,
218  int64_t storageTypeMax);
219 };
220 
221 /// Represents a family of uniform, quantized types.
222 ///
223 /// Each instance of this type expresses a mapping between real values (most
224 /// often expressed in floating point f32) and quantized values (either fixed
225 /// point or affine).
226 ///
227 /// The relationship is:
228 /// real_value = scale * (quantized_value - zero_point)
229 ///
230 /// It is used as part of high level graph transformations that have the goal
231 /// of re-expressing parts of a computation in terms of this common form for
232 /// more efficient execution at runtime. In addition, it is designed to be
233 /// expressive enough to facilitate lowering to precise types and operations
234 /// in target hardware.
235 ///
236 /// As a high-level type, focused on intermediate passes, this type holds
237 /// opinions consistent with high-level usage. If lowering math kernels below
238 /// the high level arithmetic ops (i.e. to LLVM IR or hardware specific
239 /// instruction sets), it is expected that the information expressed here
240 /// will be used to drive low level codegen and target specific type selection,
241 /// but this type will likely be erased in the process.
242 ///
243 /// Syntax synopsis:
244 /// Per-layer, all parameters expressed:
245 /// !quant<uniform[StorageType:ExpressedType]{Scale:ZeroPoint}>
246 /// Per-layer, optional parameters omitted:
247 /// !quant<uniform[StorageType]{Scale}>
248 ///
249 /// StorageType: 'i'|'u' NumBits
250 /// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
251 /// Scale: A legal double value
252 /// ZeroPoint: An integer value
254  : public Type::TypeBase<UniformQuantizedType, QuantizedType,
255  detail::UniformQuantizedTypeStorage> {
256 public:
257  using Base::Base;
258  using Base::getChecked;
259 
260  /// Gets an instance of the type with all parameters specified but not
261  /// checked.
262  static UniformQuantizedType get(unsigned flags, Type storageType,
263  Type expressedType, double scale,
264  int64_t zeroPoint, int64_t storageTypeMin,
265  int64_t storageTypeMax);
266 
267  /// Gets an instance of the type with all specified parameters checked.
268  /// Returns a nullptr convertible type on failure.
269  static UniformQuantizedType
271  Type storageType, Type expressedType, double scale,
272  int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax);
273 
274  /// Verifies construction invariants and issues errors/warnings.
276  unsigned flags, Type storageType,
277  Type expressedType, double scale,
278  int64_t zeroPoint, int64_t storageTypeMin,
279  int64_t storageTypeMax);
280 
281  /// Gets the scale term. The scale designates the difference between the real
282  /// values corresponding to consecutive quantized values differing by 1.
283  double getScale() const;
284 
285  /// Gets the storage value corresponding to the real value 0 in the affine
286  /// equation.
287  int64_t getZeroPoint() const;
288 
289  // Fixed point values are real numbers divided by a scale.
290  // Currently, only signed storage types are treated as fixed point.
291  // A fixed point value can be obtained from an affine value by subtracting
292  // the zeroPoint.
293  // In the future, this may be explicit versus implied by type and zeroPoint.
294  bool isFixedPoint() const { return isSigned() && getZeroPoint() == 0; }
295 };
296 
297 /// Represents per-axis (also known as per-channel quantization).
298 ///
299 /// Syntax synopsis:
300 /// Per-axis, all parameters expressed:
301 /// !quant<uniform[StorageType:ExpressedType:QuantizedDim]{QuantParams}>
302 /// Per-axis, optional parameters omitted:
303 /// !quant<uniform[StorageType]{Scale}>
304 ///
305 /// StorageType: 'i'|'u' NumBits
306 /// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
307 /// QuantizedDim: An integer value
308 /// QuantParams: (Scale ':' ZeroPoint)+
309 /// Scale: A legal double value
310 /// ZeroPoint: An integer value
312  : public Type::TypeBase<UniformQuantizedPerAxisType, QuantizedType,
313  detail::UniformQuantizedPerAxisTypeStorage> {
314 public:
315  using Base::Base;
316  using Base::getChecked;
317 
318  /// Gets an instance of the type with all parameters specified but not
319  /// checked.
321  get(unsigned flags, Type storageType, Type expressedType,
322  ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
323  int32_t quantizedDimension, int64_t storageTypeMin,
324  int64_t storageTypeMax);
325 
326  /// Gets an instance of the type with all specified parameters checked.
327  /// Returns a nullptr convertible type on failure.
330  Type storageType, Type expressedType, ArrayRef<double> scales,
331  ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
332  int64_t storageTypeMin, int64_t storageTypeMax);
333 
334  /// Verifies construction invariants and issues errors/warnings.
336  unsigned flags, Type storageType,
337  Type expressedType, ArrayRef<double> scales,
338  ArrayRef<int64_t> zeroPoints,
339  int32_t quantizedDimension,
340  int64_t storageTypeMin, int64_t storageTypeMax);
341 
342  /// Gets the quantization scales. The scales designate the difference between
343  /// the real values corresponding to consecutive quantized values differing
344  /// by 1. The ith scale corresponds to the ith slice in the
345  /// quantized_dimension.
346  ArrayRef<double> getScales() const;
347 
348  /// Gets the storage values corresponding to the real value 0 in the affine
349  /// equation. The ith zero point corresponds to the ith slice in the
350  /// quantized_dimension.
352 
353  /// Specifies the dimension of the Tensor's shape that the scales and
354  /// zero_points correspond to. For example, a tensor t, with dims=[4, 3, 2, 1]
355  /// with quantization params:
356  /// scales=[1.0, 2.0, 3.0], zeroPoints=[1, 2, 3], quantizedDimension=1
357  /// will be quantized across the second dimension of t.
358  /// t[:, 0, :, :] will have scale[0]=1.0, zero_point[0]=1
359  /// t[:, 1, :, :] will have scale[1]=2.0, zero_point[0]=2
360  /// t[:, 2, :, :] will have scale[2]=3.0, zero_point[0]=3
361  int32_t getQuantizedDimension() const;
362 
363  /// Fixed point values are real numbers divided by a scale.
364  /// Currently, only signed storage types are treated as fixed point.
365  /// A fixed point value can be obtained from an affine value by subtracting
366  /// the zeroPoint.
367  /// In the future, this may be explicit versus implied by type and zeroPoint.
368  bool isFixedPoint() const {
369  if (!isSigned())
370  return false;
371  return !llvm::is_contained(getZeroPoints(), 0);
372  }
373 };
374 
375 /// A quantized type that infers its range from given min/max values.
376 ///
377 /// Typical syntax:
378 /// quant.calibrated<f32<-0.922,0.981>>
380  : public Type::TypeBase<CalibratedQuantizedType, QuantizedType,
381  detail::CalibratedQuantizedTypeStorage> {
382 public:
383  using Base::Base;
384  using Base::getChecked;
385 
386  /// Gets an instance of the type with all parameters specified but not
387  /// checked.
388  static CalibratedQuantizedType get(Type expressedType, double min,
389  double max);
390 
391  /// Gets an instance of the type with all specified parameters checked.
392  /// Returns a nullptr convertible type on failure.
395  double min, double max);
396 
397  /// Verifies construction invariants and issues errors/warnings.
399  Type expressedType, double min, double max);
400  double getMin() const;
401  double getMax() const;
402 };
403 
404 } // namespace quant
405 } // namespace mlir
406 
407 #endif // MLIR_DIALECT_QUANT_QUANTTYPES_H
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:52
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
Definition: Types.cpp:59
bool isIndex() const
Definition: Types.cpp:56
constexpr Type()=default
bool isF32() const
Definition: Types.cpp:51
bool isF16() const
Definition: Types.cpp:49
bool isBF16() const
Definition: Types.cpp:48
Utility class for implementing users of storage classes uniqued by a StorageUniquer.
A quantized type that maps storage to/from expressed types in an unspecified way.
Definition: QuantTypes.h:196
static AnyQuantizedType get(unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
Definition: QuantTypes.cpp:217
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
Definition: QuantTypes.cpp:236
static AnyQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
Definition: QuantTypes.cpp:226
A quantized type that infers its range from given min/max values.
Definition: QuantTypes.h:381
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double min, double max)
Verifies construction invariants and issues errors/warnings.
Definition: QuantTypes.cpp:384
static CalibratedQuantizedType get(Type expressedType, double min, double max)
Gets an instance of the type with all parameters specified but not checked.
Definition: QuantTypes.cpp:371
static CalibratedQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double min, double max)
Gets an instance of the type with all specified parameters checked.
Definition: QuantTypes.cpp:376
Base class for all quantized types known to this dialect.
Definition: QuantTypes.h:49
Type getExpressedType() const
Gets the original expressed type that this quantized type approximates.
Definition: QuantTypes.cpp:81
static constexpr unsigned MaxStorageBits
The maximum number of bits supported for storage types.
Definition: QuantTypes.h:55
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Definition: QuantTypes.cpp:32
static Type castToStorageType(Type quantizedType)
Casts from a type based on a QuantizedType to a corresponding type based on the storageType (returns ...
Definition: QuantTypes.cpp:127
Type castExpressedToStorageType(Type candidateType)
Casts from a type based on the expressedType to the equivalent type based on storageType by way of th...
Definition: QuantTypes.cpp:209
static Type castToExpressedType(Type quantizedType)
Casts from a type based on QuantizedType to a corresponding type based on the expressedType (returns ...
Definition: QuantTypes.cpp:182
bool isSigned() const
Whether the storage type should be interpreted as a signed quantity (true) or an unsigned value (fals...
Definition: QuantTypes.h:102
static QuantizedType getQuantizedElementType(Type primitiveOrContainerType)
Returns the element type as a QuantizedType or nullptr if it is not a quantized type.
Definition: QuantTypes.cpp:94
unsigned getFlags() const
Gets the flags associated with this type.
Definition: QuantTypes.cpp:23
int64_t getStorageTypeMax() const
The maximum value that storageType can take.
Definition: QuantTypes.cpp:71
static int64_t getDefaultMaximumForInteger(bool isSigned, unsigned integralWidth)
Gets the maximum possible stored by a storageType.
Definition: QuantTypes.h:77
unsigned getStorageTypeIntegralWidth() const
Gets the integral bit width that the underlying storage type can exactly represent.
Definition: QuantTypes.cpp:75
static bool classof(Type type)
Support method to enable LLVM-style type casting.
Definition: QuantTypes.cpp:27
Type castFromStorageType(Type candidateType)
Casts from a type based on the storageType to a corresponding type based on this type (returns nullpt...
Definition: QuantTypes.cpp:103
int64_t getStorageTypeMin() const
The minimum value that storageType can take.
Definition: QuantTypes.cpp:67
static int64_t getDefaultMinimumForInteger(bool isSigned, unsigned integralWidth)
Gets the minimum possible stored by a storageType.
Definition: QuantTypes.h:67
Type getStorageType() const
Gets the underlying type used for to store values.
Definition: QuantTypes.cpp:63
Type castFromExpressedType(Type candidateType)
Casts from a type based on the expressedType to a corresponding type based on this type (returns null...
Definition: QuantTypes.cpp:154
bool isCompatibleExpressedType(Type candidateExpressedType)
Returns whether the candidateExpressedType is a match for this QuantizedType.
Definition: QuantTypes.cpp:85
Represents per-axis (also known as per-channel quantization).
Definition: QuantTypes.h:313
static UniformQuantizedPerAxisType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
Definition: QuantTypes.cpp:314
bool isFixedPoint() const
Fixed point values are real numbers divided by a scale.
Definition: QuantTypes.h:368
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
Definition: QuantTypes.cpp:324
static UniformQuantizedPerAxisType get(unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
Definition: QuantTypes.cpp:304
int32_t getQuantizedDimension() const
Specifies the dimension of the Tensor's shape that the scales and zero_points correspond to.
Definition: QuantTypes.cpp:367
ArrayRef< int64_t > getZeroPoints() const
Gets the storage values corresponding to the real value 0 in the affine equation.
Definition: QuantTypes.cpp:363
ArrayRef< double > getScales() const
Gets the quantization scales.
Definition: QuantTypes.cpp:359
Represents a family of uniform, quantized types.
Definition: QuantTypes.h:255
double getScale() const
Gets the scale term.
Definition: QuantTypes.cpp:298
int64_t getZeroPoint() const
Gets the storage value corresponding to the real value 0 in the affine equation.
Definition: QuantTypes.cpp:300
static UniformQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
Definition: QuantTypes.cpp:262
static UniformQuantizedType get(unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
Definition: QuantTypes.cpp:253
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
Definition: QuantTypes.cpp:271
This header declares functions that assist transformations in the MemRef dialect.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26