MLIR 23.0.0git
QuantTypes.cpp
Go to the documentation of this file.
1//===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "TypeDetail.h"
13
15#include "mlir/IR/MLIRContext.h"
16
17using namespace mlir;
18using namespace mlir::quant;
19using namespace mlir::quant::detail;
20
21namespace {
22
23// Return the minimum scale representable in a given float type
24double getMinScale(Type expressedType) {
25 auto floatType = cast<FloatType>(expressedType);
26 return APFloat::getSmallest(floatType.getFloatSemantics()).convertToDouble();
27}
28
29// Return the maximum scale representable in a given float type
30double getMaxScale(Type expressedType) {
31 auto floatType = cast<FloatType>(expressedType);
32 return APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble();
33}
34
35} // namespace
36
37unsigned QuantizedType::getFlags() const {
38 return static_cast<ImplType *>(impl)->flags;
39}
40
42 return llvm::isa<QuantDialect>(type.getDialect());
43}
44
45LogicalResult
47 unsigned flags, Type storageType,
48 Type expressedType, int64_t storageTypeMin,
49 int64_t storageTypeMax) {
50 if (auto quantStorageTypeInterface =
51 llvm::dyn_cast<QuantStorageTypeInterface>(storageType)) {
52 unsigned integralWidth = quantStorageTypeInterface.getStorageWidth();
53
54 // Verify storage width.
55 if (integralWidth == 0 || integralWidth > MaxStorageBits)
56 return emitError() << "illegal storage type size: " << integralWidth;
57
59 int64_t defaultMin = quantStorageTypeInterface.getDefaultMinimum(isSigned);
60 int64_t defaultMax = quantStorageTypeInterface.getDefaultMaximum(isSigned);
61
62 if (storageTypeMax - storageTypeMin <= 0 || storageTypeMin < defaultMin ||
63 storageTypeMax > defaultMax) {
64 return emitError() << "illegal storage min and storage max: ("
65 << storageTypeMin << ":" << storageTypeMax << ")";
66 }
67
68 return success();
69 }
70
71 return emitError() << "storage type must implement QuantStorageTypeInterface";
72}
73
75 return static_cast<ImplType *>(impl)->storageType;
76}
77
79 return static_cast<ImplType *>(impl)->storageTypeMin;
80}
81
83 return static_cast<ImplType *>(impl)->storageTypeMax;
84}
85
87 Type storageType = static_cast<ImplType *>(impl)->storageType;
88 auto quantStorageTypeInterface =
89 llvm::dyn_cast<QuantStorageTypeInterface>(storageType);
90
91 int64_t defaultMin = quantStorageTypeInterface.getDefaultMinimum(isSigned());
92 int64_t defaultMax = quantStorageTypeInterface.getDefaultMaximum(isSigned());
93
94 return defaultMin != getStorageTypeMin() || defaultMax != getStorageTypeMax();
95}
96
98 Type storageType = static_cast<ImplType *>(impl)->storageType;
99 auto quantStorageTypeInterface =
100 llvm::dyn_cast<QuantStorageTypeInterface>(storageType);
101
102 return quantStorageTypeInterface.getStorageWidth();
103}
104
106 return static_cast<ImplType *>(impl)->expressedType;
107}
108
109bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) {
110 if (llvm::isa<ShapedType>(candidateExpressedType)) {
111 return llvm::cast<ShapedType>(candidateExpressedType).getElementType() ==
113 }
114 return candidateExpressedType == getExpressedType();
115}
116
119 if (llvm::isa<ShapedType>(primitiveOrContainerType)) {
120 Type elementType =
121 llvm::cast<ShapedType>(primitiveOrContainerType).getElementType();
122 return llvm::dyn_cast<QuantizedType>(elementType);
123 }
124 return llvm::dyn_cast<QuantizedType>(primitiveOrContainerType);
125}
126
128 if (candidateType == getStorageType()) {
129 // i.e. i8 -> quant<"uniform[i8:f32]{1.0}">
130 return *this;
131 }
132 if (llvm::isa<RankedTensorType>(candidateType)) {
133 // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
134 return RankedTensorType::get(
135 llvm::cast<RankedTensorType>(candidateType).getShape(),
137 }
138 if (llvm::isa<UnrankedTensorType>(candidateType)) {
139 // i.e. tensor<xi8> -> tensor<x!quant<"uniform[i8:f32]{1.0}">>
140 return UnrankedTensorType::get(getStorageType());
141 }
142 if (llvm::isa<VectorType>(candidateType)) {
143 // i.e. vector<4xi8> -> vector<4x!quant<"uniform[i8:f32]{1.0}">>
144 return VectorType::get(llvm::cast<VectorType>(candidateType).getShape(),
146 }
147
148 return nullptr;
149}
150
152 if (llvm::isa<QuantizedType>(quantizedType)) {
153 // i.e. quant<"uniform[i8:f32]{1.0}"> -> i8
154 return llvm::cast<QuantizedType>(quantizedType).getStorageType();
155 }
156 if (llvm::isa<ShapedType>(quantizedType)) {
157 // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
158 ShapedType sType = llvm::cast<ShapedType>(quantizedType);
159 if (!llvm::isa<QuantizedType>(sType.getElementType())) {
160 return nullptr;
161 }
162 Type storageType =
163 llvm::cast<QuantizedType>(sType.getElementType()).getStorageType();
164 if (llvm::isa<RankedTensorType>(quantizedType)) {
165 return RankedTensorType::get(sType.getShape(), storageType);
166 }
167 if (llvm::isa<UnrankedTensorType>(quantizedType)) {
168 return UnrankedTensorType::get(storageType);
169 }
170 if (llvm::isa<VectorType>(quantizedType)) {
171 return VectorType::get(sType.getShape(), storageType);
172 }
173 }
174
175 return nullptr;
176}
177
179 if (candidateType == getExpressedType()) {
180 // i.e. f32 -> quant<"uniform[i8:f32]{1.0}">
181 return *this;
182 }
183 if (llvm::isa<ShapedType>(candidateType)) {
184 ShapedType candidateShapedType = llvm::cast<ShapedType>(candidateType);
185 if (candidateShapedType.getElementType() != getExpressedType()) {
186 return nullptr;
187 }
188
189 if (llvm::isa<RankedTensorType>(candidateType)) {
190 // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
191 return RankedTensorType::get(candidateShapedType.getShape(), *this);
192 }
193 if (llvm::isa<UnrankedTensorType>(candidateType)) {
194 // i.e. tensor<xf32> -> tensor<x!quant<"uniform[i8:f32]{1.0}">>
195 return UnrankedTensorType::get(*this);
196 }
197 if (llvm::isa<VectorType>(candidateType)) {
198 // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
199 return VectorType::get(candidateShapedType.getShape(), *this);
200 }
201 }
202
203 return nullptr;
204}
205
207 if (llvm::isa<QuantizedType>(quantizedType)) {
208 // i.e. quant<"uniform[i8:f32]{1.0}"> -> f32
209 return llvm::cast<QuantizedType>(quantizedType).getExpressedType();
210 }
211 if (llvm::isa<ShapedType>(quantizedType)) {
212 // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
213 ShapedType sType = llvm::cast<ShapedType>(quantizedType);
214 if (!llvm::isa<QuantizedType>(sType.getElementType())) {
215 return nullptr;
216 }
217 Type expressedType =
218 llvm::cast<QuantizedType>(sType.getElementType()).getExpressedType();
219 if (llvm::isa<RankedTensorType>(quantizedType)) {
220 return RankedTensorType::get(sType.getShape(), expressedType);
221 }
222 if (llvm::isa<UnrankedTensorType>(quantizedType)) {
223 return UnrankedTensorType::get(expressedType);
224 }
225 if (llvm::isa<VectorType>(quantizedType)) {
226 return VectorType::get(sType.getShape(), expressedType);
227 }
228 }
229
230 return nullptr;
231}
232
234 Type expressedQuantizedType = castFromExpressedType(candidateType);
235 if (!expressedQuantizedType) {
236 return nullptr;
237 }
238 return QuantizedType::castToStorageType(expressedQuantizedType);
239}
240
241AnyQuantizedType AnyQuantizedType::get(unsigned flags, Type storageType,
242 Type expressedType,
243 int64_t storageTypeMin,
244 int64_t storageTypeMax) {
245 return Base::get(storageType.getContext(), flags, storageType, expressedType,
246 storageTypeMin, storageTypeMax);
247}
248
251 unsigned flags, Type storageType,
252 Type expressedType, int64_t storageTypeMin,
253 int64_t storageTypeMax) {
254 return Base::getChecked(emitError, storageType.getContext(), flags,
255 storageType, expressedType, storageTypeMin,
256 storageTypeMax);
257}
258
259LogicalResult
261 unsigned flags, Type storageType,
262 Type expressedType, int64_t storageTypeMin,
263 int64_t storageTypeMax) {
264 if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
265 expressedType, storageTypeMin,
266 storageTypeMax))) {
267 return failure();
268 }
269
270 // Verify that the expressed type is floating point.
271 // If this restriction is ever eliminated, the parser/printer must be
272 // extended.
273 if (expressedType && !llvm::isa<FloatType>(expressedType))
274 return emitError() << "expressed type must be floating point";
275
276 return success();
277}
278
280 Type expressedType, double scale,
281 int64_t zeroPoint,
282 int64_t storageTypeMin,
283 int64_t storageTypeMax) {
284 return Base::get(storageType.getContext(), flags, storageType, expressedType,
285 scale, zeroPoint, storageTypeMin, storageTypeMax);
286}
287
289 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
290 Type storageType, Type expressedType, double scale, int64_t zeroPoint,
291 int64_t storageTypeMin, int64_t storageTypeMax) {
292 return Base::getChecked(emitError, storageType.getContext(), flags,
293 storageType, expressedType, scale, zeroPoint,
294 storageTypeMin, storageTypeMax);
295}
296
298 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
299 Type storageType, Type expressedType, double scale, int64_t zeroPoint,
300 int64_t storageTypeMin, int64_t storageTypeMax) {
301 if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
302 expressedType, storageTypeMin,
303 storageTypeMax))) {
304 return failure();
305 }
306
307 // Uniform quantization requires fully expressed parameters, including
308 // expressed type.
309 if (!expressedType)
310 return emitError() << "uniform quantization requires expressed type";
311
312 // Verify that the expressed type is floating point.
313 // If this restriction is ever eliminated, the parser/printer must be
314 // extended.
315 if (!llvm::isa<FloatType>(expressedType))
316 return emitError() << "expressed type must be floating point";
317
318 // Verify scale.
319 double minScale = getMinScale(expressedType);
320 double maxScale = getMaxScale(expressedType);
321 if (scale < minScale || scale > maxScale)
322 return emitError() << "scale out of expressed type range [" << minScale
323 << ", " << maxScale << "]";
324
325 return success();
326}
327
328double UniformQuantizedType::getScale() const { return getImpl()->scale; }
329
331 return getImpl()->zeroPoint;
332}
333
335 unsigned flags, Type storageType, Type expressedType,
336 ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
337 int32_t quantizedDimension, int64_t storageTypeMin,
338 int64_t storageTypeMax) {
339 return Base::get(storageType.getContext(), flags, storageType, expressedType,
340 scales, zeroPoints, quantizedDimension, storageTypeMin,
341 storageTypeMax);
342}
343
345 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
346 Type storageType, Type expressedType, ArrayRef<double> scales,
347 ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
348 int64_t storageTypeMin, int64_t storageTypeMax) {
349 return Base::getChecked(emitError, storageType.getContext(), flags,
350 storageType, expressedType, scales, zeroPoints,
351 quantizedDimension, storageTypeMin, storageTypeMax);
352}
353
355 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
356 Type storageType, Type expressedType, ArrayRef<double> scales,
357 ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
358 int64_t storageTypeMin, int64_t storageTypeMax) {
359 if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
360 expressedType, storageTypeMin,
361 storageTypeMax))) {
362 return failure();
363 }
364
365 // Uniform quantization requires fully expressed parameters, including
366 // expressed type.
367 if (!expressedType)
368 return emitError() << "uniform quantization requires expressed type";
369
370 // Verify that the expressed type is floating point.
371 // If this restriction is ever eliminated, the parser/printer must be
372 // extended.
373 if (!llvm::isa<FloatType>(expressedType))
374 return emitError() << "expressed type must be floating point";
375
376 // Ensure that the number of scales and zeroPoints match.
377 if (scales.size() != zeroPoints.size())
378 return emitError() << "illegal number of scales and zeroPoints: "
379 << scales.size() << ", " << zeroPoints.size();
380
381 // Verify scale.
382 double minScale = getMinScale(expressedType);
383 double maxScale = getMaxScale(expressedType);
384 for (double scale : scales) {
385 if (scale < minScale || scale > maxScale)
386 return emitError() << "scale out of expressed type range [" << minScale
387 << ", " << maxScale << "]";
388 }
389
390 // Verify quantized dimension.
391 if (quantizedDimension < 0)
392 return emitError() << "illegal quantized dimension: " << quantizedDimension;
393
394 return success();
395}
396
398 return getImpl()->getScales();
399}
400
402 return getImpl()->getZeroPoints();
403}
404
406 return getImpl()->quantizedDimension;
407}
408
410 unsigned flags, Type storageType, Type expressedType,
411 DenseElementsAttr scales, DenseElementsAttr zeroPoints,
412 ArrayRef<int32_t> quantizedDimensions, ArrayRef<int64_t> blockSizes,
413 int64_t storageTypeMin, int64_t storageTypeMax) {
414 return Base::get(storageType.getContext(), flags, storageType, expressedType,
415 scales, zeroPoints, quantizedDimensions, blockSizes,
416 storageTypeMin, storageTypeMax);
417}
418
420 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
421 Type storageType, Type expressedType, DenseElementsAttr scales,
422 DenseElementsAttr zeroPoints, ArrayRef<int32_t> quantizedDimensions,
423 ArrayRef<int64_t> blockSizes, int64_t storageTypeMin,
424 int64_t storageTypeMax) {
425 return Base::getChecked(emitError, storageType.getContext(), flags,
426 storageType, expressedType, scales, zeroPoints,
427 quantizedDimensions, blockSizes, storageTypeMin,
428 storageTypeMax);
429}
430
432 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
433 Type storageType, Type expressedType, DenseElementsAttr scales,
434 DenseElementsAttr zeroPoints, ArrayRef<int32_t> quantizedDimensions,
435 ArrayRef<int64_t> blockSizes, int64_t storageTypeMin,
436 int64_t storageTypeMax) {
437 if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
438 expressedType, storageTypeMin,
439 storageTypeMax))) {
440 return failure();
441 }
442
443 // Uniform quantization requires fully expressed parameters, including
444 // expressed type.
445 if (!expressedType)
446 return emitError() << "uniform quantization requires expressed type";
447
448 // Verify that the expressed type is floating point.
449 // If this restriction is ever eliminated, the parser/printer must be
450 // extended.
451 if (!llvm::isa<FloatType>(expressedType))
452 return emitError() << "expressed type must be floating point";
453
454 // Verify scale type to match expressedType.
455 if (scales.getType().getElementType() != expressedType) {
456 return emitError() << "type of scale values "
457 << scales.getType().getElementType()
458 << " must match the expressed type " << expressedType;
459 }
460
461 // Verify zero-point type to match storageType.
462 if (zeroPoints.getType().getElementType() != storageType) {
463 return emitError() << "type of zero point values "
464 << zeroPoints.getType().getElementType()
465 << " must match the storage type " << storageType;
466 }
467
468 // Ensure that the shape of scales and zeroPoints match.
469 if (scales.getType().getShape() != zeroPoints.getType().getShape())
470 return emitError() << "shape of scales and zeroPoints ("
471 << scales.getType().getShape() << " vs "
472 << zeroPoints.getType().getShape() << ") does not match";
473
474 // Ensure that the number of quantized-dimensions and block-sizes match.
475 if (quantizedDimensions.size() != blockSizes.size())
476 return emitError() << "number of quantized dimensions and block sizes ("
477 << scales.size() << " vs " << zeroPoints.size()
478 << ") does not match";
479
480 // Verify quantized dimension.
481 for (auto quantizedDimension : quantizedDimensions) {
482 if (quantizedDimension < 0)
483 return emitError() << "illegal quantized dimension: "
484 << quantizedDimension;
485 }
486
487 // Verify block sizes.
488 for (auto blockSize : blockSizes) {
489 if (blockSize <= 0)
490 return emitError() << "illegal block size: " << blockSize;
491 }
492
493 return success();
494}
495
499
501 return getImpl()->getZeroPoints();
502}
503
506 return getImpl()->getQuantizedDimensions();
507}
508
510 return getImpl()->getBlockSizes();
511}
512
516 result.reserve(getQuantizedDimensions().size());
517
518 for (auto [dim, size] :
519 llvm::zip(getQuantizedDimensions(), getBlockSizes())) {
520 result.push_back({dim, size});
521 }
522
523 return result;
524}
525
527 double min, double max) {
528 return Base::get(expressedType.getContext(), expressedType, min, max);
529}
530
533 double min, double max) {
534 return Base::getChecked(emitError, expressedType.getContext(), expressedType,
535 min, max);
536}
537
540 double min, double max) {
541 // Verify that the expressed type is floating point.
542 // If this restriction is ever eliminated, the parser/printer must be
543 // extended.
544 if (!llvm::isa<FloatType>(expressedType))
545 return emitError() << "expressed type must be floating point";
546 if (max <= min)
547 return emitError() << "illegal min and max: (" << min << ":" << max << ")";
548
549 return success();
550}
551
552double CalibratedQuantizedType::getMin() const { return getImpl()->min; }
553
554double CalibratedQuantizedType::getMax() const { return getImpl()->max; }
return success()
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
An attribute that represents a reference to a dense vector or tensor object.
int64_t size() const
Returns the number of elements held by this attribute.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class represents a diagnostic that is inflight and set to be reported.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition Types.h:107
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
A quantized type that maps storage to/from expressed types in an unspecified way.
Definition QuantTypes.h:203
static AnyQuantizedType get(unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
static AnyQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
A quantized type that infers its range from given min/max values.
Definition QuantTypes.h:524
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double min, double max)
Verifies construction invariants and issues errors/warnings.
static CalibratedQuantizedType get(Type expressedType, double min, double max)
Gets an instance of the type with all parameters specified but not checked.
static CalibratedQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double min, double max)
Gets an instance of the type with all specified parameters checked.
Base class for all quantized types known to this dialect.
Definition QuantTypes.h:50
Type getExpressedType() const
Gets the original expressed type that this quantized type approximates.
static constexpr unsigned MaxStorageBits
The maximum number of bits supported for storage types.
Definition QuantTypes.h:56
bool hasStorageTypeBounds() const
Return whether the storage type has explicit min or max boundaries different from the minimum and max...
static Type castToStorageType(Type quantizedType)
Casts from a type based on a QuantizedType to a corresponding type based on the storageType (returns ...
Type castExpressedToStorageType(Type candidateType)
Casts from a type based on the expressedType to the equivalent type based on storageType by way of th...
detail::QuantizedTypeStorage ImplType
Definition QuantTypes.h:52
static Type castToExpressedType(Type quantizedType)
Casts from a type based on QuantizedType to a corresponding type based on the expressedType (returns ...
bool isSigned() const
Whether the storage type should be interpreted as a signed quantity (true) or an unsigned value (fals...
Definition QuantTypes.h:103
constexpr Type()=default
static QuantizedType getQuantizedElementType(Type primitiveOrContainerType)
Returns the element type as a QuantizedType or nullptr if it is not a quantized type.
unsigned getFlags() const
Gets the flags associated with this type.
int64_t getStorageTypeMax() const
The maximum value that storageType can take.
unsigned getStorageTypeIntegralWidth() const
Gets the integral bit width that the underlying storage type can exactly represent.
static bool classof(Type type)
Support method to enable LLVM-style type casting.
Type castFromStorageType(Type candidateType)
Casts from a type based on the storageType to a corresponding type based on this type (returns nullpt...
int64_t getStorageTypeMin() const
The minimum value that storageType can take.
Type getStorageType() const
Gets the underlying type used for to store values.
Type castFromExpressedType(Type candidateType)
Casts from a type based on the expressedType to a corresponding type based on this type (returns null...
bool isCompatibleExpressedType(Type candidateExpressedType)
Returns whether the candidateExpressedType is a match for this QuantizedType.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Represents per-axis (also known as per-channel quantization).
Definition QuantTypes.h:324
static UniformQuantizedPerAxisType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
static UniformQuantizedPerAxisType get(unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
int32_t getQuantizedDimension() const
Specifies the dimension of the Tensor's shape that the scales and zero_points correspond to.
ArrayRef< int64_t > getZeroPoints() const
Gets the storage values corresponding to the real value 0 in the affine equation.
ArrayRef< double > getScales() const
Gets the quantization scales.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
Represents sub-channel (also known as blockwise quantization).
Definition QuantTypes.h:409
ArrayRef< int32_t > getQuantizedDimensions() const
Gets the quantized dimensions.
DenseElementsAttr getZeroPoints() const
Gets the quantization zero-points.
ArrayRef< int64_t > getBlockSizes() const
Gets the block sizes for the quantized dimensions.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, DenseElementsAttr scales, DenseElementsAttr zeroPoints, ArrayRef< int32_t > quantizedDimensions, ArrayRef< int64_t > blockSizes, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
const SmallVector< std::pair< int32_t, int64_t > > getBlockSizeInfo() const
Gets the block size information.
static UniformQuantizedSubChannelType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, DenseElementsAttr scales, DenseElementsAttr zeroPoints, ArrayRef< int32_t > quantizedDimensions, ArrayRef< int64_t > blockSizes, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
static UniformQuantizedSubChannelType get(unsigned flags, Type storageType, Type expressedType, DenseElementsAttr scales, DenseElementsAttr zeroPoints, ArrayRef< int32_t > quantizedDimensions, ArrayRef< int64_t > blockSizes, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
DenseElementsAttr getScales() const
Gets the quantization scales.
Represents a family of uniform, quantized types.
Definition QuantTypes.h:264
double getScale() const
Gets the scale term.
int64_t getZeroPoint() const
Gets the storage value corresponding to the real value 0 in the affine equation.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
static UniformQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
static UniformQuantizedType get(unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144