MLIR 23.0.0git
QuantTypes.h
Go to the documentation of this file.
1//===- QuantTypes.h - Quantization Ops and Types ----------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_QUANT_IR_QUANTTYPES_H
10#define MLIR_DIALECT_QUANT_IR_QUANTTYPES_H
11
12#include "mlir/IR/Attributes.h"
13#include "mlir/IR/Builders.h"
15#include "mlir/IR/Dialect.h"
17#include "mlir/IR/Types.h"
18#include "llvm/Support/MathExtras.h"
19
20namespace mlir {
21namespace quant {
33
34/// Enumeration of bit-mapped flags related to quantized types.
37 // Indicates that the storage type should be interpreted as a signed
38 // integer. The default is to interpret it as an unsigned value.
39 Signed = 1,
40};
41} // namespace QuantizationFlags
42
43/// Base class for all quantized types known to this dialect.
44/// All quantized types have:
45/// - storageType: The (narrower) numeric type that is being used to
46/// approximate some expressed type.
47/// - expressedType: The type that is being approximated.
48///
49/// The base class provides generic support for manipulating the types based
50/// on these fields.
51class QuantizedType : public Type {
52public:
54 using Type::Type;
55
56 /// The maximum number of bits supported for storage types.
57 static constexpr unsigned MaxStorageBits = 32;
58
59 static LogicalResult
61 Type storageType, Type expressedType, int64_t storageTypeMin,
62 int64_t storageTypeMax);
63
64 /// Support method to enable LLVM-style type casting.
65 static bool classof(Type type);
66
67 /// Gets the minimum possible stored by a storageType. storageTypeMin must
68 /// be greater than or equal to this value.
70 unsigned integralWidth) {
71 if (isSigned) {
72 return llvm::minIntN(integralWidth);
73 }
74 return 0;
75 }
76
77 /// Gets the maximum possible stored by a storageType. storageTypeMax must
78 /// be less than or equal to this value.
80 unsigned integralWidth) {
81 if (isSigned) {
82 return llvm::maxIntN(integralWidth);
83 }
84 return llvm::maxUIntN(integralWidth);
85 }
86
87 /// Gets the original expressed type that this quantized type approximates.
88 /// Note that this presumes that the quantized type was always derived from
89 /// a floating point type, which in the broadest definition, is not true (i.e.
90 /// it could be some form of integral, fixed type or affine type in its own
91 /// right); however, at the high level, no examples of such usage are
92 /// presently known and the restriction serves some useful purposes (such as
93 /// always being able to reverse a transformation or measure error). In most
94 /// cases, this will be f32.
95 Type getExpressedType() const;
96
97 /// Gets the flags associated with this type. Typically a more specific
98 /// accessor is appropriate.
99 unsigned getFlags() const;
100
101 // Convenience helpers.
102 /// Whether the storage type should be interpreted as a signed quantity
103 /// (true) or an unsigned value (false).
104 bool isSigned() const {
105 return (getFlags() & QuantizationFlags::Signed) ==
107 }
108
109 /// Gets the underlying type used for to store values. Note that this may
110 /// be signed or unsigned. Use the isSigned() accessor to differentiate.
111 Type getStorageType() const;
112
113 /// The minimum value that storageType can take.
115
116 /// The maximum value that storageType can take.
118
119 /// Return whether the storage type has explicit min or max boundaries
120 /// different from the minimum and maximum representable values.
121 bool hasStorageTypeBounds() const;
122
123 /// Gets the integral bit width that the underlying storage type can exactly
124 /// represent. For integral storage types, this will just be their width.
125 unsigned getStorageTypeIntegralWidth() const;
126
127 /// Returns whether the candidateExpressedType is a match for this
128 /// QuantizedType. This will be true if the candidate type is either a
129 /// primitive type or a container type whose element type equals this
130 /// QuantizedType's expressed type.
131 /// Examples of compatible candidateExpressedType:
132 /// !quant.uniform<i8:f32, 1.0> =~ f32
133 /// !quant.uniform<i8:f32, 1.0> =~ tensor<4xf32>
134 bool isCompatibleExpressedType(Type candidateExpressedType);
135
136 /// Returns the element type as a QuantizedType or nullptr if it is not
137 /// a quantized type. If the type is primitive, returns that. If it is a
138 /// container (vector/tensor), return the element type.
139 /// Examples:
140 /// !quant.uniform<i8:f32, 1.0> -> !quant.uniform<i8:f32, 1.0>
141 /// tensor<4x!quant.uniform<i8:f32, 1.0> -> quant.uniform<i8:f32, 1.0>
142 static QuantizedType getQuantizedElementType(Type primitiveOrContainerType);
143
144 /// Casts from a type based on the storageType to a corresponding type based
145 /// on this type (returns nullptr if the cast is not valid).
146 /// Examples:
147 /// `candidate type` -> `return type`
148 /// i8 -> !quant.uniform<i8:f32, 1.0>
149 /// tensor<4xi8> -> tensor<4x!quant.uniform<i8:f32, 1.0}>>
150 /// vector<4xi8> -> vector<4x!quant.uniform<i8:f32, 1.0>>
151 /// It is assumed above that this type's quantization is `<i8:f32, 1.0>`.
152 Type castFromStorageType(Type candidateType);
153
154 /// Casts from a type based on a QuantizedType to a corresponding type based
155 /// on the storageType (returns nullptr if the cast is not valid).
156 /// This is the inverse of castFromStorageType().
157 static Type castToStorageType(Type quantizedType);
158
159 /// Casts from a type based on the expressedType to a corresponding type based
160 /// on this type (returns nullptr if the cast is not valid).
161 /// Examples:
162 /// f32 -> !quant.uniform<i8:f32, 1.0>
163 /// tensor<4xf32> -> tensor<4x!quant.uniform<i8:f32, 1.0>>
164 /// vector<4xf32> -> vector<4x!quant.uniform<i8:f32, 1.0>>
165 Type castFromExpressedType(Type candidateType);
166
167 /// Casts from a type based on QuantizedType to a corresponding type based
168 /// on the expressedType (returns nullptr if the cast is not valid).
169 /// This is the inverse of castFromExpressedType.
170 static Type castToExpressedType(Type quantizedType);
171
172 /// Casts from a type based on the expressedType to the equivalent type
173 /// based on storageType by way of this QuantizedType. Equivalent to:
174 /// QuantizedType::castToStorageType(castFromExpressedType(candidateType))
175 /// (but with validity checks).
176 /// Example (for this = !quant.uniform<i8:f32, 1.0>):
177 /// tensor<4xf32> -> tensor<4xi8>
178 Type castExpressedToStorageType(Type candidateType);
179
180private:
181 /// Hide the following methods inherited from `Type`. It is almost certainly
182 /// a bug to call them from a `QuantizedType` object. Users should call
183 /// `getStorageType` or `getExpressedType` to get the underlying types
184 /// they want to inspect.
185 using Type::isBF16;
186 using Type::isF16;
187 using Type::isF32;
188 using Type::isF64;
189 using Type::isIndex;
190 using Type::isInteger;
191};
192
193/// A quantized type that maps storage to/from expressed types in an
194/// unspecified way.
195///
196/// Typical syntax:
197/// quant.any<i8:f32>
198/// quant.any<i8>
199/// quant.any<i8<-16,15>>
200///
201/// Note that for the any type, the expressed type is optional.
203 : public Type::TypeBase<AnyQuantizedType, QuantizedType,
204 detail::AnyQuantizedTypeStorage> {
205public:
206 using Base::Base;
207 using Base::getChecked;
208
209 static constexpr StringLiteral name = "quant.any";
210
211 /// Gets an instance of the type with all parameters specified but not
212 /// checked.
213 static AnyQuantizedType get(unsigned flags, Type storageType,
214 Type expressedType, int64_t storageTypeMin,
215 int64_t storageTypeMax);
216
217 /// Gets an instance of the type with all specified parameters checked.
218 /// Returns a nullptr convertible type on failure.
219 static AnyQuantizedType
221 Type storageType, Type expressedType, int64_t storageTypeMin,
222 int64_t storageTypeMax);
223
224 /// Verifies construction invariants and issues errors/warnings.
225 static LogicalResult
227 Type storageType, Type expressedType, int64_t storageTypeMin,
228 int64_t storageTypeMax);
229};
230
231/// Represents a family of uniform, quantized types.
232///
233/// Each instance of this type expresses a mapping between real values (most
234/// often expressed in floating point f32) and quantized values (either fixed
235/// point or affine).
236///
237/// The relationship is:
238/// real_value = scale * (quantized_value - zero_point)
239///
240/// It is used as part of high level graph transformations that have the goal
241/// of re-expressing parts of a computation in terms of this common form for
242/// more efficient execution at runtime. In addition, it is designed to be
243/// expressive enough to facilitate lowering to precise types and operations
244/// in target hardware.
245///
246/// As a high-level type, focused on intermediate passes, this type holds
247/// opinions consistent with high-level usage. If lowering math kernels below
248/// the high level arithmetic ops (i.e. to LLVM IR or hardware specific
249/// instruction sets), it is expected that the information expressed here
250/// will be used to drive low level codegen and target specific type selection,
251/// but this type will likely be erased in the process.
252///
253/// Syntax synopsis:
254/// Per-layer, all parameters expressed:
255/// !quant<uniform[StorageType:ExpressedType]{Scale:ZeroPoint}>
256/// Per-layer, optional parameters omitted:
257/// !quant<uniform[StorageType]{Scale}>
258///
259/// StorageType: 'i'|'u' NumBits, 'f4', 'F8E5M2', 'bf8', 'quantile'
260/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
261/// Scale: A legal double value
262/// ZeroPoint: An integer value
264 : public Type::TypeBase<UniformQuantizedType, QuantizedType,
265 detail::UniformQuantizedTypeStorage> {
266public:
267 using Base::Base;
268 using Base::getChecked;
269
270 static constexpr StringLiteral name = "quant.uniform";
271
272 /// Gets an instance of the type with all parameters specified but not
273 /// checked.
274 static UniformQuantizedType get(unsigned flags, Type storageType,
275 Type expressedType, double scale,
276 int64_t zeroPoint, int64_t storageTypeMin,
277 int64_t storageTypeMax);
278
279 /// Gets an instance of the type with all specified parameters checked.
280 /// Returns a nullptr convertible type on failure.
283 Type storageType, Type expressedType, double scale,
284 int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax);
285
286 /// Verifies construction invariants and issues errors/warnings.
287 static LogicalResult
289 Type storageType, Type expressedType, double scale,
290 int64_t zeroPoint, int64_t storageTypeMin,
291 int64_t storageTypeMax);
292
293 /// Gets the scale term. The scale designates the difference between the real
294 /// values corresponding to consecutive quantized values differing by 1.
295 double getScale() const;
296
297 /// Gets the storage value corresponding to the real value 0 in the affine
298 /// equation.
299 int64_t getZeroPoint() const;
300
301 // Fixed point values are real numbers divided by a scale.
302 // Currently, only signed storage types are treated as fixed point.
303 // A fixed point value can be obtained from an affine value by subtracting
304 // the zeroPoint.
305 // In the future, this may be explicit versus implied by type and zeroPoint.
306 bool isFixedPoint() const { return isSigned() && getZeroPoint() == 0; }
307};
308
309/// Represents per-axis (also known as per-channel quantization).
310///
311/// Syntax synopsis:
312/// Per-axis, all parameters expressed:
313/// !quant<uniform[StorageType:ExpressedType:QuantizedDim]{QuantParams}>
314/// Per-axis, optional parameters omitted:
315/// !quant<uniform[StorageType]{Scale}>
316///
317/// StorageType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8', 'quantile'
318/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
319/// QuantizedDim: An integer value
320/// QuantParams: (Scale ':' ZeroPoint)+
321/// Scale: A legal double value
322/// ZeroPoint: An integer value
324 : public Type::TypeBase<UniformQuantizedPerAxisType, QuantizedType,
325 detail::UniformQuantizedPerAxisTypeStorage> {
326public:
327 using Base::Base;
328 using Base::getChecked;
329
330 static constexpr StringLiteral name = "quant.uniform_per_axis";
331
332 /// Gets an instance of the type with all parameters specified but not
333 /// checked.
335 get(unsigned flags, Type storageType, Type expressedType,
336 ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
337 int32_t quantizedDimension, int64_t storageTypeMin,
338 int64_t storageTypeMax);
339
340 /// Gets an instance of the type with all specified parameters checked.
341 /// Returns a nullptr convertible type on failure.
344 Type storageType, Type expressedType, ArrayRef<double> scales,
345 ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
346 int64_t storageTypeMin, int64_t storageTypeMax);
347
348 /// Verifies construction invariants and issues errors/warnings.
349 static LogicalResult
351 Type storageType, Type expressedType,
352 ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
353 int32_t quantizedDimension, int64_t storageTypeMin,
354 int64_t storageTypeMax);
355
356 /// Gets the quantization scales. The scales designate the difference between
357 /// the real values corresponding to consecutive quantized values differing
358 /// by 1. The ith scale corresponds to the ith slice in the
359 /// quantized_dimension.
361
362 /// Gets the storage values corresponding to the real value 0 in the affine
363 /// equation. The ith zero point corresponds to the ith slice in the
364 /// quantized_dimension.
366
367 /// Specifies the dimension of the Tensor's shape that the scales and
368 /// zero_points correspond to. For example, a tensor t, with dims=[4, 3, 2, 1]
369 /// with quantization params:
370 /// scales=[1.0, 2.0, 3.0], zeroPoints=[1, 2, 3], quantizedDimension=1
371 /// will be quantized across the second dimension of t.
372 /// t[:, 0, :, :] will have scale[0]=1.0, zero_point[0]=1
373 /// t[:, 1, :, :] will have scale[1]=2.0, zero_point[0]=2
374 /// t[:, 2, :, :] will have scale[2]=3.0, zero_point[0]=3
375 int32_t getQuantizedDimension() const;
376
377 /// Fixed point values are real numbers divided by a scale.
378 /// Currently, only signed storage types are treated as fixed point.
379 /// A fixed point value can be obtained from an affine value by subtracting
380 /// the zeroPoint.
381 /// In the future, this may be explicit versus implied by type and zeroPoint.
382 bool isFixedPoint() const {
383 if (!isSigned())
384 return false;
385 return !llvm::is_contained(getZeroPoints(), 0);
386 }
387};
388
389/// Represents sub-channel (also known as blockwise quantization).
390///
391/// Syntax synopsis:
392/// UniformQuantizedSubChannelType ::= '!quant.uniform' '<'
393/// storageType ('<' storageMin ':' storageMax '>')? ':'
394/// expressedType ':' BlockSizeInfo ',' ScaleZeroTensor '>'
395/// BlockSizeInfo: '{' '}' | '{' AxisBlock (',' AxisBlock)* '}'
396/// AxisBlock ::= AxisSpec ':' BlockSizeSpec
397/// ScaleZeroTensor ::= ScaleZeroDenseExp | ScaleZeroList
398/// ScaleZeroDenseExp ::= '{' ScaleZeroTensor (',' ScaleZeroTensor)* '}'
399/// ScaleZeroList ::= ScaleZero (',' ScaleZero)*
400/// ScaleZero ::= Scale (':' ZeroPoint)?
401///
402/// StorageType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8', 'quantile'
403/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
404/// AxisSpec: An integer value
405/// BlockSizeSpec: An integer value
406/// Scale: An attribute (usually floating-point value)
407/// ZeroPoint: An attribute (usually integer value)
409 : public Type::TypeBase<UniformQuantizedSubChannelType, QuantizedType,
410 detail::UniformQuantizedSubChannelTypeStorage> {
411public:
412 using Base::Base;
413 using Base::getChecked;
414
415 static constexpr StringLiteral name = "quant.uniform_sub_channel";
416
417 /// Gets an instance of the type with all parameters specified but not
418 /// checked.
420 get(unsigned flags, Type storageType, Type expressedType,
421 DenseElementsAttr scales, DenseElementsAttr zeroPoints,
422 ArrayRef<int32_t> quantizedDimensions, ArrayRef<int64_t> blockSizes,
423 int64_t storageTypeMin, int64_t storageTypeMax);
424
425 /// Gets an instance of the type with all specified parameters checked.
426 /// Returns a nullptr convertible type on failure.
429 Type storageType, Type expressedType, DenseElementsAttr scales,
430 DenseElementsAttr zeroPoints,
431 ArrayRef<int32_t> quantizedDimensions,
432 ArrayRef<int64_t> blockSizes, int64_t storageTypeMin,
433 int64_t storageTypeMax);
434
435 /// Verifies construction invariants and issues errors/warnings.
436 static LogicalResult
438 Type storageType, Type expressedType,
439 DenseElementsAttr scales, DenseElementsAttr zeroPoints,
440 ArrayRef<int32_t> quantizedDimensions,
441 ArrayRef<int64_t> blockSizes, int64_t storageTypeMin,
442 int64_t storageTypeMax);
443
444 /// Gets the quantization scales. The scales are organized in a
445 /// multi-dimensional tensor. The size of each dimension in the scales tensor
446 /// is determined by the number of blocks along the corresponding dimension in
447 /// the quantized data tensor.
448 ///
449 /// For example, if the quantized data tensor has shape [X0, X1, ..., XR-1]
450 /// and the block sizes are [B0, B1, ..., BR-1], then the scales tensor will
451 /// have shape [X0/B0, X1/B1, ..., XR-1/BR-1].
452 ///
453 /// The scale value for a specific element in the quantized data tensor at
454 /// position [i0, i1, ..., iR-1] is determined by accessing the corresponding
455 /// element in the scales tensor at position [i0/B0, i1/B1, ..., iR-1/BR-1].
457
458 /// Gets the quantization zero-points. The zero-points are organized in a
459 /// multi-dimensional tensor. The size of each dimension in the zero-point
460 /// tensor is determined by the number of blocks along the corresponding
461 /// dimension in the quantized data tensor.
462 ///
463 /// For example, if the quantized data tensor has shape [X0, X1, ..., XR-1]
464 /// and the block sizes are [B0, B1, ..., BR-1], then the zero-point tensor
465 /// will have shape [X0/B0, X1/B1, ..., XR-1/BR-1].
466 ///
467 /// The zero-point value for a specific element in the quantized data tensor
468 /// at position [i0, i1, ..., iR-1] is determined by accessing the
469 /// corresponding element in the zero-point tensor at position [i0/B0, i1/B1,
470 /// ..., iR-1/BR-1].
472
473 /// Gets the quantized dimensions. Each element in the returned list
474 /// represents an axis of the quantized data tensor that has a specified block
475 /// size. The order of elements corresponds to the order of block sizes
476 /// returned by `getBlockSizes()`.
477 ///
478 /// It means that the data tensor is quantized along the `i`-th dimension in
479 /// the returned list using the `i`-th block size from `getBlockSizes()`.
480 ///
481 /// Note that the type expression does not have to specify the block size for
482 /// all axes in the data tensor. Any unspecified block size for an axis `i`
483 /// defaults to the tensor dimension size of that axis.
484 ///
485 /// For example, for a quantized type:
486 /// `tensor<8x4x2x!quant.uniform<i8:f32:{1:2, 0:8}, {{1.0, 2.0}, {3.0, 4.0}}>`
487 ///
488 /// `getQuantizedDimensions()` returns [1, 0].
489 /// `getBlockSizes()` returns [2, 8].
490 ///
491 /// This indicates that:
492 /// * Axis 1 (second dimension) is quantized with a block size of 2.
493 /// * Axis 0 (first dimension) is quantized with a block size of 8.
494 /// Since axis 2 is not specified, it implicitly has a block size equal to
495 /// the size of the third dimension (which is 2 in this case).
497
498 /// Gets the block sizes for the quantized dimensions. The `i`-th element in
499 /// the returned list corresponds to the block size for the `i`-th dimension
500 /// in the list returned by `getQuantizedDimensions()`.
501 ///
502 /// See `getQuantizedDimensions()` for more details and examples.
504
505 /// Gets the block size information. This returns a list of pairs, where each
506 /// pair represents a quantized dimension and its corresponding block size.
507 ///
508 /// For example, for the type:
509 /// `tensor<8x4x!quant.uniform<i8:f32:{1:2, 0:8}, {{2.0, 3.0}}>`
510 ///
511 /// This method returns:
512 /// `[(1, 2), (0, 8)]`
513 ///
514 /// This list indicates that axis 1 has a block size of 2, and axis 0 has a
515 /// block size of 8.
517};
518
519/// A quantized type that infers its range from given min/max values.
520///
521/// Typical syntax:
522/// quant.calibrated<f32<-0.922,0.981>>
524 : public Type::TypeBase<CalibratedQuantizedType, QuantizedType,
525 detail::CalibratedQuantizedTypeStorage> {
526public:
527 using Base::Base;
528 using Base::getChecked;
529
530 static constexpr StringLiteral name = "quant.calibrated";
531
532 /// Gets an instance of the type with all parameters specified but not
533 /// checked.
534 static CalibratedQuantizedType get(Type expressedType, double min,
535 double max);
536
537 /// Gets an instance of the type with all specified parameters checked.
538 /// Returns a nullptr convertible type on failure.
541 double min, double max);
542
543 /// Verifies construction invariants and issues errors/warnings.
544 static LogicalResult
546 Type expressedType, double min, double max);
547 double getMin() const;
548 double getMax() const;
549};
550
551/*Syntax:
552
553 ```
554 quantile-type ::= `!quant.quantile` `<` type `:` type `,` `{` float-list `}`
555 (`,` `<` int `,` int `>`)? `>`
556 ```
557
558 A quantile type represents a quantile-based floating point encoding, where
559 discrete storage values are totally defined by the floating-point values
560 entries in a quantile lookup table of F8/F16/F32/F64.
561
562 Optionally, explicit minimum and maximum storage values can be specified
563 after the LUT as `<min:max>`.
564
565 This type is used for weight compression schemes like NF4 (NormalizedFloat4)
566 and similar quantile-based formats.
567
568 Example:
569
570 MLIR:
571 !quant.quantile<ui4:f16, {-1.0,-0.696,0.0,0.079,1.0}>
572 !quant.quantile<ui4:f16, {-1.0,-0.696,0.0,0.079,1.0}, <-8,7>>
573
574 As an additional explanation for better understanding and readability of the
575 above example, the quantile type can be broken down as follows:
576 - `!quant.quantile`: This indicates that we are defining a quantile type.
577 - `<ui4:f16`: This specifies the storage type and the quantile type. In this
578 case, `ui4` indicates an unsigned 4-bit integer storage type, and `f16`
579 indicates that the quantile values are represented as 16-bit floating-point
580 numbers.
581 - `{-1.0,-0.696,0.0,0.079,1.0}`: This is the quantile lookup table (LUT)
582 that defines the discrete storage values. Each value in the LUT corresponds
583 to a specific quantized value that can be stored in the `ui4` storage type.
584 - `, <-8,7>`: This optional part specifies the explicit minimum and maximum
585 storage values. In this case, the minimum storage value is -8 and the maximum
586 storage value is 7.
587*/
588
590 : public Type::TypeBase<QuantileType, QuantizedType,
591 detail::QuantileTypeStorage,
592 mlir::QuantStorageTypeInterface::Trait> {
593public:
595 using Base::Base;
596
597 // Get the underlying type used for to store raw values.
598 Type getStorageType() const;
599
600 // Get primitive expressed type of data in quantiles.
601 // Note that we may convert FP8 data to FP16 for storage,
602 // but we should treat its expressed type as FP8 rather than FP16.
603 Type getQuantileType() const;
604
605 /// Return the quantile table of this float type.
607
608 /// Return the explicit storage minimum, if set.
609 std::optional<int64_t> getStorageMin() const;
610
611 /// Return the explicit storage maximum, if set.
612 std::optional<int64_t> getStorageMax() const;
613
614 // Get a quantile float type with specified quantile table.
615 static QuantileType get(mlir::MLIRContext *ctx, Type storageType,
616 Type quantileType, ArrayRef<double> quantiles = {},
617 std::optional<int64_t> storageMin = std::nullopt,
618 std::optional<int64_t> storageMax = std::nullopt);
619
620 static QuantileType
622 mlir::MLIRContext *ctx, Type storageType, Type quantileType,
623 ArrayRef<double> quantiles,
624 std::optional<int64_t> storageMin = std::nullopt,
625 std::optional<int64_t> storageMax = std::nullopt);
626
627 static LogicalResult verifyInvariants(
629 Type quantileType, ArrayRef<double> quantiles,
630 std::optional<int64_t> storageMin, std::optional<int64_t> storageMax);
631
632 /// Methods for support type inquiry through isa, cast, and dyn_cast.
633 static bool classof(mlir::Type type);
634
635 // Printer
636 void print(mlir::AsmPrinter &printer) const;
637
638 static constexpr llvm::StringLiteral getMnemonic() { return {"quantile"}; }
639
640 static constexpr llvm::StringLiteral name = "quantile";
641
642 // Returns true if the type defaults to signed (e.g., si8, i8 or float types),
643 // false otherwise
644 bool shouldDefaultToSigned() const;
645
646 // Get the bit width of the storage type.
647 unsigned getStorageWidth() const;
648
649 // Get the default minimum and maximum values for the storage type.
650 int64_t getDefaultMinimum([[maybe_unused]] bool isSigned) const;
651 int64_t getDefaultMaximum([[maybe_unused]] bool isSigned) const;
652
653 // Get the string representation of the storage type
654 std::string getStorageTypeName([[maybe_unused]] bool isSigned) const;
655
656 // Get whether the type is a packed quantile float type
657 bool isPacked() const;
658
659 // Get the logical bit width of the quantile float type, which is the bit
660 // width of the represented floating point value.
661 unsigned getLogicalBitWidth() const;
662
663 // Get the number of quantized values stored in one byte for this quantile
664 // float type.
665 unsigned getElementsPerByte() const;
666
667 // Get the preferred alignment in bytes for this quantile float type, if any.
668 std::optional<unsigned> getPreferredAlignmentBytes() const;
669};
670} // namespace quant
671} // namespace mlir
672
673#endif // MLIR_DIALECT_QUANT_IR_QUANTTYPES_H
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
This base class exposes generic asm printer hooks, usable across the various derived printers.
An attribute that represents a reference to a dense vector or tensor object.
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
bool isIndex() const
Definition Types.cpp:56
constexpr Type()=default
bool isF32() const
Definition Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
detail::StorageUserBase< ConcreteType, BaseType, StorageType, detail::TypeUniquer, Traits... > TypeBase
Utility class for implementing types.
Definition Types.h:79
bool isF16() const
Definition Types.cpp:38
bool isBF16() const
Definition Types.cpp:37
StorageUserBase< ConcreteType, BaseType, StorageType, detail::TypeUniquer, Traits... > Base
A quantized type that maps storage to/from expressed types in an unspecified way.
Definition QuantTypes.h:204
static AnyQuantizedType get(unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
static constexpr StringLiteral name
Definition QuantTypes.h:209
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
static AnyQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
A quantized type that infers its range from given min/max values.
Definition QuantTypes.h:525
static constexpr StringLiteral name
Definition QuantTypes.h:530
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double min, double max)
Verifies construction invariants and issues errors/warnings.
static CalibratedQuantizedType get(Type expressedType, double min, double max)
Gets an instance of the type with all parameters specified but not checked.
static CalibratedQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double min, double max)
Gets an instance of the type with all specified parameters checked.
unsigned getLogicalBitWidth() const
static QuantileType getChecked(function_ref< InFlightDiagnostic()> emitError, mlir::MLIRContext *ctx, Type storageType, Type quantileType, ArrayRef< double > quantiles, std::optional< int64_t > storageMin=std::nullopt, std::optional< int64_t > storageMax=std::nullopt)
static bool classof(mlir::Type type)
Methods for support type inquiry through isa, cast, and dyn_cast.
unsigned getStorageWidth() const
std::optional< int64_t > getStorageMin() const
Return the explicit storage minimum, if set.
int64_t getDefaultMinimum(bool isSigned) const
unsigned getElementsPerByte() const
static constexpr llvm::StringLiteral name
Definition QuantTypes.h:640
std::string getStorageTypeName(bool isSigned) const
detail::QuantileTypeStorage ImplType
Definition QuantTypes.h:594
bool shouldDefaultToSigned() const
int64_t getDefaultMaximum(bool isSigned) const
static QuantileType get(mlir::MLIRContext *ctx, Type storageType, Type quantileType, ArrayRef< double > quantiles={}, std::optional< int64_t > storageMin=std::nullopt, std::optional< int64_t > storageMax=std::nullopt)
ArrayRef< double > getQuantiles() const
Return the quantile table of this float type.
std::optional< unsigned > getPreferredAlignmentBytes() const
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type storageType, Type quantileType, ArrayRef< double > quantiles, std::optional< int64_t > storageMin, std::optional< int64_t > storageMax)
std::optional< int64_t > getStorageMax() const
Return the explicit storage maximum, if set.
static constexpr llvm::StringLiteral getMnemonic()
Definition QuantTypes.h:638
void print(mlir::AsmPrinter &printer) const
Base class for all quantized types known to this dialect.
Definition QuantTypes.h:51
Type getExpressedType() const
Gets the original expressed type that this quantized type approximates.
static constexpr unsigned MaxStorageBits
The maximum number of bits supported for storage types.
Definition QuantTypes.h:57
bool hasStorageTypeBounds() const
Return whether the storage type has explicit min or max boundaries different from the minimum and max...
static Type castToStorageType(Type quantizedType)
Casts from a type based on a QuantizedType to a corresponding type based on the storageType (returns ...
Type castExpressedToStorageType(Type candidateType)
Casts from a type based on the expressedType to the equivalent type based on storageType by way of th...
detail::QuantizedTypeStorage ImplType
Definition QuantTypes.h:53
static Type castToExpressedType(Type quantizedType)
Casts from a type based on QuantizedType to a corresponding type based on the expressedType (returns ...
bool isSigned() const
Whether the storage type should be interpreted as a signed quantity (true) or an unsigned value (fals...
Definition QuantTypes.h:104
constexpr Type()=default
static QuantizedType getQuantizedElementType(Type primitiveOrContainerType)
Returns the element type as a QuantizedType or nullptr if it is not a quantized type.
unsigned getFlags() const
Gets the flags associated with this type.
int64_t getStorageTypeMax() const
The maximum value that storageType can take.
static int64_t getDefaultMaximumForInteger(bool isSigned, unsigned integralWidth)
Gets the maximum possible stored by a storageType.
Definition QuantTypes.h:79
unsigned getStorageTypeIntegralWidth() const
Gets the integral bit width that the underlying storage type can exactly represent.
static bool classof(Type type)
Support method to enable LLVM-style type casting.
Type castFromStorageType(Type candidateType)
Casts from a type based on the storageType to a corresponding type based on this type (returns nullpt...
int64_t getStorageTypeMin() const
The minimum value that storageType can take.
static int64_t getDefaultMinimumForInteger(bool isSigned, unsigned integralWidth)
Gets the minimum possible stored by a storageType.
Definition QuantTypes.h:69
Type getStorageType() const
Gets the underlying type used for to store values.
Type castFromExpressedType(Type candidateType)
Casts from a type based on the expressedType to a corresponding type based on this type (returns null...
bool isCompatibleExpressedType(Type candidateExpressedType)
Returns whether the candidateExpressedType is a match for this QuantizedType.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Represents per-axis (also known as per-channel quantization).
Definition QuantTypes.h:325
static constexpr StringLiteral name
Definition QuantTypes.h:330
static UniformQuantizedPerAxisType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
bool isFixedPoint() const
Fixed point values are real numbers divided by a scale.
Definition QuantTypes.h:382
static UniformQuantizedPerAxisType get(unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
int32_t getQuantizedDimension() const
Specifies the dimension of the Tensor's shape that the scales and zero_points correspond to.
ArrayRef< int64_t > getZeroPoints() const
Gets the storage values corresponding to the real value 0 in the affine equation.
ArrayRef< double > getScales() const
Gets the quantization scales.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
Represents sub-channel (also known as blockwise quantization).
Definition QuantTypes.h:410
static constexpr StringLiteral name
Definition QuantTypes.h:415
ArrayRef< int32_t > getQuantizedDimensions() const
Gets the quantized dimensions.
DenseElementsAttr getZeroPoints() const
Gets the quantization zero-points.
ArrayRef< int64_t > getBlockSizes() const
Gets the block sizes for the quantized dimensions.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, DenseElementsAttr scales, DenseElementsAttr zeroPoints, ArrayRef< int32_t > quantizedDimensions, ArrayRef< int64_t > blockSizes, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
const SmallVector< std::pair< int32_t, int64_t > > getBlockSizeInfo() const
Gets the block size information.
static UniformQuantizedSubChannelType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, DenseElementsAttr scales, DenseElementsAttr zeroPoints, ArrayRef< int32_t > quantizedDimensions, ArrayRef< int64_t > blockSizes, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
static UniformQuantizedSubChannelType get(unsigned flags, Type storageType, Type expressedType, DenseElementsAttr scales, DenseElementsAttr zeroPoints, ArrayRef< int32_t > quantizedDimensions, ArrayRef< int64_t > blockSizes, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
DenseElementsAttr getScales() const
Gets the quantization scales.
Represents a family of uniform, quantized types.
Definition QuantTypes.h:265
double getScale() const
Gets the scale term.
int64_t getZeroPoint() const
Gets the storage value corresponding to the real value 0 in the affine equation.
static constexpr StringLiteral name
Definition QuantTypes.h:270
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
static UniformQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
static UniformQuantizedType get(unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
Enumeration of bit-mapped flags related to quantized types.
Definition QuantTypes.h:35
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147