MLIR  22.0.0git
QuantTypes.cpp
Go to the documentation of this file.
1 //===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "TypeDetail.h"
12 
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/MLIRContext.h"
15 
16 using namespace mlir;
17 using namespace mlir::quant;
18 using namespace mlir::quant::detail;
19 
20 namespace {
21 
22 // Return the minimum scale representable in a given float type
23 double getMinScale(Type expressedType) {
24  auto floatType = cast<FloatType>(expressedType);
25  return APFloat::getSmallest(floatType.getFloatSemantics()).convertToDouble();
26 }
27 
28 // Return the maximum scale representable in a given float type
29 double getMaxScale(Type expressedType) {
30  auto floatType = cast<FloatType>(expressedType);
31  return APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble();
32 }
33 
34 } // namespace
35 
36 unsigned QuantizedType::getFlags() const {
37  return static_cast<ImplType *>(impl)->flags;
38 }
39 
41  return llvm::isa<QuantDialect>(type.getDialect());
42 }
43 
44 LogicalResult
46  unsigned flags, Type storageType,
47  Type expressedType, int64_t storageTypeMin,
48  int64_t storageTypeMax) {
49  // Verify that the storage type is integral.
50  // This restriction may be lifted at some point in favor of using bf16
51  // or f16 as exact representations on hardware where that is advantageous.
52  auto intStorageType = llvm::dyn_cast<IntegerType>(storageType);
53  if (!intStorageType)
54  return emitError() << "storage type must be integral";
55  unsigned integralWidth = intStorageType.getWidth();
56 
57  // Verify storage width.
58  if (integralWidth == 0 || integralWidth > MaxStorageBits)
59  return emitError() << "illegal storage type size: " << integralWidth;
60 
61  // Verify storageTypeMin and storageTypeMax.
62  bool isSigned =
64  int64_t defaultIntegerMin =
65  getDefaultMinimumForInteger(isSigned, integralWidth);
66  int64_t defaultIntegerMax =
67  getDefaultMaximumForInteger(isSigned, integralWidth);
68  if (storageTypeMax - storageTypeMin <= 0 ||
69  storageTypeMin < defaultIntegerMin ||
70  storageTypeMax > defaultIntegerMax) {
71  return emitError() << "illegal storage min and storage max: ("
72  << storageTypeMin << ":" << storageTypeMax << ")";
73  }
74  return success();
75 }
76 
78  return static_cast<ImplType *>(impl)->storageType;
79 }
80 
82  return static_cast<ImplType *>(impl)->storageTypeMin;
83 }
84 
86  return static_cast<ImplType *>(impl)->storageTypeMax;
87 }
88 
90  unsigned int integralWidth = getStorageTypeIntegralWidth();
91  bool isSignedInteger = isSigned();
92  int64_t defaultIntegerMin =
93  getDefaultMinimumForInteger(isSignedInteger, integralWidth);
94  int64_t defaultIntegerMax =
95  getDefaultMaximumForInteger(isSignedInteger, integralWidth);
96  return defaultIntegerMin != getStorageTypeMin() ||
97  defaultIntegerMax != getStorageTypeMax();
98 }
99 
101  // NOTE: If ever supporting non-integral storage types, some other scheme
102  // for determining the width will be needed.
103  return static_cast<ImplType *>(impl)->storageType.getIntOrFloatBitWidth();
104 }
105 
107  return static_cast<ImplType *>(impl)->expressedType;
108 }
109 
110 bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) {
111  if (llvm::isa<ShapedType>(candidateExpressedType)) {
112  return llvm::cast<ShapedType>(candidateExpressedType).getElementType() ==
113  getExpressedType();
114  }
115  return candidateExpressedType == getExpressedType();
116 }
117 
119 QuantizedType::getQuantizedElementType(Type primitiveOrContainerType) {
120  if (llvm::isa<ShapedType>(primitiveOrContainerType)) {
121  Type elementType =
122  llvm::cast<ShapedType>(primitiveOrContainerType).getElementType();
123  return llvm::dyn_cast<QuantizedType>(elementType);
124  }
125  return llvm::dyn_cast<QuantizedType>(primitiveOrContainerType);
126 }
127 
129  if (candidateType == getStorageType()) {
130  // i.e. i8 -> quant<"uniform[i8:f32]{1.0}">
131  return *this;
132  }
133  if (llvm::isa<RankedTensorType>(candidateType)) {
134  // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
135  return RankedTensorType::get(
136  llvm::cast<RankedTensorType>(candidateType).getShape(),
137  getStorageType());
138  }
139  if (llvm::isa<UnrankedTensorType>(candidateType)) {
140  // i.e. tensor<xi8> -> tensor<x!quant<"uniform[i8:f32]{1.0}">>
141  return UnrankedTensorType::get(getStorageType());
142  }
143  if (llvm::isa<VectorType>(candidateType)) {
144  // i.e. vector<4xi8> -> vector<4x!quant<"uniform[i8:f32]{1.0}">>
145  return VectorType::get(llvm::cast<VectorType>(candidateType).getShape(),
146  getStorageType());
147  }
148 
149  return nullptr;
150 }
151 
153  if (llvm::isa<QuantizedType>(quantizedType)) {
154  // i.e. quant<"uniform[i8:f32]{1.0}"> -> i8
155  return llvm::cast<QuantizedType>(quantizedType).getStorageType();
156  }
157  if (llvm::isa<ShapedType>(quantizedType)) {
158  // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
159  ShapedType sType = llvm::cast<ShapedType>(quantizedType);
160  if (!llvm::isa<QuantizedType>(sType.getElementType())) {
161  return nullptr;
162  }
163  Type storageType =
164  llvm::cast<QuantizedType>(sType.getElementType()).getStorageType();
165  if (llvm::isa<RankedTensorType>(quantizedType)) {
166  return RankedTensorType::get(sType.getShape(), storageType);
167  }
168  if (llvm::isa<UnrankedTensorType>(quantizedType)) {
169  return UnrankedTensorType::get(storageType);
170  }
171  if (llvm::isa<VectorType>(quantizedType)) {
172  return VectorType::get(sType.getShape(), storageType);
173  }
174  }
175 
176  return nullptr;
177 }
178 
180  if (candidateType == getExpressedType()) {
181  // i.e. f32 -> quant<"uniform[i8:f32]{1.0}">
182  return *this;
183  }
184  if (llvm::isa<ShapedType>(candidateType)) {
185  ShapedType candidateShapedType = llvm::cast<ShapedType>(candidateType);
186  if (candidateShapedType.getElementType() != getExpressedType()) {
187  return nullptr;
188  }
189 
190  if (llvm::isa<RankedTensorType>(candidateType)) {
191  // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
192  return RankedTensorType::get(candidateShapedType.getShape(), *this);
193  }
194  if (llvm::isa<UnrankedTensorType>(candidateType)) {
195  // i.e. tensor<xf32> -> tensor<x!quant<"uniform[i8:f32]{1.0}">>
196  return UnrankedTensorType::get(*this);
197  }
198  if (llvm::isa<VectorType>(candidateType)) {
199  // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
200  return VectorType::get(candidateShapedType.getShape(), *this);
201  }
202  }
203 
204  return nullptr;
205 }
206 
208  if (llvm::isa<QuantizedType>(quantizedType)) {
209  // i.e. quant<"uniform[i8:f32]{1.0}"> -> f32
210  return llvm::cast<QuantizedType>(quantizedType).getExpressedType();
211  }
212  if (llvm::isa<ShapedType>(quantizedType)) {
213  // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
214  ShapedType sType = llvm::cast<ShapedType>(quantizedType);
215  if (!llvm::isa<QuantizedType>(sType.getElementType())) {
216  return nullptr;
217  }
218  Type expressedType =
219  llvm::cast<QuantizedType>(sType.getElementType()).getExpressedType();
220  if (llvm::isa<RankedTensorType>(quantizedType)) {
221  return RankedTensorType::get(sType.getShape(), expressedType);
222  }
223  if (llvm::isa<UnrankedTensorType>(quantizedType)) {
224  return UnrankedTensorType::get(expressedType);
225  }
226  if (llvm::isa<VectorType>(quantizedType)) {
227  return VectorType::get(sType.getShape(), expressedType);
228  }
229  }
230 
231  return nullptr;
232 }
233 
235  Type expressedQuantizedType = castFromExpressedType(candidateType);
236  if (!expressedQuantizedType) {
237  return nullptr;
238  }
239  return QuantizedType::castToStorageType(expressedQuantizedType);
240 }
241 
242 AnyQuantizedType AnyQuantizedType::get(unsigned flags, Type storageType,
243  Type expressedType,
244  int64_t storageTypeMin,
245  int64_t storageTypeMax) {
246  return Base::get(storageType.getContext(), flags, storageType, expressedType,
247  storageTypeMin, storageTypeMax);
248 }
249 
252  unsigned flags, Type storageType,
253  Type expressedType, int64_t storageTypeMin,
254  int64_t storageTypeMax) {
255  return Base::getChecked(emitError, storageType.getContext(), flags,
256  storageType, expressedType, storageTypeMin,
257  storageTypeMax);
258 }
259 
260 LogicalResult
262  unsigned flags, Type storageType,
263  Type expressedType, int64_t storageTypeMin,
264  int64_t storageTypeMax) {
265  if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
266  expressedType, storageTypeMin,
267  storageTypeMax))) {
268  return failure();
269  }
270 
271  // Verify that the expressed type is floating point.
272  // If this restriction is ever eliminated, the parser/printer must be
273  // extended.
274  if (expressedType && !llvm::isa<FloatType>(expressedType))
275  return emitError() << "expressed type must be floating point";
276 
277  return success();
278 }
279 
281  Type expressedType, double scale,
282  int64_t zeroPoint,
283  int64_t storageTypeMin,
284  int64_t storageTypeMax) {
285  return Base::get(storageType.getContext(), flags, storageType, expressedType,
286  scale, zeroPoint, storageTypeMin, storageTypeMax);
287 }
288 
290  function_ref<InFlightDiagnostic()> emitError, unsigned flags,
291  Type storageType, Type expressedType, double scale, int64_t zeroPoint,
292  int64_t storageTypeMin, int64_t storageTypeMax) {
293  return Base::getChecked(emitError, storageType.getContext(), flags,
294  storageType, expressedType, scale, zeroPoint,
295  storageTypeMin, storageTypeMax);
296 }
297 
299  function_ref<InFlightDiagnostic()> emitError, unsigned flags,
300  Type storageType, Type expressedType, double scale, int64_t zeroPoint,
301  int64_t storageTypeMin, int64_t storageTypeMax) {
302  if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
303  expressedType, storageTypeMin,
304  storageTypeMax))) {
305  return failure();
306  }
307 
308  // Uniform quantization requires fully expressed parameters, including
309  // expressed type.
310  if (!expressedType)
311  return emitError() << "uniform quantization requires expressed type";
312 
313  // Verify that the expressed type is floating point.
314  // If this restriction is ever eliminated, the parser/printer must be
315  // extended.
316  if (!llvm::isa<FloatType>(expressedType))
317  return emitError() << "expressed type must be floating point";
318 
319  // Verify scale.
320  double minScale = getMinScale(expressedType);
321  double maxScale = getMaxScale(expressedType);
322  if (scale < minScale || scale > maxScale)
323  return emitError() << "scale out of expressed type range [" << minScale
324  << ", " << maxScale << "]";
325 
326  return success();
327 }
328 
329 double UniformQuantizedType::getScale() const { return getImpl()->scale; }
330 
332  return getImpl()->zeroPoint;
333 }
334 
336  unsigned flags, Type storageType, Type expressedType,
337  ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
338  int32_t quantizedDimension, int64_t storageTypeMin,
339  int64_t storageTypeMax) {
340  return Base::get(storageType.getContext(), flags, storageType, expressedType,
341  scales, zeroPoints, quantizedDimension, storageTypeMin,
342  storageTypeMax);
343 }
344 
346  function_ref<InFlightDiagnostic()> emitError, unsigned flags,
347  Type storageType, Type expressedType, ArrayRef<double> scales,
348  ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
349  int64_t storageTypeMin, int64_t storageTypeMax) {
350  return Base::getChecked(emitError, storageType.getContext(), flags,
351  storageType, expressedType, scales, zeroPoints,
352  quantizedDimension, storageTypeMin, storageTypeMax);
353 }
354 
356  function_ref<InFlightDiagnostic()> emitError, unsigned flags,
357  Type storageType, Type expressedType, ArrayRef<double> scales,
358  ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
359  int64_t storageTypeMin, int64_t storageTypeMax) {
360  if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
361  expressedType, storageTypeMin,
362  storageTypeMax))) {
363  return failure();
364  }
365 
366  // Uniform quantization requires fully expressed parameters, including
367  // expressed type.
368  if (!expressedType)
369  return emitError() << "uniform quantization requires expressed type";
370 
371  // Verify that the expressed type is floating point.
372  // If this restriction is ever eliminated, the parser/printer must be
373  // extended.
374  if (!llvm::isa<FloatType>(expressedType))
375  return emitError() << "expressed type must be floating point";
376 
377  // Ensure that the number of scales and zeroPoints match.
378  if (scales.size() != zeroPoints.size())
379  return emitError() << "illegal number of scales and zeroPoints: "
380  << scales.size() << ", " << zeroPoints.size();
381 
382  // Verify scale.
383  double minScale = getMinScale(expressedType);
384  double maxScale = getMaxScale(expressedType);
385  for (double scale : scales) {
386  if (scale < minScale || scale > maxScale)
387  return emitError() << "scale out of expressed type range [" << minScale
388  << ", " << maxScale << "]";
389  }
390 
391  // Verify quantized dimension.
392  if (quantizedDimension < 0)
393  return emitError() << "illegal quantized dimension: " << quantizedDimension;
394 
395  return success();
396 }
397 
399  return getImpl()->getScales();
400 }
401 
403  return getImpl()->getZeroPoints();
404 }
405 
407  return getImpl()->quantizedDimension;
408 }
409 
411  unsigned flags, Type storageType, Type expressedType,
412  DenseElementsAttr scales, DenseElementsAttr zeroPoints,
413  ArrayRef<int32_t> quantizedDimensions, ArrayRef<int64_t> blockSizes,
414  int64_t storageTypeMin, int64_t storageTypeMax) {
415  return Base::get(storageType.getContext(), flags, storageType, expressedType,
416  scales, zeroPoints, quantizedDimensions, blockSizes,
417  storageTypeMin, storageTypeMax);
418 }
419 
421  function_ref<InFlightDiagnostic()> emitError, unsigned flags,
422  Type storageType, Type expressedType, DenseElementsAttr scales,
423  DenseElementsAttr zeroPoints, ArrayRef<int32_t> quantizedDimensions,
424  ArrayRef<int64_t> blockSizes, int64_t storageTypeMin,
425  int64_t storageTypeMax) {
426  return Base::getChecked(emitError, storageType.getContext(), flags,
427  storageType, expressedType, scales, zeroPoints,
428  quantizedDimensions, blockSizes, storageTypeMin,
429  storageTypeMax);
430 }
431 
433  function_ref<InFlightDiagnostic()> emitError, unsigned flags,
434  Type storageType, Type expressedType, DenseElementsAttr scales,
435  DenseElementsAttr zeroPoints, ArrayRef<int32_t> quantizedDimensions,
436  ArrayRef<int64_t> blockSizes, int64_t storageTypeMin,
437  int64_t storageTypeMax) {
438  if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
439  expressedType, storageTypeMin,
440  storageTypeMax))) {
441  return failure();
442  }
443 
444  // Uniform quantization requires fully expressed parameters, including
445  // expressed type.
446  if (!expressedType)
447  return emitError() << "uniform quantization requires expressed type";
448 
449  // Verify that the expressed type is floating point.
450  // If this restriction is ever eliminated, the parser/printer must be
451  // extended.
452  if (!llvm::isa<FloatType>(expressedType))
453  return emitError() << "expressed type must be floating point";
454 
455  // Verify scale type to match expressedType.
456  if (scales.getType().getElementType() != expressedType) {
457  return emitError() << "type of scale values "
458  << scales.getType().getElementType()
459  << " must match the expressed type " << expressedType;
460  }
461 
462  // Verify zero-point type to match storageType.
463  if (zeroPoints.getType().getElementType() != storageType) {
464  return emitError() << "type of zero point values "
465  << zeroPoints.getType().getElementType()
466  << " must match the storage type " << storageType;
467  }
468 
469  // Ensure that the shape of scales and zeroPoints match.
470  if (scales.getType().getShape() != zeroPoints.getType().getShape())
471  return emitError() << "shape of scales and zeroPoints ("
472  << scales.getType().getShape() << " vs "
473  << zeroPoints.getType().getShape() << ") does not match";
474 
475  // Ensure that the number of quantized-dimensions and block-sizes match.
476  if (quantizedDimensions.size() != blockSizes.size())
477  return emitError() << "number of quantized dimensions and block sizes ("
478  << scales.size() << " vs " << zeroPoints.size()
479  << ") does not match";
480 
481  // Verify quantized dimension.
482  for (auto quantizedDimension : quantizedDimensions) {
483  if (quantizedDimension < 0)
484  return emitError() << "illegal quantized dimension: "
485  << quantizedDimension;
486  }
487 
488  // Verify block sizes.
489  for (auto blockSize : blockSizes) {
490  if (blockSize <= 0)
491  return emitError() << "illegal block size: " << blockSize;
492  }
493 
494  return success();
495 }
496 
498  return getImpl()->getScales();
499 }
500 
502  return getImpl()->getZeroPoints();
503 }
504 
507  return getImpl()->getQuantizedDimensions();
508 }
509 
511  return getImpl()->getBlockSizes();
512 }
513 
517  result.reserve(getQuantizedDimensions().size());
518 
519  for (auto [dim, size] :
520  llvm::zip(getQuantizedDimensions(), getBlockSizes())) {
521  result.push_back({dim, size});
522  }
523 
524  return result;
525 }
526 
528  double min, double max) {
529  return Base::get(expressedType.getContext(), expressedType, min, max);
530 }
531 
533  function_ref<InFlightDiagnostic()> emitError, Type expressedType,
534  double min, double max) {
535  return Base::getChecked(emitError, expressedType.getContext(), expressedType,
536  min, max);
537 }
538 
540  function_ref<InFlightDiagnostic()> emitError, Type expressedType,
541  double min, double max) {
542  // Verify that the expressed type is floating point.
543  // If this restriction is ever eliminated, the parser/printer must be
544  // extended.
545  if (!llvm::isa<FloatType>(expressedType))
546  return emitError() << "expressed type must be floating point";
547  if (max <= min)
548  return emitError() << "illegal min and max: (" << min << ":" << max << ")";
549 
550  return success();
551 }
552 
553 double CalibratedQuantizedType::getMin() const { return getImpl()->min; }
554 
555 double CalibratedQuantizedType::getMax() const { return getImpl()->max; }
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
An attribute that represents a reference to a dense vector or tensor object.
int64_t size() const
Returns the number of elements held by this attribute.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:107
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
A quantized type that maps storage to/from expressed types in an unspecified way.
Definition: QuantTypes.h:203
static AnyQuantizedType get(unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
Definition: QuantTypes.cpp:242
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
Definition: QuantTypes.cpp:261
static AnyQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
Definition: QuantTypes.cpp:251
A quantized type that infers its range from given min/max values.
Definition: QuantTypes.h:524
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double min, double max)
Verifies construction invariants and issues errors/warnings.
Definition: QuantTypes.cpp:539
static CalibratedQuantizedType get(Type expressedType, double min, double max)
Gets an instance of the type with all parameters specified but not checked.
Definition: QuantTypes.cpp:527
static CalibratedQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double min, double max)
Gets an instance of the type with all specified parameters checked.
Definition: QuantTypes.cpp:532
Base class for all quantized types known to this dialect.
Definition: QuantTypes.h:50
Type getExpressedType() const
Gets the original expressed type that this quantized type approximates.
Definition: QuantTypes.cpp:106
bool hasStorageTypeBounds() const
Return whether the storage type has explicit min or max boundaries different from the minimum and max...
Definition: QuantTypes.cpp:89
static Type castToStorageType(Type quantizedType)
Casts from a type based on a QuantizedType to a corresponding type based on the storageType (returns ...
Definition: QuantTypes.cpp:152
Type castExpressedToStorageType(Type candidateType)
Casts from a type based on the expressedType to the equivalent type based on storageType by way of th...
Definition: QuantTypes.cpp:234
static Type castToExpressedType(Type quantizedType)
Casts from a type based on QuantizedType to a corresponding type based on the expressedType (returns ...
Definition: QuantTypes.cpp:207
static QuantizedType getQuantizedElementType(Type primitiveOrContainerType)
Returns the element type as a QuantizedType or nullptr if it is not a quantized type.
Definition: QuantTypes.cpp:119
unsigned getFlags() const
Gets the flags associated with this type.
Definition: QuantTypes.cpp:36
int64_t getStorageTypeMax() const
The maximum value that storageType can take.
Definition: QuantTypes.cpp:85
unsigned getStorageTypeIntegralWidth() const
Gets the integral bit width that the underlying storage type can exactly represent.
Definition: QuantTypes.cpp:100
static bool classof(Type type)
Support method to enable LLVM-style type casting.
Definition: QuantTypes.cpp:40
Type castFromStorageType(Type candidateType)
Casts from a type based on the storageType to a corresponding type based on this type (returns nullpt...
Definition: QuantTypes.cpp:128
int64_t getStorageTypeMin() const
The minimum value that storageType can take.
Definition: QuantTypes.cpp:81
Type getStorageType() const
Gets the underlying type used for to store values.
Definition: QuantTypes.cpp:77
Type castFromExpressedType(Type candidateType)
Casts from a type based on the expressedType to a corresponding type based on this type (returns null...
Definition: QuantTypes.cpp:179
bool isCompatibleExpressedType(Type candidateExpressedType)
Returns whether the candidateExpressedType is a match for this QuantizedType.
Definition: QuantTypes.cpp:110
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Definition: QuantTypes.cpp:45
Represents per-axis (also known as per-channel quantization).
Definition: QuantTypes.h:324
static UniformQuantizedPerAxisType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
Definition: QuantTypes.cpp:345
static UniformQuantizedPerAxisType get(unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
Definition: QuantTypes.cpp:335
int32_t getQuantizedDimension() const
Specifies the dimension of the Tensor's shape that the scales and zero_points correspond to.
Definition: QuantTypes.cpp:406
ArrayRef< int64_t > getZeroPoints() const
Gets the storage values corresponding to the real value 0 in the affine equation.
Definition: QuantTypes.cpp:402
ArrayRef< double > getScales() const
Gets the quantization scales.
Definition: QuantTypes.cpp:398
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
Definition: QuantTypes.cpp:355
Represents sub-channel (also known as blockwise quantization).
Definition: QuantTypes.h:409
ArrayRef< int32_t > getQuantizedDimensions() const
Gets the quantized dimensions.
Definition: QuantTypes.cpp:506
DenseElementsAttr getZeroPoints() const
Gets the quantization zero-points.
Definition: QuantTypes.cpp:501
ArrayRef< int64_t > getBlockSizes() const
Gets the block sizes for the quantized dimensions.
Definition: QuantTypes.cpp:510
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, DenseElementsAttr scales, DenseElementsAttr zeroPoints, ArrayRef< int32_t > quantizedDimensions, ArrayRef< int64_t > blockSizes, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
Definition: QuantTypes.cpp:432
const SmallVector< std::pair< int32_t, int64_t > > getBlockSizeInfo() const
Gets the block size information.
Definition: QuantTypes.cpp:515
static UniformQuantizedSubChannelType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, DenseElementsAttr scales, DenseElementsAttr zeroPoints, ArrayRef< int32_t > quantizedDimensions, ArrayRef< int64_t > blockSizes, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
Definition: QuantTypes.cpp:420
static UniformQuantizedSubChannelType get(unsigned flags, Type storageType, Type expressedType, DenseElementsAttr scales, DenseElementsAttr zeroPoints, ArrayRef< int32_t > quantizedDimensions, ArrayRef< int64_t > blockSizes, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
Definition: QuantTypes.cpp:410
DenseElementsAttr getScales() const
Gets the quantization scales.
Definition: QuantTypes.cpp:497
Represents a family of uniform, quantized types.
Definition: QuantTypes.h:264
double getScale() const
Gets the scale term.
Definition: QuantTypes.cpp:329
int64_t getZeroPoint() const
Gets the storage value corresponding to the real value 0 in the affine equation.
Definition: QuantTypes.cpp:331
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
Definition: QuantTypes.cpp:298
static UniformQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
Definition: QuantTypes.cpp:289
static UniformQuantizedType get(unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
Definition: QuantTypes.cpp:280
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...