MLIR 23.0.0git
TypeParser.cpp
Go to the documentation of this file.
1//===- TypeParser.h - Quantization Type Parser ------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
14#include "mlir/IR/Types.h"
15#include "llvm/ADT/APFloat.h"
16#include "llvm/ADT/SmallVectorExtras.h"
17
18using namespace mlir;
19using namespace quant;
20
21static Type parseStorageType(DialectAsmParser &parser, bool &isSigned) {
22 auto typeLoc = parser.getCurrentLocation();
23 Type type;
24
25 // Parse storage type (alpha_ident, integer_literal).
26 StringRef identifier;
27 unsigned storageTypeWidth = 0;
29 if (result.has_value()) {
30 if (!succeeded(*result))
31 return nullptr;
32
33 if (auto quantStorageTypeInterface =
34 llvm::dyn_cast<QuantStorageTypeInterface>(type)) {
35 // Returns true if the type defaults to signed (e.g., si8, i8 or float
36 // types), false if it defaults to unsigned.
37 isSigned = quantStorageTypeInterface.shouldDefaultToSigned();
38 storageTypeWidth = quantStorageTypeInterface.getStorageWidth();
39 } else {
40 parser.emitError(typeLoc, "illegal storage type prefix");
41 return nullptr;
42 }
43 } else if (succeeded(parser.parseKeyword(&identifier))) {
44 // Otherwise, this must be an unsigned integer (`u` integer-literal)
45 if (identifier.consume_front("u")) {
46 if (identifier.getAsInteger(10, storageTypeWidth)) {
47 parser.emitError(typeLoc, "expected storage type width");
48 return nullptr;
49 }
50 isSigned = false;
51 type = parser.getBuilder().getIntegerType(storageTypeWidth);
52 } else {
53 parser.emitError(typeLoc, "illegal storage type prefix");
54 return nullptr;
55 }
56 } else {
57 return nullptr;
58 }
59
60 if (storageTypeWidth == 0 ||
61 storageTypeWidth > QuantizedType::MaxStorageBits) {
62 parser.emitError(typeLoc, "illegal storage type size: ")
63 << storageTypeWidth;
64 return nullptr;
65 }
66
67 return type;
68}
69
70static ParseResult parseStorageRange(DialectAsmParser &parser, Type storageType,
71 bool isSigned, int64_t &storageTypeMin,
72 int64_t &storageTypeMax) {
73 auto quantStorageTypeInterface =
74 llvm::dyn_cast<QuantStorageTypeInterface>(storageType);
75
76 int64_t defaultMin = quantStorageTypeInterface.getDefaultMinimum(isSigned);
77 int64_t defaultMax = quantStorageTypeInterface.getDefaultMaximum(isSigned);
78
79 if (failed(parser.parseOptionalLess())) {
80 storageTypeMin = defaultMin;
81 storageTypeMax = defaultMax;
82 return success();
83 }
84
85 // Explicit storage min and storage max.
86 SMLoc minLoc = parser.getCurrentLocation(), maxLoc;
87 if (parser.parseInteger(storageTypeMin) || parser.parseColon() ||
88 parser.getCurrentLocation(&maxLoc) ||
89 parser.parseInteger(storageTypeMax) || parser.parseGreater())
90 return failure();
91 if (storageTypeMin < defaultMin) {
92 return parser.emitError(minLoc, "illegal storage type minimum: ")
93 << storageTypeMin;
94 }
95 if (storageTypeMax > defaultMax) {
96 return parser.emitError(maxLoc, "illegal storage type maximum: ")
97 << storageTypeMax;
98 }
99 return success();
100}
101
103 double &min, double &max) {
104 auto typeLoc = parser.getCurrentLocation();
105 FloatType type;
106
107 if (failed(parser.parseType(type))) {
108 parser.emitError(typeLoc, "expecting float expressed type");
109 return nullptr;
110 }
111
112 // Calibrated min and max values.
113 if (parser.parseLess() || parser.parseFloat(min) || parser.parseColon() ||
114 parser.parseFloat(max) || parser.parseGreater()) {
115 parser.emitError(typeLoc, "calibrated values must be present");
116 return nullptr;
117 }
118 return type;
119}
120
121/// Parses an AnyQuantizedType.
122///
123/// any ::= `any<` storage-spec (expressed-type-spec)?`>`
124/// storage-spec ::= storage-type (`<` storage-range `>`)?
125/// storage-range ::= integer-literal `:` integer-literal
126/// storage-type ::= (`i` | `u`) integer-literal
127/// expressed-type-spec ::= `:` `f` integer-literal
129 Type storageType;
130 FloatType expressedType;
131 unsigned typeFlags = 0;
132 int64_t storageTypeMin;
133 int64_t storageTypeMax;
134
135 // Type specification.
136 if (parser.parseLess())
137 return nullptr;
138
139 // Storage type.
140 bool isSigned = false;
141 storageType = parseStorageType(parser, isSigned);
142 if (!storageType) {
143 return nullptr;
144 }
145 if (isSigned) {
146 typeFlags |= QuantizationFlags::Signed;
147 }
148
149 // Storage type range.
150 if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
151 storageTypeMax)) {
152 return nullptr;
153 }
154
155 // Optional expressed type.
156 if (succeeded(parser.parseOptionalColon())) {
157 if (parser.parseType(expressedType)) {
158 return nullptr;
159 }
160 }
161
162 if (parser.parseGreater()) {
163 return nullptr;
164 }
165
166 return parser.getChecked<AnyQuantizedType>(
167 typeFlags, storageType, expressedType, storageTypeMin, storageTypeMax);
168}
169
170/// Checks if the given scale value is within the valid range of the expressed
171/// type. The `expressedType` argument is the floating-point type used for
172/// expressing the quantized values, and `scale` is the double value to check.
173static LogicalResult
175 Type expressedType, double scale) {
176 auto floatType = cast<FloatType>(expressedType);
177 double minScale =
178 APFloat::getSmallest(floatType.getFloatSemantics()).convertToDouble();
179 double maxScale =
180 APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble();
181 if (scale < minScale || scale > maxScale)
182 return emitError() << "scale " << scale << " out of expressed type range ["
183 << minScale << ", " << maxScale << "]";
184 return success();
185}
186
187/// Parses a quantization parameter, which is either a scale value (float) or a
188/// scale-zero point pair (float:integer). `expressedType`, expressing the type
189/// of scale values, is used to validate the scale. The parsed scale and zero
190/// point (if any) are stored in `scale` and `zeroPoint`.
191static ParseResult parseQuantParams(DialectAsmParser &parser,
192 Type expressedType, double &scale,
193 int64_t &zeroPoint) {
194
195 if (parser.parseFloat(scale)) {
196 return failure();
197 }
198
200 [&]() { return parser.emitError(parser.getCurrentLocation()); },
201 expressedType, scale))) {
202 return failure();
203 }
204
205 zeroPoint = 0;
206 if (failed(parser.parseOptionalColon())) {
207 return success();
208 }
209
210 return parser.parseInteger(zeroPoint);
211}
212
213/// Parses block size information for sub-channel quantization, assuming the
214/// leading '{' has already been parsed. The block size information is provided
215/// as a comma-separated list of "Axis:BlockSize" pairs, terminated by a '}'.
216///
217/// The parsed axis indices are stored in `quantizedDimensions`, and the
218/// corresponding block sizes are stored in `blockSizes`.
219static ParseResult
221 SmallVectorImpl<int32_t> &quantizedDimensions,
222 SmallVectorImpl<int64_t> &blockSizes) {
223 // Empty block-sizes info.
224 if (succeeded(parser.parseOptionalRBrace())) {
225 return success();
226 }
227
228 auto parseBlockSizeElements = [&]() -> ParseResult {
229 quantizedDimensions.resize(quantizedDimensions.size() + 1);
230 blockSizes.resize(blockSizes.size() + 1);
231 if (parser.parseInteger(quantizedDimensions.back()) ||
232 parser.parseColon() || parser.parseInteger(blockSizes.back()))
233 return failure();
234 return success();
235 };
236
237 if (parser.parseCommaSeparatedList(parseBlockSizeElements) ||
238 parser.parseRBrace()) {
239 return failure();
240 }
241
242 return success();
243}
244
245/// Parses a bracketed list of quantization parameters, returning the dimensions
246/// of the parsed sub-tensors in `dims`. The dimension of the list is prepended
247/// to the dimensions of the sub-tensors. This function assumes that the initial
248/// left brace has already been parsed. For example:
249///
250/// parseQuantParamListUntilRBrace(1.0:1, 2.0:4, 3.0:4}) -> Success,
251/// dims = [3], scales = [1.0, 2.0, 3.0], zeroPoints = [1, 4, 4]
252///
253/// parseQuantParamListUntilRBrace({1.0, 2.0}, {3.0:1, 4.0:9}}) -> Success,
254/// dims = [2, 2], scales = [1.0, 2.0, 3.0, 4.0], zeroPoints = [0, 0, 1,
255/// 9]
256///
257/// This function expects all sub-tensors to have the same rank.
258static ParseResult
261 SmallVectorImpl<int64_t> &zeroPoints,
263 auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims,
264 const SmallVectorImpl<int64_t> &newDims) -> ParseResult {
265 if (prevDims == newDims)
266 return success();
267 return parser.emitError(parser.getCurrentLocation())
268 << "tensor literal is invalid; ranks are not consistent "
269 "between elements";
270 };
271
272 bool first = true;
274 unsigned size = 0;
275
276 auto parseOneElement = [&]() -> ParseResult {
278 if (succeeded(parser.parseOptionalLBrace())) {
279 if (parseQuantParamListUntilRBrace(parser, expressedType, scales,
280 zeroPoints, thisDims))
281 return failure();
282 } else {
283 zeroPoints.resize(zeroPoints.size() + 1);
284 scales.resize(scales.size() + 1);
285 if (parseQuantParams(parser, expressedType, scales.back(),
286 zeroPoints.back())) {
287 return failure();
288 }
289 }
290 ++size;
291 if (!first)
292 return checkDims(newDims, thisDims);
293 newDims = thisDims;
294 first = false;
295 return success();
296 };
297
298 if (parser.parseCommaSeparatedList(parseOneElement) || parser.parseRBrace()) {
299 return failure();
300 }
301
302 // Return the sublists' dimensions with 'size' prepended.
303 dims.clear();
304 dims.push_back(size);
305 dims.append(newDims.begin(), newDims.end());
306
307 return success();
308}
309
310/// Parses a UniformQuantizedType.
311///
312/// uniform_type ::= uniform_per_layer
313/// | uniform_per_axis
314/// | uniform_sub_channel
315/// uniform_per_layer ::= `uniform<` storage-spec expressed-type-spec
316/// `,` scale-zero `>`
317/// uniform_per_axis ::= `uniform<` storage-spec expressed-type-spec
318/// axis-spec `,` `{` scale-zero-list `}` `>`
319/// uniform_sub_channel ::= `uniform<` storage-spec expressed-type-spec
320/// block-size-info `,` scale-zero-tensor `>`
321/// storage-spec ::= storage-type (`<` storage-range `>`)?
322/// storage-range ::= integer-literal `:` integer-literal
323/// storage-type ::= (`i` | `u`) integer-literal | `f8E5M2` | `f8E4M3FN`
324// | `f4E2M1FN` | 'quantile'
325/// expressed-type-spec ::= `:` `f` integer-literal
326/// axis-spec ::= `:` integer-literal
327/// scale-zero ::= scale (`:` zero-point)?
328/// scale ::= float-literal
329/// zero-point ::= integer-literal
330/// scale-zero-list ::= scale-zero (`,` scale-zero)*
331/// block-size-info ::= `{` `}` | `{` axis-block `:` (`,` axis-block)* `}`
332/// axis-block ::= axis-spec `:` block-size-spec
333/// block-size-spec ::= integer-literal
334/// scale-zero-tensor ::= scale-zero-dense-exp | scale-zero-list
335/// scale-zero-dense-exp ::= `{`
336/// scale-zero-tensor (`,` scale-zero-tensor)*
337/// `}`
339 Type storageType;
340 FloatType expressedType;
341 unsigned typeFlags = 0;
342 int64_t storageTypeMin;
343 int64_t storageTypeMax;
344 bool isPerAxis = false;
345 bool isSubChannel = false;
346 SmallVector<int32_t, 1> quantizedDimensions;
347 SmallVector<int64_t, 1> blockSizes;
349 SmallVector<int64_t, 1> zeroPoints;
350
351 // Type specification.
352 if (parser.parseLess()) {
353 return nullptr;
354 }
355
356 // Storage type.
357 bool isSigned = false;
358 storageType = parseStorageType(parser, isSigned);
359 if (!storageType) {
360 return nullptr;
361 }
362 if (isSigned) {
363 typeFlags |= QuantizationFlags::Signed;
364 }
365
366 // Storage type range.
367 if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
368 storageTypeMax)) {
369 return nullptr;
370 }
371
372 // Expressed type.
373 if (parser.parseColon() || parser.parseType(expressedType)) {
374 return nullptr;
375 }
376
377 // Optionally parse quantized dimension for per-axis or sub-channel
378 // quantization.
379 if (succeeded(parser.parseOptionalColon())) {
380 if (succeeded(parser.parseOptionalLBrace())) {
381 isSubChannel = true;
382 if (parseBlockSizeInfoUntilRBrace(parser, quantizedDimensions,
383 blockSizes)) {
384 return nullptr;
385 }
386 } else {
387 isPerAxis = true;
388 quantizedDimensions.resize(1);
389 if (parser.parseInteger(quantizedDimensions.back())) {
390 return nullptr;
391 }
392 }
393 }
394
395 // Comma leading into range_spec.
396 if (parser.parseComma()) {
397 return nullptr;
398 }
399
400 // Quantization parameter (scales/zeroPoints) specification.
401 bool isPerTensor = !isPerAxis && !isSubChannel;
403 if (isPerTensor) {
404 zeroPoints.resize(zeroPoints.size() + 1);
405 scales.resize(scales.size() + 1);
406 if (parseQuantParams(parser, expressedType, scales.back(),
407 zeroPoints.back())) {
408 return nullptr;
409 }
410
411 } else {
412 if (parser.parseLBrace() ||
413 parseQuantParamListUntilRBrace(parser, expressedType, scales,
414 zeroPoints, dims)) {
415 return nullptr;
416 }
417 }
418
419 if (parser.parseGreater()) {
420 return nullptr;
421 }
422
423 if (isPerAxis) {
425 typeFlags, storageType, expressedType, scales, zeroPoints,
426 quantizedDimensions[0], storageTypeMin, storageTypeMax);
427 }
428 if (isSubChannel) {
429 SmallVector<APFloat> apFloatScales =
430 llvm::map_to_vector(scales, [&](double scale) -> APFloat {
431 APFloat apFloatScale(scale);
432 bool unused;
433 apFloatScale.convert(expressedType.getFloatSemantics(),
434 APFloat::rmNearestTiesToEven, &unused);
435 return apFloatScale;
436 });
437 SmallVector<APInt> apIntZeroPoints =
438 llvm::map_to_vector(zeroPoints, [&](int64_t zeroPoint) -> APInt {
439 return APInt(storageType.getIntOrFloatBitWidth(), zeroPoint);
440 });
441 auto scalesRef = mlir::DenseElementsAttr::get(
442 RankedTensorType::get(dims, expressedType), apFloatScales);
443 auto zeroPointsRef = mlir::DenseElementsAttr::get(
444 RankedTensorType::get(dims, storageType), apIntZeroPoints);
446 typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
447 quantizedDimensions, blockSizes, storageTypeMin, storageTypeMax);
448 }
449
450 return parser.getChecked<UniformQuantizedType>(
451 typeFlags, storageType, expressedType, scales.front(), zeroPoints.front(),
452 storageTypeMin, storageTypeMax);
453}
454
455/// Parses an CalibratedQuantizedType.
456///
457/// calibrated ::= `calibrated<` expressed-spec `>`
458/// expressed-spec ::= expressed-type `<` calibrated-range `>`
459/// expressed-type ::= `f` integer-literal
460/// calibrated-range ::= float-literal `:` float-literal
462 FloatType expressedType;
463 double min;
464 double max;
465
466 // Type specification.
467 if (parser.parseLess())
468 return nullptr;
469
470 // Expressed type.
471 expressedType = parseExpressedTypeAndRange(parser, min, max);
472 if (!expressedType) {
473 return nullptr;
474 }
475
476 if (parser.parseGreater()) {
477 return nullptr;
478 }
479
480 return parser.getChecked<CalibratedQuantizedType>(expressedType, min, max);
481}
482
484 Type storageType;
485 Type quantileType;
486 SmallVector<double, 1> quantiles;
487
488 if (parser.parseLess())
489 return nullptr;
490 if (parser.parseType(storageType))
491 return nullptr;
492 if (parser.parseColon())
493 return nullptr;
494 if (parser.parseType(quantileType))
495 return nullptr;
496 if (parser.parseComma())
497 return nullptr;
498 if (parser.parseLBrace())
499 return nullptr;
500
501 // Allow empty braces `{}` — verify() will catch the empty quantile error.
502 if (failed(parser.parseOptionalRBrace())) {
503 do {
504 quantiles.emplace_back();
505 if (parser.parseFloat(quantiles.back()))
506 return nullptr;
507 } while (succeeded(parser.parseOptionalComma()));
508
509 if (parser.parseRBrace())
510 return nullptr;
511 }
512
513 // Optionally parse explicit storage range: `, min:max` (inside the outer
514 // `<>`).
515 std::optional<int64_t> storageMin, storageMax;
516 if (succeeded(parser.parseOptionalComma())) {
517 if (parser.parseLess())
518 return nullptr;
519 int64_t minVal, maxVal;
520 if (parser.parseInteger(minVal) || parser.parseColon() ||
521 parser.parseInteger(maxVal))
522 return nullptr;
523 storageMin = minVal;
524 storageMax = maxVal;
525 if (parser.parseGreater())
526 return nullptr;
527 }
528
529 if (parser.parseGreater())
530 return nullptr;
531
532 mlir::MLIRContext *ctx = parser.getContext();
533 return parser.getChecked<QuantileType>(ctx, storageType, quantileType,
534 quantiles, storageMin, storageMax);
535}
536
537/// Parse a type registered to this dialect.
538Type QuantDialect::parseType(DialectAsmParser &parser) const {
539 // All types start with an identifier that we switch on.
540 StringRef typeNameSpelling;
541 if (failed(parser.parseKeyword(&typeNameSpelling)))
542 return nullptr;
543
544 if (typeNameSpelling == "uniform")
545 return parseUniformType(parser);
546 if (typeNameSpelling == "any")
547 return parseAnyType(parser);
548 if (typeNameSpelling == "calibrated")
549 return parseCalibratedType(parser);
550 if (typeNameSpelling == "quantile")
551 return parseQuantileType(parser);
552
553 parser.emitError(parser.getNameLoc(),
554 "unknown quantized type " + typeNameSpelling);
555 return nullptr;
556}
557
559 // storage type
560 auto quantStorageTypeInterface =
561 llvm::dyn_cast<QuantStorageTypeInterface>(type.getStorageType());
562
563 out << quantStorageTypeInterface.getStorageTypeName(type.isSigned());
564
565 // storageTypeMin and storageTypeMax if not default.
566 if (type.hasStorageTypeBounds()) {
567 out << "<" << type.getStorageTypeMin() << ":" << type.getStorageTypeMax()
568 << ">";
569 }
570}
571
572static void printQuantParams(double scale, int64_t zeroPoint,
573 DialectAsmPrinter &out) {
574 out << scale;
575 if (zeroPoint != 0) {
576 out << ":" << zeroPoint;
577 }
578}
579
580static void
581printBlockSizeInfo(ArrayRef<std::pair<int32_t, int64_t>> blockSizeInfo,
582 DialectAsmPrinter &out) {
583 out << "{";
584 llvm::interleaveComma(
585 llvm::seq<size_t>(0, blockSizeInfo.size()), out, [&](size_t index) {
586 out << blockSizeInfo[index].first << ":" << blockSizeInfo[index].second;
587 });
588 out << "}";
589}
590
591/// Helper that prints a AnyQuantizedType.
593 DialectAsmPrinter &out) {
594 out << "any<";
595 printStorageType(type, out);
596 if (Type expressedType = type.getExpressedType()) {
597 out << ":" << expressedType;
598 }
599 out << ">";
600}
601
602/// Helper that prints a UniformQuantizedType.
604 DialectAsmPrinter &out) {
605 out << "uniform<";
606 printStorageType(type, out);
607 out << ":" << type.getExpressedType() << ", ";
608
609 // scheme specific parameters
610 printQuantParams(type.getScale(), type.getZeroPoint(), out);
611 out << ">";
612}
613
614/// Helper that prints a UniformQuantizedPerAxisType.
616 DialectAsmPrinter &out) {
617 out << "uniform<";
618 printStorageType(type, out);
619 out << ":" << type.getExpressedType() << ":";
620 out << type.getQuantizedDimension();
621 out << ", ";
622
623 // scheme specific parameters
624 ArrayRef<double> scales = type.getScales();
625 ArrayRef<int64_t> zeroPoints = type.getZeroPoints();
626 out << "{";
627 llvm::interleave(
628 llvm::seq<size_t>(0, scales.size()), out,
629 [&](size_t index) {
630 printQuantParams(scales[index], zeroPoints[index], out);
631 },
632 ",");
633 out << "}>";
634}
635
636/// Prints quantization parameters as a nested list of `scale`[:`zero_point`]
637/// elements. The nesting corresponds to the `shape` dimensions.
638///
639/// Elements are delimited by commas, and the inner dimensions are enclosed in
640/// braces. `zero_point` is only printed if it is non-zero. For example:
641///
642/// printDenseQuantizationParameters(scales=[1.0, 2.0, 3.0, 4.0],
643/// zeroPoints=[0, 0, 1, 9],
644/// shape=[2, 2])
645///
646/// would print:
647///
648/// {{1.0, 2.0}, {3.0:1, 4.0:9}}
650 ArrayRef<APInt> zeroPoints,
652 DialectAsmPrinter &out) {
653 int64_t rank = shape.size();
654 SmallVector<unsigned, 4> counter(rank, 0);
655 unsigned openBrackets = 0;
656
657 auto incrementCounterAndDelimit = [&]() {
658 ++counter[rank - 1];
659 for (unsigned i = rank - 1; i > 0; --i) {
660 if (counter[i] >= shape[i]) {
661 counter[i] = 0;
662 ++counter[i - 1];
663 --openBrackets;
664 out << '}';
665 }
666 }
667 };
668
669 for (unsigned idx = 0, e = scales.size(); idx < e; ++idx) {
670 if (idx != 0)
671 out << ", ";
672 while (openBrackets++ < rank)
673 out << '{';
674 openBrackets = rank;
675 out << scales[idx];
676 if (zeroPoints[idx] != 0) {
677 out << ":" << zeroPoints[idx];
678 }
679 incrementCounterAndDelimit();
680 }
681 while (openBrackets-- > 0)
682 out << '}';
683}
684
685/// Helper that prints a UniformQuantizedSubChannelType.
686static void
688 DialectAsmPrinter &out) {
689 out << "uniform<";
690 printStorageType(type, out);
691 out << ":" << type.getExpressedType() << ":";
693 out << ", ";
694
695 auto scalesItr = type.getScales().getValues<APFloat>();
696 auto zeroPointsItr = type.getZeroPoints().getValues<APInt>();
697 SmallVector<APFloat> scales(scalesItr.begin(), scalesItr.end());
698 SmallVector<APInt> zeroPoints(zeroPointsItr.begin(), zeroPointsItr.end());
699 printDenseQuantizationParameters(scales, zeroPoints,
700 type.getScales().getType().getShape(), out);
701 out << ">";
702}
703
704/// Helper that prints a CalibratedQuantizedType.
706 DialectAsmPrinter &out) {
707 out << "calibrated<" << type.getExpressedType();
708 out << "<" << type.getMin() << ":" << type.getMax() << ">";
709 out << ">";
710}
711
713 out << "quantile<";
714 out << type.getStorageType();
715 out << ":";
716 out << type.getQuantileType();
717 out << ", {";
718 ArrayRef<double> quantiles = type.getQuantiles();
719 llvm::interleave(
720 llvm::seq<size_t>(0, quantiles.size()), out,
721 [&](size_t index) { out << quantiles[index]; }, ",");
722 out << "}";
723 if (auto minVal = type.getStorageMin())
724 if (auto maxVal = type.getStorageMax())
725 out << ", <" << *minVal << ":" << *maxVal << ">";
726 out << ">";
727}
728
729/// Print a type registered to this dialect.
730void QuantDialect::printType(Type type, DialectAsmPrinter &os) const {
731 if (auto anyType = llvm::dyn_cast<AnyQuantizedType>(type))
732 printAnyQuantizedType(anyType, os);
733 else if (auto uniformType = llvm::dyn_cast<UniformQuantizedType>(type))
734 printUniformQuantizedType(uniformType, os);
735 else if (auto perAxisType = llvm::dyn_cast<UniformQuantizedPerAxisType>(type))
736 printUniformQuantizedPerAxisType(perAxisType, os);
737 else if (auto perAxisType =
738 llvm::dyn_cast<UniformQuantizedSubChannelType>(type))
740 else if (auto calibratedType = llvm::dyn_cast<CalibratedQuantizedType>(type))
741 printCalibratedQuantizedType(calibratedType, os);
742 else if (auto quantileType = llvm::dyn_cast<QuantileType>(type))
743 printQuantileType(quantileType, os);
744 else
745 llvm_unreachable("Unhandled quantized type");
746}
return success()
static void printBlockSizeInfo(ArrayRef< std::pair< int32_t, int64_t > > blockSizeInfo, DialectAsmPrinter &out)
static ParseResult parseQuantParams(DialectAsmParser &parser, Type expressedType, double &scale, int64_t &zeroPoint)
Parses a quantization parameter, which is either a scale value (float) or a scale-zero point pair (fl...
static void printAnyQuantizedType(AnyQuantizedType type, DialectAsmPrinter &out)
Helper that prints a AnyQuantizedType.
static Type parseStorageType(DialectAsmParser &parser, bool &isSigned)
static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser, double &min, double &max)
static ParseResult parseStorageRange(DialectAsmParser &parser, Type storageType, bool isSigned, int64_t &storageTypeMin, int64_t &storageTypeMax)
static void printUniformQuantizedSubChannelType(UniformQuantizedSubChannelType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedSubChannelType.
static Type parseUniformType(DialectAsmParser &parser)
Parses a UniformQuantizedType.
static void printQuantileType(QuantileType type, DialectAsmPrinter &out)
static Type parseAnyType(DialectAsmParser &parser)
Parses an AnyQuantizedType.
static void printStorageType(QuantizedType type, DialectAsmPrinter &out)
static ParseResult parseBlockSizeInfoUntilRBrace(DialectAsmParser &parser, SmallVectorImpl< int32_t > &quantizedDimensions, SmallVectorImpl< int64_t > &blockSizes)
Parses block size information for sub-channel quantization, assuming the leading '{' has already been...
static void printQuantParams(double scale, int64_t zeroPoint, DialectAsmPrinter &out)
static Type parseCalibratedType(DialectAsmParser &parser)
Parses an CalibratedQuantizedType.
static ParseResult parseQuantParamListUntilRBrace(DialectAsmParser &parser, Type expressedType, SmallVectorImpl< double > &scales, SmallVectorImpl< int64_t > &zeroPoints, SmallVectorImpl< int64_t > &dims)
Parses a bracketed list of quantization parameters, returning the dimensions of the parsed sub-tensor...
static void printCalibratedQuantizedType(CalibratedQuantizedType type, DialectAsmPrinter &out)
Helper that prints a CalibratedQuantizedType.
static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedPerAxisType.
static void printUniformQuantizedType(UniformQuantizedType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedType.
static LogicalResult isScaleInExpressedTypeRange(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double scale)
Checks if the given scale value is within the valid range of the expressed type.
static Type parseQuantileType(DialectAsmParser &parser)
static void printDenseQuantizationParameters(ArrayRef< APFloat > scales, ArrayRef< APInt > zeroPoints, ArrayRef< int64_t > shape, DialectAsmPrinter &out)
Prints quantization parameters as a nested list of scale[:zero_point] elements.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
virtual ParseResult parseLBrace()=0
Parse a { token.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalRBrace()=0
Parse a } token if present.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual OptionalParseResult parseOptionalType(Type &result)=0
Parse an optional type.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseFloat(double &result)=0
Parse a floating point value from the stream.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class implements Optional functionality for ParseResult.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
A quantized type that maps storage to/from expressed types in an unspecified way.
Definition QuantTypes.h:204
A quantized type that infers its range from given min/max values.
Definition QuantTypes.h:525
std::optional< int64_t > getStorageMin() const
Return the explicit storage minimum, if set.
ArrayRef< double > getQuantiles() const
Return the quantile table of this float type.
std::optional< int64_t > getStorageMax() const
Return the explicit storage maximum, if set.
Base class for all quantized types known to this dialect.
Definition QuantTypes.h:51
bool hasStorageTypeBounds() const
Return whether the storage type has explicit min or max boundaries different from the minimum and max...
bool isSigned() const
Whether the storage type should be interpreted as a signed quantity (true) or an unsigned value (fals...
Definition QuantTypes.h:104
int64_t getStorageTypeMax() const
The maximum value that storageType can take.
int64_t getStorageTypeMin() const
The minimum value that storageType can take.
Type getStorageType() const
Gets the underlying type used for to store values.
Represents per-axis (also known as per-channel quantization).
Definition QuantTypes.h:325
int32_t getQuantizedDimension() const
Specifies the dimension of the Tensor's shape that the scales and zero_points correspond to.
ArrayRef< int64_t > getZeroPoints() const
Gets the storage values corresponding to the real value 0 in the affine equation.
ArrayRef< double > getScales() const
Gets the quantization scales.
Represents sub-channel (also known as blockwise quantization).
Definition QuantTypes.h:410
DenseElementsAttr getZeroPoints() const
Gets the quantization zero-points.
const SmallVector< std::pair< int32_t, int64_t > > getBlockSizeInfo() const
Gets the block size information.
DenseElementsAttr getScales() const
Gets the quantization scales.
Represents a family of uniform, quantized types.
Definition QuantTypes.h:265
double getScale() const
Gets the scale term.
int64_t getZeroPoint() const
Gets the storage value corresponding to the real value 0 in the affine equation.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147