MLIR  14.0.0git
QuantTypes.h
Go to the documentation of this file.
1 //===- QuantTypes.h - Quantization Ops and Types ----------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_QUANT_QUANTTYPES_H
10 #define MLIR_DIALECT_QUANT_QUANTTYPES_H
11 
12 #include "mlir/IR/Attributes.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/BuiltinTypes.h"
15 #include "mlir/IR/Dialect.h"
16 #include "mlir/IR/OpDefinition.h"
17 #include "mlir/IR/Types.h"
18 #include "llvm/Support/MathExtras.h"
19 
20 namespace mlir {
21 namespace quant {
22 
23 class QuantizedIntegerType;
24 
25 namespace detail {
26 
32 
33 } // namespace detail
34 
35 /// Enumeration of bit-mapped flags related to quantized types.
36 namespace QuantizationFlags {
37 enum FlagValue {
38  // Indicates that the storage type should be interpreted as a signed
39  // integer. The default is to interpret it as an unsigned value.
40  Signed = 1,
41 };
42 } // namespace QuantizationFlags
43 
44 /// Base class for all quantized types known to this dialect.
45 /// All quantized types have:
46 /// - storageType: The (narrower) numeric type that is being used to
47 /// approximate some expressed type.
48 /// - expressedType: The type that is being approximated.
49 ///
50 /// The base class provides generic support for manipulating the types based
51 /// on these fields.
52 class QuantizedType : public Type {
53 public:
55  using Type::Type;
56 
57  /// The maximum number of bits supported for storage types.
58  static constexpr unsigned MaxStorageBits = 32;
59 
61  unsigned flags, Type storageType,
62  Type expressedType, int64_t storageTypeMin,
63  int64_t storageTypeMax);
64 
65  /// Support method to enable LLVM-style type casting.
66  static bool classof(Type type);
67 
68  /// Gets the minimum possible stored by a storageType. storageTypeMin must
69  /// be greater than or equal to this value.
70  static int64_t getDefaultMinimumForInteger(bool isSigned,
71  unsigned integralWidth) {
72  if (isSigned) {
73  return llvm::minIntN(integralWidth);
74  }
75  return 0;
76  }
77 
78  /// Gets the maximum possible stored by a storageType. storageTypeMax must
79  /// be less than or equal to this value.
80  static int64_t getDefaultMaximumForInteger(bool isSigned,
81  unsigned integralWidth) {
82  if (isSigned) {
83  return llvm::maxIntN(integralWidth);
84  }
85  return llvm::maxUIntN(integralWidth);
86  }
87 
88  /// Gets the original expressed type that this quantized type approximates.
89  /// Note that this presumes that the quantized type was always derived from
90  /// a floating point type, which in the broadest definition, is not true (i.e.
91  /// it could be some form of integral, fixed type or affine type in its own
92  /// right); however, at the high level, no examples of such usage are
93  /// presently known and the restriction serves some useful purposes (such as
94  /// always being able to reverse a transformation or measure error). In most
95  /// cases, this will be f32.
96  Type getExpressedType() const;
97 
98  /// Gets the flags associated with this type. Typically a more specific
99  /// accessor is appropriate.
100  unsigned getFlags() const;
101 
102  // Convenience helpers.
103  /// Whether the storage type should be interpreted as a signed quantity
104  /// (true) or an unsigned value (false).
105  bool isSigned() const {
106  return (getFlags() & QuantizationFlags::Signed) ==
108  }
109 
110  /// Gets the underlying type used for to store values. Note that this may
111  /// be signed or unsigned. Use the isSigned() accessor to differentiate.
112  Type getStorageType() const;
113 
114  /// The minimum value that storageType can take.
115  int64_t getStorageTypeMin() const;
116 
117  /// The maximum value that storageType can take.
118  int64_t getStorageTypeMax() const;
119 
120  /// Gets the integral bit width that the underlying storage type can exactly
121  /// represent. For integral storage types, this will just be their width.
122  unsigned getStorageTypeIntegralWidth() const;
123 
124  /// Returns whether the candidateExpressedType is a match for this
125  /// QuantizedType. This will be true if the candidate type is either a
126  /// primitive type or a container type whose element type equals this
127  /// QuantizedType's expressed type.
128  /// Examples of compatible candidateExpressedType:
129  /// !quant.uniform<i8:f32, 1.0> =~ f32
130  /// !quant.uniform<i8:f32, 1.0> =~ tensor<4xf32>
131  bool isCompatibleExpressedType(Type candidateExpressedType);
132 
133  /// Returns the element type as a QuantizedType or nullptr if it is not
134  /// a quantized type. If the type is primitive, returns that. If it is a
135  /// container (vector/tensor), return the element type.
136  /// Examples:
137  /// !quant.uniform<i8:f32, 1.0> -> !quant.uniform<i8:f32, 1.0>
138  /// tensor<4x!quant.uniform<i8:f32, 1.0> -> quant.uniform<i8:f32, 1.0>
139  static QuantizedType getQuantizedElementType(Type primitiveOrContainerType);
140 
141  /// Casts from a type based on the storageType to a corresponding type based
142  /// on this type (returns nullptr if the cast is not valid).
143  /// Examples:
144  /// i8 -> !quant.uniform<i8:f32, 1.0>
145  /// tensor<4xi8> -> tensor<4x!quant.uniform<i8:f32, 1.0}>>
146  /// vector<4xi8> -> vector<4x!quant.uniform<i8:f32, 1.0>>
147  Type castFromStorageType(Type candidateType);
148 
149  /// Casts from a type based on a QuantizedType to a corresponding type based
150  /// on the storageType (returns nullptr if the cast is not valid).
151  /// This is the inverse of castFromStorageType().
152  static Type castToStorageType(Type quantizedType);
153 
154  /// Casts from a type based on the expressedType to a corresponding type based
155  /// on this type (returns nullptr if the cast is not valid).
156  /// Examples:
157  /// f32 -> !quant.uniform<i8:f32, 1.0>
158  /// tensor<4xf32> -> tensor<4x!quant.uniform<i8:f32, 1.0>>
159  /// vector<4xf32> -> vector<4x!quant.uniform<i8:f32, 1.0>>
160  Type castFromExpressedType(Type candidateType);
161 
162  /// Casts from a type based on QuantizedType to a corresponding type based
163  /// on the expressedType (returns nullptr if the cast is not valid).
164  /// This is the inverse of castFromExpressedType.
165  static Type castToExpressedType(Type quantizedType);
166 
167  /// Casts from a type based on the expressedType to the equivalent type
168  /// based on storageType by way of this QuantizedType. Equivalent to:
169  /// QuantizedType::castToStorageType(castFromExpressedType(candidateType))
170  /// (but with validity checks).
171  /// Example (for this = !quant.uniform<i8:f32, 1.0>):
172  /// tensor<4xf32> -> tensor<4xi8>
173  Type castExpressedToStorageType(Type candidateType);
174 
175 private:
176  /// Hide the following methods inherited from `Type`. It is almost certainly
177  /// a bug to call them from a `QuantizedType` object. Users should call
178  /// `getStorageType` or `getExpressedType` to get the underlying types
179  /// they want to inspect.
180  using Type::isBF16;
181  using Type::isF16;
182  using Type::isF32;
183  using Type::isF64;
184  using Type::isIndex;
185  using Type::isInteger;
186 };
187 
188 /// A quantized type that maps storage to/from expressed types in an
189 /// unspecified way.
190 ///
191 /// Typical syntax:
192 /// quant.any<i8:f32>
193 /// quant.any<i8>
194 /// quant.any<i8<-16,15>>
195 ///
196 /// Note that for the any type, the expressed type is optional.
198  : public Type::TypeBase<AnyQuantizedType, QuantizedType,
199  detail::AnyQuantizedTypeStorage> {
200 public:
201  using Base::Base;
202  using Base::getChecked;
203 
204  /// Gets an instance of the type with all parameters specified but not
205  /// checked.
206  static AnyQuantizedType get(unsigned flags, Type storageType,
207  Type expressedType, int64_t storageTypeMin,
208  int64_t storageTypeMax);
209 
210  /// Gets an instance of the type with all specified parameters checked.
211  /// Returns a nullptr convertible type on failure.
212  static AnyQuantizedType
213  getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
214  Type storageType, Type expressedType, int64_t storageTypeMin,
215  int64_t storageTypeMax);
216 
217  /// Verifies construction invariants and issues errors/warnings.
219  unsigned flags, Type storageType,
220  Type expressedType, int64_t storageTypeMin,
221  int64_t storageTypeMax);
222 };
223 
224 /// Represents a family of uniform, quantized types.
225 ///
226 /// Each instance of this type expresses a mapping between real values (most
227 /// often expressed in floating point f32) and quantized values (either fixed
228 /// point or affine).
229 ///
230 /// The relationship is:
231 /// real_value = scale * (quantized_value - zero_point)
232 ///
233 /// It is used as part of high level graph transformations that have the goal
234 /// of re-expressing parts of a computation in terms of this common form for
235 /// more efficient execution at runtime. In addition, it is designed to be
236 /// expressive enough to facilitate lowering to precise types and operations
237 /// in target hardware.
238 ///
239 /// As a high-level type, focused on intermediate passes, this type holds
240 /// opinions consistent with high-level usage. If lowering math kernels below
241 /// the high level arithmetic ops (i.e. to LLVM IR or hardware specific
242 /// instruction sets), it is expected that the information expressed here
243 /// will be used to drive low level codegen and target specific type selection,
244 /// but this type will likely be erased in the process.
245 ///
246 /// Syntax synopsis:
247 /// Per-layer, all parameters expressed:
248 /// !quant<uniform[StorageType:ExpressedType]{Scale:ZeroPoint}>
249 /// Per-layer, optional parameters omitted:
250 /// !quant<uniform[StorageType]{Scale}>
251 ///
252 /// StorageType: 'i'|'u' NumBits
253 /// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
254 /// Scale: A legal double value
255 /// ZeroPoint: An integer value
257  : public Type::TypeBase<UniformQuantizedType, QuantizedType,
258  detail::UniformQuantizedTypeStorage> {
259 public:
260  using Base::Base;
261  using Base::getChecked;
262 
263  /// Gets an instance of the type with all parameters specified but not
264  /// checked.
265  static UniformQuantizedType get(unsigned flags, Type storageType,
266  Type expressedType, double scale,
267  int64_t zeroPoint, int64_t storageTypeMin,
268  int64_t storageTypeMax);
269 
270  /// Gets an instance of the type with all specified parameters checked.
271  /// Returns a nullptr convertible type on failure.
272  static UniformQuantizedType
273  getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
274  Type storageType, Type expressedType, double scale,
275  int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax);
276 
277  /// Verifies construction invariants and issues errors/warnings.
279  unsigned flags, Type storageType,
280  Type expressedType, double scale,
281  int64_t zeroPoint, int64_t storageTypeMin,
282  int64_t storageTypeMax);
283 
284  /// Gets the scale term. The scale designates the difference between the real
285  /// values corresponding to consecutive quantized values differing by 1.
286  double getScale() const;
287 
288  /// Gets the storage value corresponding to the real value 0 in the affine
289  /// equation.
290  int64_t getZeroPoint() const;
291 
292  // Fixed point values are real numbers divided by a scale.
293  // Currently, only signed storage types are treated as fixed point.
294  // A fixed point value can be obtained from an affine value by subtracting
295  // the zeroPoint.
296  // In the future, this may be explicit versus implied by type and zeroPoint.
297  bool isFixedPoint() const { return isSigned() && getZeroPoint() == 0; }
298 };
299 
300 /// Represents per-axis (also known as per-channel quantization).
301 ///
302 /// Syntax synopsis:
303 /// Per-axis, all parameters expressed:
304 /// !quant<uniform[StorageType:ExpressedType:QuantizedDim]{QuantParams}>
305 /// Per-axis, optional parameters omitted:
306 /// !quant<uniform[StorageType]{Scale}>
307 ///
308 /// StorageType: 'i'|'u' NumBits
309 /// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
310 /// QuantizedDim: An integer value
311 /// QuantParams: (Scale ':' ZeroPoint)+
312 /// Scale: A legal double value
313 /// ZeroPoint: An integer value
315  : public Type::TypeBase<UniformQuantizedPerAxisType, QuantizedType,
316  detail::UniformQuantizedPerAxisTypeStorage> {
317 public:
318  using Base::Base;
319  using Base::getChecked;
320 
321  /// Gets an instance of the type with all parameters specified but not
322  /// checked.
324  get(unsigned flags, Type storageType, Type expressedType,
325  ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
326  int32_t quantizedDimension, int64_t storageTypeMin,
327  int64_t storageTypeMax);
328 
329  /// Gets an instance of the type with all specified parameters checked.
330  /// Returns a nullptr convertible type on failure.
332  getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
333  Type storageType, Type expressedType, ArrayRef<double> scales,
334  ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
335  int64_t storageTypeMin, int64_t storageTypeMax);
336 
337  /// Verifies construction invariants and issues errors/warnings.
339  unsigned flags, Type storageType,
340  Type expressedType, ArrayRef<double> scales,
341  ArrayRef<int64_t> zeroPoints,
342  int32_t quantizedDimension,
343  int64_t storageTypeMin, int64_t storageTypeMax);
344 
345  /// Gets the quantization scales. The scales designate the difference between
346  /// the real values corresponding to consecutive quantized values differing
347  /// by 1. The ith scale corresponds to the ith slice in the
348  /// quantized_dimension.
349  ArrayRef<double> getScales() const;
350 
351  /// Gets the storage values corresponding to the real value 0 in the affine
352  /// equation. The ith zero point corresponds to the ith slice in the
353  /// quantized_dimension.
354  ArrayRef<int64_t> getZeroPoints() const;
355 
356  /// Specifies the dimension of the Tensor's shape that the scales and
357  /// zero_points correspond to. For example, a tensor t, with dims=[4, 3, 2, 1]
358  /// with quantization params:
359  /// scales=[1.0, 2.0, 3.0], zeroPoints=[1, 2, 3], quantizedDimension=1
360  /// will be quantized across the second dimension of t.
361  /// t[:, 0, :, :] will have scale[0]=1.0, zero_point[0]=1
362  /// t[:, 1, :, :] will have scale[1]=2.0, zero_point[0]=2
363  /// t[:, 2, :, :] will have scale[2]=3.0, zero_point[0]=3
364  int32_t getQuantizedDimension() const;
365 
366  /// Fixed point values are real numbers divided by a scale.
367  /// Currently, only signed storage types are treated as fixed point.
368  /// A fixed point value can be obtained from an affine value by subtracting
369  /// the zeroPoint.
370  /// In the future, this may be explicit versus implied by type and zeroPoint.
371  bool isFixedPoint() const {
372  if (!isSigned())
373  return false;
374  return llvm::all_of(getZeroPoints(),
375  [](int64_t zeroPoint) { return zeroPoint != 0; });
376  }
377 };
378 
379 /// A quantized type that infers its range from given min/max values.
380 ///
381 /// Typical syntax:
382 /// quant.calibrated<f32<-0.922,0.981>>
384  : public Type::TypeBase<CalibratedQuantizedType, QuantizedType,
385  detail::CalibratedQuantizedTypeStorage> {
386 public:
387  using Base::Base;
388  using Base::getChecked;
389 
390  /// Gets an instance of the type with all parameters specified but not
391  /// checked.
392  static CalibratedQuantizedType get(Type expressedType, double min,
393  double max);
394 
395  /// Gets an instance of the type with all specified parameters checked.
396  /// Returns a nullptr convertible type on failure.
398  getChecked(function_ref<InFlightDiagnostic()> emitError, Type expressedType,
399  double min, double max);
400 
401  /// Verifies construction invariants and issues errors/warnings.
403  Type expressedType, double min, double max);
404  double getMin() const;
405  double getMax() const;
406 };
407 
408 } // namespace quant
409 } // namespace mlir
410 
411 #endif // MLIR_DIALECT_QUANT_QUANTTYPES_H
Include the generated interface declarations.
bool isF32() const
Definition: Types.cpp:23
static Value min(ImplicitLocOpBuilder &builder, Value a, Value b)
constexpr Type()
Definition: Types.h:84
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:301
static int64_t getDefaultMinimumForInteger(bool isSigned, unsigned integralWidth)
Gets the minimum possible stored by a storageType.
Definition: QuantTypes.h:70
LogicalResult verify(Operation *op)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
Definition: Verifier.cpp:353
Represents a family of uniform, quantized types.
Definition: QuantTypes.h:256
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
Definition: Types.cpp:31
A quantized type that infers its range from given min/max values.
Definition: QuantTypes.h:383
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
bool isFixedPoint() const
Fixed point values are real numbers divided by a scale.
Definition: QuantTypes.h:371
A quantized type that maps storage to/from expressed types in an unspecified way. ...
Definition: QuantTypes.h:197
bool isF16() const
Definition: Types.cpp:22
bool isIndex() const
Definition: Types.cpp:28
Represents per-axis (also known as per-channel quantization).
Definition: QuantTypes.h:314
bool isSigned() const
Whether the storage type should be interpreted as a signed quantity (true) or an unsigned value (fals...
Definition: QuantTypes.h:105
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
bool isF64() const
Definition: Types.cpp:24
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Base class for all quantized types known to this dialect.
Definition: QuantTypes.h:52
Utility class for implementing users of storage classes uniqued by a StorageUniquer.
static int64_t getDefaultMaximumForInteger(bool isSigned, unsigned integralWidth)
Gets the maximum possible stored by a storageType.
Definition: QuantTypes.h:80
bool isBF16() const
Definition: Types.cpp:21
static Value max(ImplicitLocOpBuilder &builder, Value a, Value b)