MLIR 23.0.0git
QuantTypes.cpp
Go to the documentation of this file.
1//===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "TypeDetail.h"
15
17#include "mlir/IR/MLIRContext.h"
18
19using namespace mlir;
20using namespace mlir::quant;
21using namespace mlir::quant::detail;
22
23namespace {
24
25// Return the minimum scale representable in a given float type
26double getMinScale(Type expressedType) {
27 auto floatType = cast<FloatType>(expressedType);
28 return APFloat::getSmallest(floatType.getFloatSemantics()).convertToDouble();
29}
30
31// Return the maximum scale representable in a given float type
32double getMaxScale(Type expressedType) {
33 auto floatType = cast<FloatType>(expressedType);
34 return APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble();
35}
36
37} // namespace
38
39unsigned QuantizedType::getFlags() const {
40 return static_cast<ImplType *>(impl)->flags;
41}
42
44 return llvm::isa<QuantDialect>(type.getDialect());
45}
46
47LogicalResult
49 unsigned flags, Type storageType,
50 Type expressedType, int64_t storageTypeMin,
51 int64_t storageTypeMax) {
52 if (auto quantStorageTypeInterface =
53 llvm::dyn_cast<QuantStorageTypeInterface>(storageType)) {
54 unsigned integralWidth = quantStorageTypeInterface.getStorageWidth();
55
56 // Verify storage width.
57 if (integralWidth == 0 || integralWidth > MaxStorageBits)
58 return emitError() << "illegal storage type size: " << integralWidth;
59
61 int64_t defaultMin = quantStorageTypeInterface.getDefaultMinimum(isSigned);
62 int64_t defaultMax = quantStorageTypeInterface.getDefaultMaximum(isSigned);
63
64 if (storageTypeMax - storageTypeMin <= 0 || storageTypeMin < defaultMin ||
65 storageTypeMax > defaultMax) {
66 return emitError() << "illegal storage min and storage max: ("
67 << storageTypeMin << ":" << storageTypeMax << ")";
68 }
69
70 return success();
71 }
72
73 return emitError() << "storage type must implement QuantStorageTypeInterface";
74}
75
77 return static_cast<ImplType *>(impl)->storageType;
78}
79
81 return static_cast<ImplType *>(impl)->storageTypeMin;
82}
83
85 return static_cast<ImplType *>(impl)->storageTypeMax;
86}
87
89 Type storageType = static_cast<ImplType *>(impl)->storageType;
90 auto quantStorageTypeInterface =
91 llvm::dyn_cast<QuantStorageTypeInterface>(storageType);
92
93 int64_t defaultMin = quantStorageTypeInterface.getDefaultMinimum(isSigned());
94 int64_t defaultMax = quantStorageTypeInterface.getDefaultMaximum(isSigned());
95
96 return defaultMin != getStorageTypeMin() || defaultMax != getStorageTypeMax();
97}
98
100 Type storageType = static_cast<ImplType *>(impl)->storageType;
101 auto quantStorageTypeInterface =
102 llvm::dyn_cast<QuantStorageTypeInterface>(storageType);
103
104 return quantStorageTypeInterface.getStorageWidth();
105}
106
108 return static_cast<ImplType *>(impl)->expressedType;
109}
110
111bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) {
112 if (llvm::isa<ShapedType>(candidateExpressedType)) {
113 return llvm::cast<ShapedType>(candidateExpressedType).getElementType() ==
115 }
116 return candidateExpressedType == getExpressedType();
117}
118
121 if (llvm::isa<ShapedType>(primitiveOrContainerType)) {
122 Type elementType =
123 llvm::cast<ShapedType>(primitiveOrContainerType).getElementType();
124 return llvm::dyn_cast<QuantizedType>(elementType);
125 }
126 return llvm::dyn_cast<QuantizedType>(primitiveOrContainerType);
127}
128
130 if (candidateType == getStorageType()) {
131 // i.e. i8 -> quant<"uniform[i8:f32]{1.0}">
132 return *this;
133 }
134 if (llvm::isa<RankedTensorType>(candidateType)) {
135 // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
136 return RankedTensorType::get(
137 llvm::cast<RankedTensorType>(candidateType).getShape(),
139 }
140 if (llvm::isa<UnrankedTensorType>(candidateType)) {
141 // i.e. tensor<xi8> -> tensor<x!quant<"uniform[i8:f32]{1.0}">>
142 return UnrankedTensorType::get(getStorageType());
143 }
144 if (llvm::isa<VectorType>(candidateType)) {
145 // i.e. vector<4xi8> -> vector<4x!quant<"uniform[i8:f32]{1.0}">>
146 return VectorType::get(llvm::cast<VectorType>(candidateType).getShape(),
148 }
149
150 return nullptr;
151}
152
154 if (llvm::isa<QuantizedType>(quantizedType)) {
155 // i.e. quant<"uniform[i8:f32]{1.0}"> -> i8
156 return llvm::cast<QuantizedType>(quantizedType).getStorageType();
157 }
158 if (llvm::isa<ShapedType>(quantizedType)) {
159 // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
160 ShapedType sType = llvm::cast<ShapedType>(quantizedType);
161 if (!llvm::isa<QuantizedType>(sType.getElementType())) {
162 return nullptr;
163 }
164 Type storageType =
165 llvm::cast<QuantizedType>(sType.getElementType()).getStorageType();
166 if (llvm::isa<RankedTensorType>(quantizedType)) {
167 return RankedTensorType::get(sType.getShape(), storageType);
168 }
169 if (llvm::isa<UnrankedTensorType>(quantizedType)) {
170 return UnrankedTensorType::get(storageType);
171 }
172 if (llvm::isa<VectorType>(quantizedType)) {
173 return VectorType::get(sType.getShape(), storageType);
174 }
175 }
176
177 return nullptr;
178}
179
181 if (candidateType == getExpressedType()) {
182 // i.e. f32 -> quant<"uniform[i8:f32]{1.0}">
183 return *this;
184 }
185 if (llvm::isa<ShapedType>(candidateType)) {
186 ShapedType candidateShapedType = llvm::cast<ShapedType>(candidateType);
187 if (candidateShapedType.getElementType() != getExpressedType()) {
188 return nullptr;
189 }
190
191 if (llvm::isa<RankedTensorType>(candidateType)) {
192 // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
193 return RankedTensorType::get(candidateShapedType.getShape(), *this);
194 }
195 if (llvm::isa<UnrankedTensorType>(candidateType)) {
196 // i.e. tensor<xf32> -> tensor<x!quant<"uniform[i8:f32]{1.0}">>
197 return UnrankedTensorType::get(*this);
198 }
199 if (llvm::isa<VectorType>(candidateType)) {
200 // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
201 return VectorType::get(candidateShapedType.getShape(), *this);
202 }
203 }
204
205 return nullptr;
206}
207
209 if (llvm::isa<QuantizedType>(quantizedType)) {
210 // i.e. quant<"uniform[i8:f32]{1.0}"> -> f32
211 return llvm::cast<QuantizedType>(quantizedType).getExpressedType();
212 }
213 if (llvm::isa<ShapedType>(quantizedType)) {
214 // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
215 ShapedType sType = llvm::cast<ShapedType>(quantizedType);
216 if (!llvm::isa<QuantizedType>(sType.getElementType())) {
217 return nullptr;
218 }
219 Type expressedType =
220 llvm::cast<QuantizedType>(sType.getElementType()).getExpressedType();
221 if (llvm::isa<RankedTensorType>(quantizedType)) {
222 return RankedTensorType::get(sType.getShape(), expressedType);
223 }
224 if (llvm::isa<UnrankedTensorType>(quantizedType)) {
225 return UnrankedTensorType::get(expressedType);
226 }
227 if (llvm::isa<VectorType>(quantizedType)) {
228 return VectorType::get(sType.getShape(), expressedType);
229 }
230 }
231
232 return nullptr;
233}
234
236 Type expressedQuantizedType = castFromExpressedType(candidateType);
237 if (!expressedQuantizedType) {
238 return nullptr;
239 }
240 return QuantizedType::castToStorageType(expressedQuantizedType);
241}
242
243AnyQuantizedType AnyQuantizedType::get(unsigned flags, Type storageType,
244 Type expressedType,
245 int64_t storageTypeMin,
246 int64_t storageTypeMax) {
247 return Base::get(storageType.getContext(), flags, storageType, expressedType,
248 storageTypeMin, storageTypeMax);
249}
250
253 unsigned flags, Type storageType,
254 Type expressedType, int64_t storageTypeMin,
255 int64_t storageTypeMax) {
256 return Base::getChecked(emitError, storageType.getContext(), flags,
257 storageType, expressedType, storageTypeMin,
258 storageTypeMax);
259}
260
261LogicalResult
263 unsigned flags, Type storageType,
264 Type expressedType, int64_t storageTypeMin,
265 int64_t storageTypeMax) {
266 if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
267 expressedType, storageTypeMin,
268 storageTypeMax))) {
269 return failure();
270 }
271
272 // Verify that the expressed type is floating point.
273 // If this restriction is ever eliminated, the parser/printer must be
274 // extended.
275 if (expressedType && !llvm::isa<FloatType>(expressedType))
276 return emitError() << "expressed type must be floating point";
277
278 return success();
279}
280
282 Type expressedType, double scale,
283 int64_t zeroPoint,
284 int64_t storageTypeMin,
285 int64_t storageTypeMax) {
286 return Base::get(storageType.getContext(), flags, storageType, expressedType,
287 scale, zeroPoint, storageTypeMin, storageTypeMax);
288}
289
291 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
292 Type storageType, Type expressedType, double scale, int64_t zeroPoint,
293 int64_t storageTypeMin, int64_t storageTypeMax) {
294 return Base::getChecked(emitError, storageType.getContext(), flags,
295 storageType, expressedType, scale, zeroPoint,
296 storageTypeMin, storageTypeMax);
297}
298
300 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
301 Type storageType, Type expressedType, double scale, int64_t zeroPoint,
302 int64_t storageTypeMin, int64_t storageTypeMax) {
303 if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
304 expressedType, storageTypeMin,
305 storageTypeMax))) {
306 return failure();
307 }
308
309 // Uniform quantization requires fully expressed parameters, including
310 // expressed type.
311 if (!expressedType)
312 return emitError() << "uniform quantization requires expressed type";
313
314 // Verify that the expressed type is floating point.
315 // If this restriction is ever eliminated, the parser/printer must be
316 // extended.
317 if (!llvm::isa<FloatType>(expressedType))
318 return emitError() << "expressed type must be floating point";
319
320 // Verify scale.
321 double minScale = getMinScale(expressedType);
322 double maxScale = getMaxScale(expressedType);
323 if (scale < minScale || scale > maxScale)
324 return emitError() << "scale out of expressed type range [" << minScale
325 << ", " << maxScale << "]";
326
327 return success();
328}
329
330double UniformQuantizedType::getScale() const { return getImpl()->scale; }
331
333 return getImpl()->zeroPoint;
334}
335
337 unsigned flags, Type storageType, Type expressedType,
338 ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
339 int32_t quantizedDimension, int64_t storageTypeMin,
340 int64_t storageTypeMax) {
341 return Base::get(storageType.getContext(), flags, storageType, expressedType,
342 scales, zeroPoints, quantizedDimension, storageTypeMin,
343 storageTypeMax);
344}
345
347 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
348 Type storageType, Type expressedType, ArrayRef<double> scales,
349 ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
350 int64_t storageTypeMin, int64_t storageTypeMax) {
351 return Base::getChecked(emitError, storageType.getContext(), flags,
352 storageType, expressedType, scales, zeroPoints,
353 quantizedDimension, storageTypeMin, storageTypeMax);
354}
355
357 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
358 Type storageType, Type expressedType, ArrayRef<double> scales,
359 ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
360 int64_t storageTypeMin, int64_t storageTypeMax) {
361 if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
362 expressedType, storageTypeMin,
363 storageTypeMax))) {
364 return failure();
365 }
366
367 // Uniform quantization requires fully expressed parameters, including
368 // expressed type.
369 if (!expressedType)
370 return emitError() << "uniform quantization requires expressed type";
371
372 // Verify that the expressed type is floating point.
373 // If this restriction is ever eliminated, the parser/printer must be
374 // extended.
375 if (!llvm::isa<FloatType>(expressedType))
376 return emitError() << "expressed type must be floating point";
377
378 // Ensure that the number of scales and zeroPoints match.
379 if (scales.size() != zeroPoints.size())
380 return emitError() << "illegal number of scales and zeroPoints: "
381 << scales.size() << ", " << zeroPoints.size();
382
383 // Verify scale.
384 double minScale = getMinScale(expressedType);
385 double maxScale = getMaxScale(expressedType);
386 for (double scale : scales) {
387 if (scale < minScale || scale > maxScale)
388 return emitError() << "scale out of expressed type range [" << minScale
389 << ", " << maxScale << "]";
390 }
391
392 // Verify quantized dimension.
393 if (quantizedDimension < 0)
394 return emitError() << "illegal quantized dimension: " << quantizedDimension;
395
396 return success();
397}
398
400 return getImpl()->getScales();
401}
402
404 return getImpl()->getZeroPoints();
405}
406
408 return getImpl()->quantizedDimension;
409}
410
412 unsigned flags, Type storageType, Type expressedType,
413 DenseElementsAttr scales, DenseElementsAttr zeroPoints,
414 ArrayRef<int32_t> quantizedDimensions, ArrayRef<int64_t> blockSizes,
415 int64_t storageTypeMin, int64_t storageTypeMax) {
416 return Base::get(storageType.getContext(), flags, storageType, expressedType,
417 scales, zeroPoints, quantizedDimensions, blockSizes,
418 storageTypeMin, storageTypeMax);
419}
420
422 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
423 Type storageType, Type expressedType, DenseElementsAttr scales,
424 DenseElementsAttr zeroPoints, ArrayRef<int32_t> quantizedDimensions,
425 ArrayRef<int64_t> blockSizes, int64_t storageTypeMin,
426 int64_t storageTypeMax) {
427 return Base::getChecked(emitError, storageType.getContext(), flags,
428 storageType, expressedType, scales, zeroPoints,
429 quantizedDimensions, blockSizes, storageTypeMin,
430 storageTypeMax);
431}
432
434 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
435 Type storageType, Type expressedType, DenseElementsAttr scales,
436 DenseElementsAttr zeroPoints, ArrayRef<int32_t> quantizedDimensions,
437 ArrayRef<int64_t> blockSizes, int64_t storageTypeMin,
438 int64_t storageTypeMax) {
439 if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
440 expressedType, storageTypeMin,
441 storageTypeMax))) {
442 return failure();
443 }
444
445 // Uniform quantization requires fully expressed parameters, including
446 // expressed type.
447 if (!expressedType)
448 return emitError() << "uniform quantization requires expressed type";
449
450 // Verify that the expressed type is floating point.
451 // If this restriction is ever eliminated, the parser/printer must be
452 // extended.
453 if (!llvm::isa<FloatType>(expressedType))
454 return emitError() << "expressed type must be floating point";
455
456 // Verify scale type to match expressedType.
457 if (scales.getType().getElementType() != expressedType) {
458 return emitError() << "type of scale values "
459 << scales.getType().getElementType()
460 << " must match the expressed type " << expressedType;
461 }
462
463 // Verify zero-point type to match storageType.
464 if (zeroPoints.getType().getElementType() != storageType) {
465 return emitError() << "type of zero point values "
466 << zeroPoints.getType().getElementType()
467 << " must match the storage type " << storageType;
468 }
469
470 // Ensure that the shape of scales and zeroPoints match.
471 if (scales.getType().getShape() != zeroPoints.getType().getShape())
472 return emitError() << "shape of scales and zeroPoints ("
473 << scales.getType().getShape() << " vs "
474 << zeroPoints.getType().getShape() << ") does not match";
475
476 // Ensure that the number of quantized-dimensions and block-sizes match.
477 if (quantizedDimensions.size() != blockSizes.size())
478 return emitError() << "number of quantized dimensions and block sizes ("
479 << scales.size() << " vs " << zeroPoints.size()
480 << ") does not match";
481
482 // Verify quantized dimension.
483 for (auto quantizedDimension : quantizedDimensions) {
484 if (quantizedDimension < 0)
485 return emitError() << "illegal quantized dimension: "
486 << quantizedDimension;
487 }
488
489 // Verify block sizes.
490 for (auto blockSize : blockSizes) {
491 if (blockSize <= 0)
492 return emitError() << "illegal block size: " << blockSize;
493 }
494
495 return success();
496}
497
501
503 return getImpl()->getZeroPoints();
504}
505
508 return getImpl()->getQuantizedDimensions();
509}
510
512 return getImpl()->getBlockSizes();
513}
514
518 result.reserve(getQuantizedDimensions().size());
519
520 for (auto [dim, size] :
521 llvm::zip(getQuantizedDimensions(), getBlockSizes())) {
522 result.push_back({dim, size});
523 }
524
525 return result;
526}
527
529 double min, double max) {
530 return Base::get(expressedType.getContext(), expressedType, min, max);
531}
532
535 double min, double max) {
536 return Base::getChecked(emitError, expressedType.getContext(), expressedType,
537 min, max);
538}
539
542 double min, double max) {
543 // Verify that the expressed type is floating point.
544 // If this restriction is ever eliminated, the parser/printer must be
545 // extended.
546 if (!llvm::isa<FloatType>(expressedType))
547 return emitError() << "expressed type must be floating point";
548 if (max <= min)
549 return emitError() << "illegal min and max: (" << min << ":" << max << ")";
550
551 return success();
552}
553
554double CalibratedQuantizedType::getMin() const { return getImpl()->min; }
555
556double CalibratedQuantizedType::getMax() const { return getImpl()->max; }
557
559 mlir::Type quantileType,
560 ArrayRef<double> quantiles,
561 std::optional<int64_t> storageMin,
562 std::optional<int64_t> storageMax) {
563 return Base::get(ctx, storageType, quantileType, quantiles, storageMin,
564 storageMax);
565}
566
569 mlir::Type storageType, mlir::Type quantileType, ArrayRef<double> quantiles,
570 std::optional<int64_t> storageMin, std::optional<int64_t> storageMax) {
571 return Base::getChecked(emitError, ctx, storageType, quantileType, quantiles,
572 storageMin, storageMax);
573}
574
577 Type quantileType, ArrayRef<double> quantiles,
578 std::optional<int64_t> storageMin, std::optional<int64_t> storageMax) {
579 if (!storageType.isIntOrFloat())
580 return emitError() << "storage type must be an integer or float type";
581 if (!llvm::isa<mlir::FloatType>(quantileType))
582 return emitError() << "quantile type must be a float type";
583 if (quantiles.empty())
584 return emitError() << "quantile values must not be empty";
585 if (storageMin.has_value() != storageMax.has_value())
586 return emitError()
587 << "storage min and max must both be specified or both omitted";
588 if (storageMin && storageMax && *storageMin >= *storageMax)
589 return emitError() << "storage min must be less than storage max";
590
591 unsigned width = storageType.getIntOrFloatBitWidth();
592 bool isSigned = !llvm::isa<mlir::IntegerType>(storageType) ||
593 llvm::cast<mlir::IntegerType>(storageType).isSigned();
594 auto effectiveMin =
595 storageMin.value_or(isSigned ? -(1LL << (width - 1)) : 0LL);
596 auto effectiveMax = storageMax.value_or(isSigned ? (1LL << (width - 1)) - 1
597 : (1LL << width) - 1);
598 auto expectedSize = effectiveMax - effectiveMin + 1;
599 if (static_cast<decltype(expectedSize)>(quantiles.size()) != expectedSize)
600 return emitError() << "quantile LUT size (" << quantiles.size()
601 << ") must equal the number of representable storage "
602 "values ("
603 << expectedSize << ")";
604
605 for (double v : quantiles)
606 if (std::isnan(v) || std::isinf(v))
607 return emitError()
608 << "quantile values must be finite (no NaN or infinity)";
609
610 return success();
611}
612
616
618 return static_cast<ImplType *>(impl)->getStorageType();
619}
620
622 return static_cast<ImplType *>(impl)->getQuantileType();
623}
624
626 return static_cast<ImplType *>(impl)->getQuantiles();
627}
628
629std::optional<int64_t> QuantileType::getStorageMin() const {
630 return static_cast<ImplType *>(impl)->getStorageMin();
631}
632
633std::optional<int64_t> QuantileType::getStorageMax() const {
634 return static_cast<ImplType *>(impl)->getStorageMax();
635}
636
638 if (auto intType = mlir::dyn_cast<mlir::IntegerType>(getStorageType()))
639 return intType.isSigned();
640 // Float types default to signed.
641 return true;
642}
643
647
649 if (auto explicitMax = getStorageMax())
650 return *explicitMax;
651 if (isSigned)
652 return (1LL << (getStorageWidth() - 1)) - 1;
653 return (1LL << getStorageWidth()) - 1;
654}
655
657 if (auto explicitMin = getStorageMin())
658 return *explicitMin;
659 if (isSigned)
660 return -(1LL << (getStorageWidth() - 1));
661 return 0;
662}
663
664std::string QuantileType::getStorageTypeName(bool isSigned) const {
665 std::string result = "!quant.quantile<";
666 llvm::raw_string_ostream os(result);
667 os << getStorageType() << ":" << getQuantileType() << ", {";
668 ArrayRef<double> quantiles = this->getQuantiles();
669 llvm::interleave(
670 llvm::seq<size_t>(0, quantiles.size()), os,
671 [&](size_t index) { os << quantiles[index]; }, ",");
672 os << "}";
673 if (auto minVal = getStorageMin())
674 if (auto maxVal = getStorageMax())
675 os << ", <" << *minVal << ":" << *maxVal << ">";
676 os << ">";
677 os.flush();
678 return result;
679}
680
681bool QuantileType::isPacked() const { return getStorageWidth() <= 4; }
682
684
686 unsigned width = getStorageWidth();
687 return width > 0 ? 8 / width : 0;
688}
689
690std::optional<unsigned> QuantileType::getPreferredAlignmentBytes() const {
691 return std::nullopt;
692}
return success()
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
An attribute that represents a reference to a dense vector or tensor object.
int64_t size() const
Returns the number of elements held by this attribute.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
static TypeID get()
Construct a type info object for the given type T.
Definition TypeID.h:245
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition Types.h:107
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
TypeID getTypeID()
Return a unique identifier for the concrete type.
Definition Types.h:101
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
A quantized type that maps storage to/from expressed types in an unspecified way.
Definition QuantTypes.h:204
static AnyQuantizedType get(unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
static AnyQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
A quantized type that infers its range from given min/max values.
Definition QuantTypes.h:525
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double min, double max)
Verifies construction invariants and issues errors/warnings.
static CalibratedQuantizedType get(Type expressedType, double min, double max)
Gets an instance of the type with all parameters specified but not checked.
static CalibratedQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double min, double max)
Gets an instance of the type with all specified parameters checked.
unsigned getLogicalBitWidth() const
static QuantileType getChecked(function_ref< InFlightDiagnostic()> emitError, mlir::MLIRContext *ctx, Type storageType, Type quantileType, ArrayRef< double > quantiles, std::optional< int64_t > storageMin=std::nullopt, std::optional< int64_t > storageMax=std::nullopt)
static bool classof(mlir::Type type)
Methods for support type inquiry through isa, cast, and dyn_cast.
unsigned getStorageWidth() const
std::optional< int64_t > getStorageMin() const
Return the explicit storage minimum, if set.
int64_t getDefaultMinimum(bool isSigned) const
unsigned getElementsPerByte() const
std::string getStorageTypeName(bool isSigned) const
detail::QuantileTypeStorage ImplType
Definition QuantTypes.h:594
bool shouldDefaultToSigned() const
int64_t getDefaultMaximum(bool isSigned) const
static QuantileType get(mlir::MLIRContext *ctx, Type storageType, Type quantileType, ArrayRef< double > quantiles={}, std::optional< int64_t > storageMin=std::nullopt, std::optional< int64_t > storageMax=std::nullopt)
ArrayRef< double > getQuantiles() const
Return the quantile table of this float type.
std::optional< unsigned > getPreferredAlignmentBytes() const
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, Type storageType, Type quantileType, ArrayRef< double > quantiles, std::optional< int64_t > storageMin, std::optional< int64_t > storageMax)
std::optional< int64_t > getStorageMax() const
Return the explicit storage maximum, if set.
Base class for all quantized types known to this dialect.
Definition QuantTypes.h:51
Type getExpressedType() const
Gets the original expressed type that this quantized type approximates.
static constexpr unsigned MaxStorageBits
The maximum number of bits supported for storage types.
Definition QuantTypes.h:57
bool hasStorageTypeBounds() const
Return whether the storage type has explicit min or max boundaries different from the minimum and max...
static Type castToStorageType(Type quantizedType)
Casts from a type based on a QuantizedType to a corresponding type based on the storageType (returns ...
Type castExpressedToStorageType(Type candidateType)
Casts from a type based on the expressedType to the equivalent type based on storageType by way of th...
detail::QuantizedTypeStorage ImplType
Definition QuantTypes.h:53
static Type castToExpressedType(Type quantizedType)
Casts from a type based on QuantizedType to a corresponding type based on the expressedType (returns ...
bool isSigned() const
Whether the storage type should be interpreted as a signed quantity (true) or an unsigned value (fals...
Definition QuantTypes.h:104
constexpr Type()=default
static QuantizedType getQuantizedElementType(Type primitiveOrContainerType)
Returns the element type as a QuantizedType or nullptr if it is not a quantized type.
unsigned getFlags() const
Gets the flags associated with this type.
int64_t getStorageTypeMax() const
The maximum value that storageType can take.
unsigned getStorageTypeIntegralWidth() const
Gets the integral bit width that the underlying storage type can exactly represent.
static bool classof(Type type)
Support method to enable LLVM-style type casting.
Type castFromStorageType(Type candidateType)
Casts from a type based on the storageType to a corresponding type based on this type (returns nullpt...
int64_t getStorageTypeMin() const
The minimum value that storageType can take.
Type getStorageType() const
Gets the underlying type used for to store values.
Type castFromExpressedType(Type candidateType)
Casts from a type based on the expressedType to a corresponding type based on this type (returns null...
bool isCompatibleExpressedType(Type candidateExpressedType)
Returns whether the candidateExpressedType is a match for this QuantizedType.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax)
Represents per-axis (also known as per-channel quantization).
Definition QuantTypes.h:325
static UniformQuantizedPerAxisType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
static UniformQuantizedPerAxisType get(unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
int32_t getQuantizedDimension() const
Specifies the dimension of the Tensor's shape that the scales and zero_points correspond to.
ArrayRef< int64_t > getZeroPoints() const
Gets the storage values corresponding to the real value 0 in the affine equation.
ArrayRef< double > getScales() const
Gets the quantization scales.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
Represents sub-channel (also known as blockwise quantization).
Definition QuantTypes.h:410
ArrayRef< int32_t > getQuantizedDimensions() const
Gets the quantized dimensions.
DenseElementsAttr getZeroPoints() const
Gets the quantization zero-points.
ArrayRef< int64_t > getBlockSizes() const
Gets the block sizes for the quantized dimensions.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, DenseElementsAttr scales, DenseElementsAttr zeroPoints, ArrayRef< int32_t > quantizedDimensions, ArrayRef< int64_t > blockSizes, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
const SmallVector< std::pair< int32_t, int64_t > > getBlockSizeInfo() const
Gets the block size information.
static UniformQuantizedSubChannelType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, DenseElementsAttr scales, DenseElementsAttr zeroPoints, ArrayRef< int32_t > quantizedDimensions, ArrayRef< int64_t > blockSizes, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
static UniformQuantizedSubChannelType get(unsigned flags, Type storageType, Type expressedType, DenseElementsAttr scales, DenseElementsAttr zeroPoints, ArrayRef< int32_t > quantizedDimensions, ArrayRef< int64_t > blockSizes, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
DenseElementsAttr getScales() const
Gets the quantization scales.
Represents a family of uniform, quantized types.
Definition QuantTypes.h:265
double getScale() const
Gets the scale term.
int64_t getZeroPoint() const
Gets the storage value corresponding to the real value 0 in the affine equation.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
Verifies construction invariants and issues errors/warnings.
static UniformQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
static UniformQuantizedType get(unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147