MLIR 22.0.0git
QuantTypes.cpp
Go to the documentation of this file.
1//===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "TypeDetail.h"
12
14#include "mlir/IR/MLIRContext.h"
15
16using namespace mlir;
17using namespace mlir::quant;
18using namespace mlir::quant::detail;
19
20namespace {
21
22// Return the minimum scale representable in a given float type
23double getMinScale(Type expressedType) {
24 auto floatType = cast<FloatType>(expressedType);
25 return APFloat::getSmallest(floatType.getFloatSemantics()).convertToDouble();
26}
27
28// Return the maximum scale representable in a given float type
29double getMaxScale(Type expressedType) {
30 auto floatType = cast<FloatType>(expressedType);
31 return APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble();
32}
33
34} // namespace
35
36unsigned QuantizedType::getFlags() const {
37 return static_cast<ImplType *>(impl)->flags;
38}
39
41 return llvm::isa<QuantDialect>(type.getDialect());
42}
43
44LogicalResult
46 unsigned flags, Type storageType,
47 Type expressedType, int64_t storageTypeMin,
48 int64_t storageTypeMax) {
49 // Verify that the storage type is integral.
50 // This restriction may be lifted at some point in favor of using bf16
51 // or f16 as exact representations on hardware where that is advantageous.
52 auto intStorageType = llvm::dyn_cast<IntegerType>(storageType);
53 if (!intStorageType)
54 return emitError() << "storage type must be integral";
55 unsigned integralWidth = intStorageType.getWidth();
56
57 // Verify storage width.
58 if (integralWidth == 0 || integralWidth > MaxStorageBits)
59 return emitError() << "illegal storage type size: " << integralWidth;
60
61 // Verify storageTypeMin and storageTypeMax.
62 bool isSigned =
64 int64_t defaultIntegerMin =
66 int64_t defaultIntegerMax =
68 if (storageTypeMax - storageTypeMin <= 0 ||
69 storageTypeMin < defaultIntegerMin ||
70 storageTypeMax > defaultIntegerMax) {
71 return emitError() << "illegal storage min and storage max: ("
72 << storageTypeMin << ":" << storageTypeMax << ")";
73 }
74 return success();
75}
76
78 return static_cast<ImplType *>(impl)->storageType;
79}
80
82 return static_cast<ImplType *>(impl)->storageTypeMin;
83}
84
86 return static_cast<ImplType *>(impl)->storageTypeMax;
87}
88
90 unsigned int integralWidth = getStorageTypeIntegralWidth();
92 int64_t defaultIntegerMin =
94 int64_t defaultIntegerMax =
96 return defaultIntegerMin != getStorageTypeMin() ||
97 defaultIntegerMax != getStorageTypeMax();
98}
99
101 // NOTE: If ever supporting non-integral storage types, some other scheme
102 // for determining the width will be needed.
103 return static_cast<ImplType *>(impl)->storageType.getIntOrFloatBitWidth();
104}
105
107 return static_cast<ImplType *>(impl)->expressedType;
108}
109
110bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) {
111 if (llvm::isa<ShapedType>(candidateExpressedType)) {
112 return llvm::cast<ShapedType>(candidateExpressedType).getElementType() ==
114 }
115 return candidateExpressedType == getExpressedType();
116}
117
120 if (llvm::isa<ShapedType>(primitiveOrContainerType)) {
121 Type elementType =
122 llvm::cast<ShapedType>(primitiveOrContainerType).getElementType();
123 return llvm::dyn_cast<QuantizedType>(elementType);
124 }
125 return llvm::dyn_cast<QuantizedType>(primitiveOrContainerType);
126}
127
129 if (candidateType == getStorageType()) {
130 // i.e. i8 -> quant<"uniform[i8:f32]{1.0}">
131 return *this;
132 }
133 if (llvm::isa<RankedTensorType>(candidateType)) {
134 // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
135 return RankedTensorType::get(
136 llvm::cast<RankedTensorType>(candidateType).getShape(),
138 }
139 if (llvm::isa<UnrankedTensorType>(candidateType)) {
140 // i.e. tensor<xi8> -> tensor<x!quant<"uniform[i8:f32]{1.0}">>
141 return UnrankedTensorType::get(getStorageType());
142 }
143 if (llvm::isa<VectorType>(candidateType)) {
144 // i.e. vector<4xi8> -> vector<4x!quant<"uniform[i8:f32]{1.0}">>
145 return VectorType::get(llvm::cast<VectorType>(candidateType).getShape(),
147 }
148
149 return nullptr;
150}
151
153 if (llvm::isa<QuantizedType>(quantizedType)) {
154 // i.e. quant<"uniform[i8:f32]{1.0}"> -> i8
155 return llvm::cast<QuantizedType>(quantizedType).getStorageType();
156 }
157 if (llvm::isa<ShapedType>(quantizedType)) {
158 // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
159 ShapedType sType = llvm::cast<ShapedType>(quantizedType);
160 if (!llvm::isa<QuantizedType>(sType.getElementType())) {
161 return nullptr;
162 }
163 Type storageType =
164 llvm::cast<QuantizedType>(sType.getElementType()).getStorageType();
165 if (llvm::isa<RankedTensorType>(quantizedType)) {
166 return RankedTensorType::get(sType.getShape(), storageType);
167 }
168 if (llvm::isa<UnrankedTensorType>(quantizedType)) {
169 return UnrankedTensorType::get(storageType);
170 }
171 if (llvm::isa<VectorType>(quantizedType)) {
172 return VectorType::get(sType.getShape(), storageType);
173 }
174 }
175
176 return nullptr;
177}
178
180 if (candidateType == getExpressedType()) {
181 // i.e. f32 -> quant<"uniform[i8:f32]{1.0}">
182 return *this;
183 }
184 if (llvm::isa<ShapedType>(candidateType)) {
185 ShapedType candidateShapedType = llvm::cast<ShapedType>(candidateType);
186 if (candidateShapedType.getElementType() != getExpressedType()) {
187 return nullptr;
188 }
189
190 if (llvm::isa<RankedTensorType>(candidateType)) {
191 // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
192 return RankedTensorType::get(candidateShapedType.getShape(), *this);
193 }
194 if (llvm::isa<UnrankedTensorType>(candidateType)) {
195 // i.e. tensor<xf32> -> tensor<x!quant<"uniform[i8:f32]{1.0}">>
196 return UnrankedTensorType::get(*this);
197 }
198 if (llvm::isa<VectorType>(candidateType)) {
199 // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
200 return VectorType::get(candidateShapedType.getShape(), *this);
201 }
202 }
203
204 return nullptr;
205}
206
208 if (llvm::isa<QuantizedType>(quantizedType)) {
209 // i.e. quant<"uniform[i8:f32]{1.0}"> -> f32
210 return llvm::cast<QuantizedType>(quantizedType).getExpressedType();
211 }
212 if (llvm::isa<ShapedType>(quantizedType)) {
213 // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
214 ShapedType sType = llvm::cast<ShapedType>(quantizedType);
215 if (!llvm::isa<QuantizedType>(sType.getElementType())) {
216 return nullptr;
217 }
218 Type expressedType =
219 llvm::cast<QuantizedType>(sType.getElementType()).getExpressedType();
220 if (llvm::isa<RankedTensorType>(quantizedType)) {
221 return RankedTensorType::get(sType.getShape(), expressedType);
222 }
223 if (llvm::isa<UnrankedTensorType>(quantizedType)) {
224 return UnrankedTensorType::get(expressedType);
225 }
226 if (llvm::isa<VectorType>(quantizedType)) {
227 return VectorType::get(sType.getShape(), expressedType);
228 }
229 }
230
231 return nullptr;
232}
233
235 Type expressedQuantizedType = castFromExpressedType(candidateType);
236 if (!expressedQuantizedType) {
237 return nullptr;
238 }
239 return QuantizedType::castToStorageType(expressedQuantizedType);
240}
241
242AnyQuantizedType AnyQuantizedType::get(unsigned flags, Type storageType,
243 Type expressedType,
244 int64_t storageTypeMin,
245 int64_t storageTypeMax) {
246 return Base::get(storageType.getContext(), flags, storageType, expressedType,
247 storageTypeMin, storageTypeMax);
248}
249
252 unsigned flags, Type storageType,
253 Type expressedType, int64_t storageTypeMin,
254 int64_t storageTypeMax) {
255 return Base::getChecked(emitError, storageType.getContext(), flags,
256 storageType, expressedType, storageTypeMin,
257 storageTypeMax);
258}
259
260LogicalResult
262 unsigned flags, Type storageType,
263 Type expressedType, int64_t storageTypeMin,
264 int64_t storageTypeMax) {
265 if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
266 expressedType, storageTypeMin,
267 storageTypeMax))) {
268 return failure();
269 }
270
271 // Verify that the expressed type is floating point.
272 // If this restriction is ever eliminated, the parser/printer must be
273 // extended.
274 if (expressedType && !llvm::isa<FloatType>(expressedType))
275 return emitError() << "expressed type must be floating point";
276
277 return success();
278}
279
281 Type expressedType, double scale,
282 int64_t zeroPoint,
283 int64_t storageTypeMin,
284 int64_t storageTypeMax) {
285 return Base::get(storageType.getContext(), flags, storageType, expressedType,
286 scale, zeroPoint, storageTypeMin, storageTypeMax);
287}
288
290 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
291 Type storageType, Type expressedType, double scale, int64_t zeroPoint,
292 int64_t storageTypeMin, int64_t storageTypeMax) {
293 return Base::getChecked(emitError, storageType.getContext(), flags,
294 storageType, expressedType, scale, zeroPoint,
295 storageTypeMin, storageTypeMax);
296}
297
299 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
300 Type storageType, Type expressedType, double scale, int64_t zeroPoint,
301 int64_t storageTypeMin, int64_t storageTypeMax) {
302 if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
303 expressedType, storageTypeMin,
304 storageTypeMax))) {
305 return failure();
306 }
307
308 // Uniform quantization requires fully expressed parameters, including
309 // expressed type.
310 if (!expressedType)
311 return emitError() << "uniform quantization requires expressed type";
312
313 // Verify that the expressed type is floating point.
314 // If this restriction is ever eliminated, the parser/printer must be
315 // extended.
316 if (!llvm::isa<FloatType>(expressedType))
317 return emitError() << "expressed type must be floating point";
318
319 // Verify scale.
320 double minScale = getMinScale(expressedType);
321 double maxScale = getMaxScale(expressedType);
322 if (scale < minScale || scale > maxScale)
323 return emitError() << "scale out of expressed type range [" << minScale
324 << ", " << maxScale << "]";
325
326 return success();
327}
328
329double UniformQuantizedType::getScale() const { return getImpl()->scale; }
330
332 return getImpl()->zeroPoint;
333}
334
336 unsigned flags, Type storageType, Type expressedType,
337 ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
338 int32_t quantizedDimension, int64_t storageTypeMin,
339 int64_t storageTypeMax) {
340 return Base::get(storageType.getContext(), flags, storageType, expressedType,
341 scales, zeroPoints, quantizedDimension, storageTypeMin,
342 storageTypeMax);
343}
344
346 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
347 Type storageType, Type expressedType, ArrayRef<double> scales,
348 ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
349 int64_t storageTypeMin, int64_t storageTypeMax) {
350 return Base::getChecked(emitError, storageType.getContext(), flags,
351 storageType, expressedType, scales, zeroPoints,
352 quantizedDimension, storageTypeMin, storageTypeMax);
353}
354
356 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
357 Type storageType, Type expressedType, ArrayRef<double> scales,
358 ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
359 int64_t storageTypeMin, int64_t storageTypeMax) {
360 if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
361 expressedType, storageTypeMin,
362 storageTypeMax))) {
363 return failure();
364 }
365
366 // Uniform quantization requires fully expressed parameters, including
367 // expressed type.
368 if (!expressedType)
369 return emitError() << "uniform quantization requires expressed type";
370
371 // Verify that the expressed type is floating point.
372 // If this restriction is ever eliminated, the parser/printer must be
373 // extended.
374 if (!llvm::isa<FloatType>(expressedType))
375 return emitError() << "expressed type must be floating point";
376
377 // Ensure that the number of scales and zeroPoints match.
378 if (scales.size() != zeroPoints.size())
379 return emitError() << "illegal number of scales and zeroPoints: "
380 << scales.size() << ", " << zeroPoints.size();
381
382 // Verify scale.
383 double minScale = getMinScale(expressedType);
384 double maxScale = getMaxScale(expressedType);
385 for (double scale : scales) {
386 if (scale < minScale || scale > maxScale)
387 return emitError() << "scale out of expressed type range [" << minScale
388 << ", " << maxScale << "]";
389 }
390
391 // Verify quantized dimension.
392 if (quantizedDimension < 0)
393 return emitError() << "illegal quantized dimension: " << quantizedDimension;
394
395 return success();
396}
397
399 return getImpl()->getScales();
400}
401
403 return getImpl()->getZeroPoints();
404}
405
407 return getImpl()->quantizedDimension;
408}
409
411 unsigned flags, Type storageType, Type expressedType,
412 DenseElementsAttr scales, DenseElementsAttr zeroPoints,
413 ArrayRef<int32_t> quantizedDimensions, ArrayRef<int64_t> blockSizes,
414 int64_t storageTypeMin, int64_t storageTypeMax) {
415 return Base::get(storageType.getContext(), flags, storageType, expressedType,
416 scales, zeroPoints, quantizedDimensions, blockSizes,
417 storageTypeMin, storageTypeMax);
418}
419
421 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
422 Type storageType, Type expressedType, DenseElementsAttr scales,
423 DenseElementsAttr zeroPoints, ArrayRef<int32_t> quantizedDimensions,
424 ArrayRef<int64_t> blockSizes, int64_t storageTypeMin,
425 int64_t storageTypeMax) {
426 return Base::getChecked(emitError, storageType.getContext(), flags,
427 storageType, expressedType, scales, zeroPoints,
428 quantizedDimensions, blockSizes, storageTypeMin,
429 storageTypeMax);
430}
431
433 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
434 Type storageType, Type expressedType, DenseElementsAttr scales,
435 DenseElementsAttr zeroPoints, ArrayRef<int32_t> quantizedDimensions,
436 ArrayRef<int64_t> blockSizes, int64_t storageTypeMin,
437 int64_t storageTypeMax) {
438 if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
439 expressedType, storageTypeMin,
440 storageTypeMax))) {
441 return failure();
442 }
443
444 // Uniform quantization requires fully expressed parameters, including
445 // expressed type.
446 if (!expressedType)
447 return emitError() << "uniform quantization requires expressed type";
448
449 // Verify that the expressed type is floating point.
450 // If this restriction is ever eliminated, the parser/printer must be
451 // extended.
452 if (!llvm::isa<FloatType>(expressedType))
453 return emitError() << "expressed type must be floating point";
454
455 // Verify scale type to match expressedType.
456 if (scales.getType().getElementType() != expressedType) {
457 return emitError() << "type of scale values "
458 << scales.getType().getElementType()
459 << " must match the expressed type " << expressedType;
460 }
461
462 // Verify zero-point type to match storageType.
463 if (zeroPoints.getType().getElementType() != storageType) {
464 return emitError() << "type of zero point values "
465 << zeroPoints.getType().getElementType()
466 << " must match the storage type " << storageType;
467 }
468
469 // Ensure that the shape of scales and zeroPoints match.
470 if (scales.getType().getShape() != zeroPoints.getType().getShape())
471 return emitError() << "shape of scales and zeroPoints ("
472 << scales.getType().getShape() << " vs "
473 << zeroPoints.getType().getShape() << ") does not match";
474
475 // Ensure that the number of quantized-dimensions and block-sizes match.
476 if (quantizedDimensions.size() != blockSizes.size())
477 return emitError() << "number of quantized dimensions and block sizes ("
478 << scales.size() << " vs " << zeroPoints.size()
479 << ") does not match";
480
481 // Verify quantized dimension.
482 for (auto quantizedDimension : quantizedDimensions) {
483 if (quantizedDimension < 0)
484 return emitError() << "illegal quantized dimension: "
485 << quantizedDimension;
486 }
487
488 // Verify block sizes.
489 for (auto blockSize : blockSizes) {
490 if (blockSize <= 0)
491 return emitError() << "illegal block size: " << blockSize;
492 }
493
494 return success();
495}
496
500
502 return getImpl()->getZeroPoints();
503}
504
507 return getImpl()->getQuantizedDimensions();
508}
509
511 return getImpl()->getBlockSizes();
512}
513
517 result.reserve(getQuantizedDimensions().size());
518
519 for (auto [dim, size] :
520 llvm::zip(getQuantizedDimensions(), getBlockSizes())) {
521 result.push_back({dim, size});
522 }
523
524 return result;
525}
526
528 double min, double max) {
529 return Base::get(expressedType.getContext(), expressedType, min, max);
530}
531
534 double min, double max) {
535 return Base::getChecked(emitError, expressedType.getContext(), expressedType,
536 min, max);
537}
538
541 double min, double max) {
542 // Verify that the expressed type is floating point.
543 // If this restriction is ever eliminated, the parser/printer must be
544 // extended.
545 if (!llvm::isa<FloatType>(expressedType))
546 return emitError() << "expressed type must be floating point";
547 if (max <= min)
548 return emitError() << "illegal min and max: (" << min << ":" << max << ")";
549
550 return success();
551}
552
553double CalibratedQuantizedType::getMin() const { return getImpl()->min; }
554
555double CalibratedQuantizedType::getMax() const { return getImpl()->max; }
return success()
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
An attribute that represents a reference to a dense vector or tensor object.
int64_t size() const
Returns the number of elements held by this attribute.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class represents a diagnostic that is inflight and set to be reported.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition Types.h:107
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition Types.cpp:76
A quantized type that maps storage to/from expressed types in an unspecified way.
Definition QuantTypes.h:203
static AnyQuantizedType get(unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
static AnyQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
A quantized type that infers its range from given min/max values.
Definition QuantTypes.h:524
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double min, double max)
Verifies construction invariants and issues errors/warnings.
static CalibratedQuantizedType get(Type expressedType, double min, double max)
Gets an instance of the type with all parameters specified but not checked.
static CalibratedQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double min, double max)
Gets an instance of the type with all specified parameters checked.
Base class for all quantized types known to this dialect.
Definition QuantTypes.h:50
Type getExpressedType() const
Gets the original expressed type that this quantized type approximates.
static constexpr unsigned MaxStorageBits
The maximum number of bits supported for storage types.
Definition QuantTypes.h:56
bool hasStorageTypeBounds() const
Return whether the storage type has explicit min or max boundaries different from the minimum and max...
static Type castToStorageType(Type quantizedType)
Casts from a type based on a QuantizedType to a corresponding type based on the storageType (returns ...
Type castExpressedToStorageType(Type candidateType)
Casts from a type based on the expressedType to the equivalent type based on storageType by way of th...
detail::QuantizedTypeStorage ImplType
Definition QuantTypes.h:52
static Type castToExpressedType(Type quantizedType)
Casts from a type based on QuantizedType to a corresponding type based on the expressedType (returns ...
bool isSigned() const
Whether the storage type should be interpreted as a signed quantity (true) or an unsigned value (fals...
Definition QuantTypes.h:103
constexpr Type()=default
static QuantizedType getQuantizedElementType(Type primitiveOrContainerType)
Returns the element type as a QuantizedType or nullptr if it is not a quantized type.
unsigned getFlags() const
Gets the flags associated with this type.
int64_t getStorageTypeMax() const
The maximum value that storageType can take.
static int64_t getDefaultMaximumForInteger(bool isSigned, unsigned integralWidth)
Gets the maximum possible stored by a storageType.
Definition QuantTypes.h:78
unsigned getStorageTypeIntegralWidth() const
Gets the integral bit width that the underlying storage type can exactly represent.
static bool classof(Type type)
Support method to enable LLVM-style type casting.
Type castFromStorageType(Type candidateType)
Casts from a type based on the storageType to a corresponding type based on this type (returns nullpt...
int64_t getStorageTypeMin() const
The minimum value that storageType can take.
static int64_t getDefaultMinimumForInteger(bool isSigned, unsigned integralWidth)
Gets the minimum possible stored by a storageType.
Definition QuantTypes.h:68
Type getStorageType() const
Gets the underlying type used for to store values.
Type castFromExpressedType(Type candidateType)
Casts from a type based on the expressedType to a corresponding type based on this type (returns null...
bool isCompatibleExpressedType(Type candidateExpressedType)
Returns whether the candidateExpressedType is a match for this QuantizedType.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Represents per-axis (also known as per-channel quantization).
Definition QuantTypes.h:324
static UniformQuantizedPerAxisType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
static UniformQuantizedPerAxisType get(unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
int32_t getQuantizedDimension() const
Specifies the dimension of the Tensor's shape that the scales and zero_points correspond to.
ArrayRef< int64_t > getZeroPoints() const
Gets the storage values corresponding to the real value 0 in the affine equation.
ArrayRef< double > getScales() const
Gets the quantization scales.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
Represents sub-channel (also known as blockwise quantization).
Definition QuantTypes.h:409
ArrayRef< int32_t > getQuantizedDimensions() const
Gets the quantized dimensions.
DenseElementsAttr getZeroPoints() const
Gets the quantization zero-points.
ArrayRef< int64_t > getBlockSizes() const
Gets the block sizes for the quantized dimensions.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, DenseElementsAttr scales, DenseElementsAttr zeroPoints, ArrayRef< int32_t > quantizedDimensions, ArrayRef< int64_t > blockSizes, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
const SmallVector< std::pair< int32_t, int64_t > > getBlockSizeInfo() const
Gets the block size information.
static UniformQuantizedSubChannelType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, DenseElementsAttr scales, DenseElementsAttr zeroPoints, ArrayRef< int32_t > quantizedDimensions, ArrayRef< int64_t > blockSizes, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
static UniformQuantizedSubChannelType get(unsigned flags, Type storageType, Type expressedType, DenseElementsAttr scales, DenseElementsAttr zeroPoints, ArrayRef< int32_t > quantizedDimensions, ArrayRef< int64_t > blockSizes, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
DenseElementsAttr getScales() const
Gets the quantization scales.
Represents a family of uniform, quantized types.
Definition QuantTypes.h:264
double getScale() const
Gets the scale term.
int64_t getZeroPoint() const
Gets the storage value corresponding to the real value 0 in the affine equation.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
static UniformQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
static UniformQuantizedType get(unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152