MLIR 23.0.0git
TypeParser.cpp
Go to the documentation of this file.
1//===- TypeParser.h - Quantization Type Parser ------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
14#include "mlir/IR/Types.h"
15#include "llvm/ADT/APFloat.h"
16#include "llvm/ADT/SmallVectorExtras.h"
17
18using namespace mlir;
19using namespace quant;
20
21static Type parseStorageType(DialectAsmParser &parser, bool &isSigned) {
22 auto typeLoc = parser.getCurrentLocation();
23 Type type;
24
25 // Parse storage type (alpha_ident, integer_literal).
26 StringRef identifier;
27 unsigned storageTypeWidth = 0;
29 if (result.has_value()) {
30 if (!succeeded(*result))
31 return nullptr;
32
33 if (auto quantStorageTypeInterface =
34 llvm::dyn_cast<QuantStorageTypeInterface>(type)) {
35 // Returns true if the type defaults to signed (e.g., si8, i8 or float
36 // types), false if it defaults to unsigned.
37 isSigned = quantStorageTypeInterface.shouldDefaultToSigned();
38 storageTypeWidth = quantStorageTypeInterface.getStorageWidth();
39 } else {
40 parser.emitError(typeLoc, "illegal storage type prefix");
41 return nullptr;
42 }
43 } else if (succeeded(parser.parseKeyword(&identifier))) {
44 // Otherwise, this must be an unsigned integer (`u` integer-literal)
45 if (identifier.consume_front("u")) {
46 if (identifier.getAsInteger(10, storageTypeWidth)) {
47 parser.emitError(typeLoc, "expected storage type width");
48 return nullptr;
49 }
50 isSigned = false;
51 type = parser.getBuilder().getIntegerType(storageTypeWidth);
52 } else {
53 parser.emitError(typeLoc, "illegal storage type prefix");
54 return nullptr;
55 }
56 } else {
57 return nullptr;
58 }
59
60 if (storageTypeWidth == 0 ||
61 storageTypeWidth > QuantizedType::MaxStorageBits) {
62 parser.emitError(typeLoc, "illegal storage type size: ")
63 << storageTypeWidth;
64 return nullptr;
65 }
66
67 return type;
68}
69
70static ParseResult parseStorageRange(DialectAsmParser &parser, Type storageType,
71 bool isSigned, int64_t &storageTypeMin,
72 int64_t &storageTypeMax) {
73 auto quantStorageTypeInterface =
74 llvm::dyn_cast<QuantStorageTypeInterface>(storageType);
75
76 int64_t defaultMin = quantStorageTypeInterface.getDefaultMinimum(isSigned);
77 int64_t defaultMax = quantStorageTypeInterface.getDefaultMaximum(isSigned);
78
79 if (failed(parser.parseOptionalLess())) {
80 storageTypeMin = defaultMin;
81 storageTypeMax = defaultMax;
82 return success();
83 }
84
85 // Explicit storage min and storage max.
86 SMLoc minLoc = parser.getCurrentLocation(), maxLoc;
87 if (parser.parseInteger(storageTypeMin) || parser.parseColon() ||
88 parser.getCurrentLocation(&maxLoc) ||
89 parser.parseInteger(storageTypeMax) || parser.parseGreater())
90 return failure();
91 if (storageTypeMin < defaultMin) {
92 return parser.emitError(minLoc, "illegal storage type minimum: ")
93 << storageTypeMin;
94 }
95 if (storageTypeMax > defaultMax) {
96 return parser.emitError(maxLoc, "illegal storage type maximum: ")
97 << storageTypeMax;
98 }
99 return success();
100}
101
103 double &min, double &max) {
104 auto typeLoc = parser.getCurrentLocation();
105 FloatType type;
106
107 if (failed(parser.parseType(type))) {
108 parser.emitError(typeLoc, "expecting float expressed type");
109 return nullptr;
110 }
111
112 // Calibrated min and max values.
113 if (parser.parseLess() || parser.parseFloat(min) || parser.parseColon() ||
114 parser.parseFloat(max) || parser.parseGreater()) {
115 parser.emitError(typeLoc, "calibrated values must be present");
116 return nullptr;
117 }
118 return type;
119}
120
121/// Parses an AnyQuantizedType.
122///
123/// any ::= `any<` storage-spec (expressed-type-spec)?`>`
124/// storage-spec ::= storage-type (`<` storage-range `>`)?
125/// storage-range ::= integer-literal `:` integer-literal
126/// storage-type ::= (`i` | `u`) integer-literal
127/// expressed-type-spec ::= `:` `f` integer-literal
129 Type storageType;
130 FloatType expressedType;
131 unsigned typeFlags = 0;
132 int64_t storageTypeMin;
133 int64_t storageTypeMax;
134
135 // Type specification.
136 if (parser.parseLess())
137 return nullptr;
138
139 // Storage type.
140 bool isSigned = false;
141 storageType = parseStorageType(parser, isSigned);
142 if (!storageType) {
143 return nullptr;
144 }
145 if (isSigned) {
146 typeFlags |= QuantizationFlags::Signed;
147 }
148
149 // Storage type range.
150 if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
151 storageTypeMax)) {
152 return nullptr;
153 }
154
155 // Optional expressed type.
156 if (succeeded(parser.parseOptionalColon())) {
157 if (parser.parseType(expressedType)) {
158 return nullptr;
159 }
160 }
161
162 if (parser.parseGreater()) {
163 return nullptr;
164 }
165
166 return parser.getChecked<AnyQuantizedType>(
167 typeFlags, storageType, expressedType, storageTypeMin, storageTypeMax);
168}
169
170/// Checks if the given scale value is within the valid range of the expressed
171/// type. The `expressedType` argument is the floating-point type used for
172/// expressing the quantized values, and `scale` is the double value to check.
173static LogicalResult
175 Type expressedType, double scale) {
176 auto floatType = cast<FloatType>(expressedType);
177 double minScale =
178 APFloat::getSmallest(floatType.getFloatSemantics()).convertToDouble();
179 double maxScale =
180 APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble();
181 if (scale < minScale || scale > maxScale)
182 return emitError() << "scale " << scale << " out of expressed type range ["
183 << minScale << ", " << maxScale << "]";
184 return success();
185}
186
187/// Parses a quantization parameter, which is either a scale value (float) or a
188/// scale-zero point pair (float:integer). `expressedType`, expressing the type
189/// of scale values, is used to validate the scale. The parsed scale and zero
190/// point (if any) are stored in `scale` and `zeroPoint`.
191static ParseResult parseQuantParams(DialectAsmParser &parser,
192 Type expressedType, double &scale,
193 int64_t &zeroPoint) {
194
195 if (parser.parseFloat(scale)) {
196 return failure();
197 }
198
200 [&]() { return parser.emitError(parser.getCurrentLocation()); },
201 expressedType, scale))) {
202 return failure();
203 }
204
205 zeroPoint = 0;
206 if (failed(parser.parseOptionalColon())) {
207 return success();
208 }
209
210 return parser.parseInteger(zeroPoint);
211}
212
213/// Parses block size information for sub-channel quantization, assuming the
214/// leading '{' has already been parsed. The block size information is provided
215/// as a comma-separated list of "Axis:BlockSize" pairs, terminated by a '}'.
216///
217/// The parsed axis indices are stored in `quantizedDimensions`, and the
218/// corresponding block sizes are stored in `blockSizes`.
219static ParseResult
221 SmallVectorImpl<int32_t> &quantizedDimensions,
222 SmallVectorImpl<int64_t> &blockSizes) {
223 // Empty block-sizes info.
224 if (succeeded(parser.parseOptionalRBrace())) {
225 return success();
226 }
227
228 auto parseBlockSizeElements = [&]() -> ParseResult {
229 quantizedDimensions.resize(quantizedDimensions.size() + 1);
230 blockSizes.resize(blockSizes.size() + 1);
231 if (parser.parseInteger(quantizedDimensions.back()) ||
232 parser.parseColon() || parser.parseInteger(blockSizes.back()))
233 return failure();
234 return success();
235 };
236
237 if (parser.parseCommaSeparatedList(parseBlockSizeElements) ||
238 parser.parseRBrace()) {
239 return failure();
240 }
241
242 return success();
243}
244
245/// Parses a bracketed list of quantization parameters, returning the dimensions
246/// of the parsed sub-tensors in `dims`. The dimension of the list is prepended
247/// to the dimensions of the sub-tensors. This function assumes that the initial
248/// left brace has already been parsed. For example:
249///
250/// parseQuantParamListUntilRBrace(1.0:1, 2.0:4, 3.0:4}) -> Success,
251/// dims = [3], scales = [1.0, 2.0, 3.0], zeroPoints = [1, 4, 4]
252///
253/// parseQuantParamListUntilRBrace({1.0, 2.0}, {3.0:1, 4.0:9}}) -> Success,
254/// dims = [2, 2], scales = [1.0, 2.0, 3.0, 4.0], zeroPoints = [0, 0, 1,
255/// 9]
256///
257/// This function expects all sub-tensors to have the same rank.
258static ParseResult
261 SmallVectorImpl<int64_t> &zeroPoints,
263 auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims,
264 const SmallVectorImpl<int64_t> &newDims) -> ParseResult {
265 if (prevDims == newDims)
266 return success();
267 return parser.emitError(parser.getCurrentLocation())
268 << "tensor literal is invalid; ranks are not consistent "
269 "between elements";
270 };
271
272 bool first = true;
274 unsigned size = 0;
275
276 auto parseOneElement = [&]() -> ParseResult {
278 if (succeeded(parser.parseOptionalLBrace())) {
279 if (parseQuantParamListUntilRBrace(parser, expressedType, scales,
280 zeroPoints, thisDims))
281 return failure();
282 } else {
283 zeroPoints.resize(zeroPoints.size() + 1);
284 scales.resize(scales.size() + 1);
285 if (parseQuantParams(parser, expressedType, scales.back(),
286 zeroPoints.back())) {
287 return failure();
288 }
289 }
290 ++size;
291 if (!first)
292 return checkDims(newDims, thisDims);
293 newDims = thisDims;
294 first = false;
295 return success();
296 };
297
298 if (parser.parseCommaSeparatedList(parseOneElement) || parser.parseRBrace()) {
299 return failure();
300 }
301
302 // Return the sublists' dimensions with 'size' prepended.
303 dims.clear();
304 dims.push_back(size);
305 dims.append(newDims.begin(), newDims.end());
306
307 return success();
308}
309
310/// Parses a UniformQuantizedType.
311///
312/// uniform_type ::= uniform_per_layer
313/// | uniform_per_axis
314/// | uniform_sub_channel
315/// uniform_per_layer ::= `uniform<` storage-spec expressed-type-spec
316/// `,` scale-zero `>`
317/// uniform_per_axis ::= `uniform<` storage-spec expressed-type-spec
318/// axis-spec `,` `{` scale-zero-list `}` `>`
319/// uniform_sub_channel ::= `uniform<` storage-spec expressed-type-spec
320/// block-size-info `,` scale-zero-tensor `>`
321/// storage-spec ::= storage-type (`<` storage-range `>`)?
322/// storage-range ::= integer-literal `:` integer-literal
323/// storage-type ::= (`i` | `u`) integer-literal
324/// expressed-type-spec ::= `:` `f` integer-literal
325/// axis-spec ::= `:` integer-literal
326/// scale-zero ::= scale (`:` zero-point)?
327/// scale ::= float-literal
328/// zero-point ::= integer-literal
329/// scale-zero-list ::= scale-zero (`,` scale-zero)*
330/// block-size-info ::= `{` `}` | `{` axis-block `:` (`,` axis-block)* `}`
331/// axis-block ::= axis-spec `:` block-size-spec
332/// block-size-spec ::= integer-literal
333/// scale-zero-tensor ::= scale-zero-dense-exp | scale-zero-list
334/// scale-zero-dense-exp ::= `{`
335/// scale-zero-tensor (`,` scale-zero-tensor)*
336/// `}`
338 Type storageType;
339 FloatType expressedType;
340 unsigned typeFlags = 0;
341 int64_t storageTypeMin;
342 int64_t storageTypeMax;
343 bool isPerAxis = false;
344 bool isSubChannel = false;
345 SmallVector<int32_t, 1> quantizedDimensions;
346 SmallVector<int64_t, 1> blockSizes;
348 SmallVector<int64_t, 1> zeroPoints;
349
350 // Type specification.
351 if (parser.parseLess()) {
352 return nullptr;
353 }
354
355 // Storage type.
356 bool isSigned = false;
357 storageType = parseStorageType(parser, isSigned);
358 if (!storageType) {
359 return nullptr;
360 }
361 if (isSigned) {
362 typeFlags |= QuantizationFlags::Signed;
363 }
364
365 // Storage type range.
366 if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
367 storageTypeMax)) {
368 return nullptr;
369 }
370
371 // Expressed type.
372 if (parser.parseColon() || parser.parseType(expressedType)) {
373 return nullptr;
374 }
375
376 // Optionally parse quantized dimension for per-axis or sub-channel
377 // quantization.
378 if (succeeded(parser.parseOptionalColon())) {
379 if (succeeded(parser.parseOptionalLBrace())) {
380 isSubChannel = true;
381 if (parseBlockSizeInfoUntilRBrace(parser, quantizedDimensions,
382 blockSizes)) {
383 return nullptr;
384 }
385 } else {
386 isPerAxis = true;
387 quantizedDimensions.resize(1);
388 if (parser.parseInteger(quantizedDimensions.back())) {
389 return nullptr;
390 }
391 }
392 }
393
394 // Comma leading into range_spec.
395 if (parser.parseComma()) {
396 return nullptr;
397 }
398
399 // Quantization parameter (scales/zeroPoints) specification.
400 bool isPerTensor = !isPerAxis && !isSubChannel;
402 if (isPerTensor) {
403 zeroPoints.resize(zeroPoints.size() + 1);
404 scales.resize(scales.size() + 1);
405 if (parseQuantParams(parser, expressedType, scales.back(),
406 zeroPoints.back())) {
407 return nullptr;
408 }
409
410 } else {
411 if (parser.parseLBrace() ||
412 parseQuantParamListUntilRBrace(parser, expressedType, scales,
413 zeroPoints, dims)) {
414 return nullptr;
415 }
416 }
417
418 if (parser.parseGreater()) {
419 return nullptr;
420 }
421
422 if (isPerAxis) {
424 typeFlags, storageType, expressedType, scales, zeroPoints,
425 quantizedDimensions[0], storageTypeMin, storageTypeMax);
426 }
427 if (isSubChannel) {
428 SmallVector<APFloat> apFloatScales =
429 llvm::map_to_vector(scales, [&](double scale) -> APFloat {
430 APFloat apFloatScale(scale);
431 bool unused;
432 apFloatScale.convert(expressedType.getFloatSemantics(),
433 APFloat::rmNearestTiesToEven, &unused);
434 return apFloatScale;
435 });
436 SmallVector<APInt> apIntZeroPoints =
437 llvm::map_to_vector(zeroPoints, [&](int64_t zeroPoint) -> APInt {
438 return APInt(storageType.getIntOrFloatBitWidth(), zeroPoint);
439 });
440 auto scalesRef = mlir::DenseElementsAttr::get(
441 RankedTensorType::get(dims, expressedType), apFloatScales);
442 auto zeroPointsRef = mlir::DenseElementsAttr::get(
443 RankedTensorType::get(dims, storageType), apIntZeroPoints);
445 typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
446 quantizedDimensions, blockSizes, storageTypeMin, storageTypeMax);
447 }
448
449 return parser.getChecked<UniformQuantizedType>(
450 typeFlags, storageType, expressedType, scales.front(), zeroPoints.front(),
451 storageTypeMin, storageTypeMax);
452}
453
454/// Parses an CalibratedQuantizedType.
455///
456/// calibrated ::= `calibrated<` expressed-spec `>`
457/// expressed-spec ::= expressed-type `<` calibrated-range `>`
458/// expressed-type ::= `f` integer-literal
459/// calibrated-range ::= float-literal `:` float-literal
461 FloatType expressedType;
462 double min;
463 double max;
464
465 // Type specification.
466 if (parser.parseLess())
467 return nullptr;
468
469 // Expressed type.
470 expressedType = parseExpressedTypeAndRange(parser, min, max);
471 if (!expressedType) {
472 return nullptr;
473 }
474
475 if (parser.parseGreater()) {
476 return nullptr;
477 }
478
479 return parser.getChecked<CalibratedQuantizedType>(expressedType, min, max);
480}
481
482/// Parse a type registered to this dialect.
483Type QuantDialect::parseType(DialectAsmParser &parser) const {
484 // All types start with an identifier that we switch on.
485 StringRef typeNameSpelling;
486 if (failed(parser.parseKeyword(&typeNameSpelling)))
487 return nullptr;
488
489 if (typeNameSpelling == "uniform")
490 return parseUniformType(parser);
491 if (typeNameSpelling == "any")
492 return parseAnyType(parser);
493 if (typeNameSpelling == "calibrated")
494 return parseCalibratedType(parser);
495
496 parser.emitError(parser.getNameLoc(),
497 "unknown quantized type " + typeNameSpelling);
498 return nullptr;
499}
500
502 // storage type
503 auto quantStorageTypeInterface =
504 llvm::dyn_cast<QuantStorageTypeInterface>(type.getStorageType());
505
506 out << quantStorageTypeInterface.getStorageTypeName(type.isSigned());
507
508 // storageTypeMin and storageTypeMax if not default.
509 if (type.hasStorageTypeBounds()) {
510 out << "<" << type.getStorageTypeMin() << ":" << type.getStorageTypeMax()
511 << ">";
512 }
513}
514
515static void printQuantParams(double scale, int64_t zeroPoint,
516 DialectAsmPrinter &out) {
517 out << scale;
518 if (zeroPoint != 0) {
519 out << ":" << zeroPoint;
520 }
521}
522
523static void
524printBlockSizeInfo(ArrayRef<std::pair<int32_t, int64_t>> blockSizeInfo,
525 DialectAsmPrinter &out) {
526 out << "{";
527 llvm::interleaveComma(
528 llvm::seq<size_t>(0, blockSizeInfo.size()), out, [&](size_t index) {
529 out << blockSizeInfo[index].first << ":" << blockSizeInfo[index].second;
530 });
531 out << "}";
532}
533
534/// Helper that prints a AnyQuantizedType.
536 DialectAsmPrinter &out) {
537 out << "any<";
538 printStorageType(type, out);
539 if (Type expressedType = type.getExpressedType()) {
540 out << ":" << expressedType;
541 }
542 out << ">";
543}
544
545/// Helper that prints a UniformQuantizedType.
547 DialectAsmPrinter &out) {
548 out << "uniform<";
549 printStorageType(type, out);
550 out << ":" << type.getExpressedType() << ", ";
551
552 // scheme specific parameters
553 printQuantParams(type.getScale(), type.getZeroPoint(), out);
554 out << ">";
555}
556
557/// Helper that prints a UniformQuantizedPerAxisType.
559 DialectAsmPrinter &out) {
560 out << "uniform<";
561 printStorageType(type, out);
562 out << ":" << type.getExpressedType() << ":";
563 out << type.getQuantizedDimension();
564 out << ", ";
565
566 // scheme specific parameters
567 ArrayRef<double> scales = type.getScales();
568 ArrayRef<int64_t> zeroPoints = type.getZeroPoints();
569 out << "{";
570 llvm::interleave(
571 llvm::seq<size_t>(0, scales.size()), out,
572 [&](size_t index) {
573 printQuantParams(scales[index], zeroPoints[index], out);
574 },
575 ",");
576 out << "}>";
577}
578
579/// Prints quantization parameters as a nested list of `scale`[:`zero_point`]
580/// elements. The nesting corresponds to the `shape` dimensions.
581///
582/// Elements are delimited by commas, and the inner dimensions are enclosed in
583/// braces. `zero_point` is only printed if it is non-zero. For example:
584///
585/// printDenseQuantizationParameters(scales=[1.0, 2.0, 3.0, 4.0],
586/// zeroPoints=[0, 0, 1, 9],
587/// shape=[2, 2])
588///
589/// would print:
590///
591/// {{1.0, 2.0}, {3.0:1, 4.0:9}}
593 ArrayRef<APInt> zeroPoints,
595 DialectAsmPrinter &out) {
596 int64_t rank = shape.size();
597 SmallVector<unsigned, 4> counter(rank, 0);
598 unsigned openBrackets = 0;
599
600 auto incrementCounterAndDelimit = [&]() {
601 ++counter[rank - 1];
602 for (unsigned i = rank - 1; i > 0; --i) {
603 if (counter[i] >= shape[i]) {
604 counter[i] = 0;
605 ++counter[i - 1];
606 --openBrackets;
607 out << '}';
608 }
609 }
610 };
611
612 for (unsigned idx = 0, e = scales.size(); idx < e; ++idx) {
613 if (idx != 0)
614 out << ", ";
615 while (openBrackets++ < rank)
616 out << '{';
617 openBrackets = rank;
618 out << scales[idx];
619 if (zeroPoints[idx] != 0) {
620 out << ":" << zeroPoints[idx];
621 }
622 incrementCounterAndDelimit();
623 }
624 while (openBrackets-- > 0)
625 out << '}';
626}
627
628/// Helper that prints a UniformQuantizedSubChannelType.
629static void
631 DialectAsmPrinter &out) {
632 out << "uniform<";
633 printStorageType(type, out);
634 out << ":" << type.getExpressedType() << ":";
636 out << ", ";
637
638 auto scalesItr = type.getScales().getValues<APFloat>();
639 auto zeroPointsItr = type.getZeroPoints().getValues<APInt>();
640 SmallVector<APFloat> scales(scalesItr.begin(), scalesItr.end());
641 SmallVector<APInt> zeroPoints(zeroPointsItr.begin(), zeroPointsItr.end());
642 printDenseQuantizationParameters(scales, zeroPoints,
643 type.getScales().getType().getShape(), out);
644 out << ">";
645}
646
647/// Helper that prints a CalibratedQuantizedType.
649 DialectAsmPrinter &out) {
650 out << "calibrated<" << type.getExpressedType();
651 out << "<" << type.getMin() << ":" << type.getMax() << ">";
652 out << ">";
653}
654
655/// Print a type registered to this dialect.
656void QuantDialect::printType(Type type, DialectAsmPrinter &os) const {
657 if (auto anyType = llvm::dyn_cast<AnyQuantizedType>(type))
658 printAnyQuantizedType(anyType, os);
659 else if (auto uniformType = llvm::dyn_cast<UniformQuantizedType>(type))
660 printUniformQuantizedType(uniformType, os);
661 else if (auto perAxisType = llvm::dyn_cast<UniformQuantizedPerAxisType>(type))
662 printUniformQuantizedPerAxisType(perAxisType, os);
663 else if (auto perAxisType =
664 llvm::dyn_cast<UniformQuantizedSubChannelType>(type))
666 else if (auto calibratedType = llvm::dyn_cast<CalibratedQuantizedType>(type))
667 printCalibratedQuantizedType(calibratedType, os);
668 else
669 llvm_unreachable("Unhandled quantized type");
670}
return success()
static void printBlockSizeInfo(ArrayRef< std::pair< int32_t, int64_t > > blockSizeInfo, DialectAsmPrinter &out)
static ParseResult parseQuantParams(DialectAsmParser &parser, Type expressedType, double &scale, int64_t &zeroPoint)
Parses a quantization parameter, which is either a scale value (float) or a scale-zero point pair (fl...
static void printAnyQuantizedType(AnyQuantizedType type, DialectAsmPrinter &out)
Helper that prints a AnyQuantizedType.
static Type parseStorageType(DialectAsmParser &parser, bool &isSigned)
static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser, double &min, double &max)
static ParseResult parseStorageRange(DialectAsmParser &parser, Type storageType, bool isSigned, int64_t &storageTypeMin, int64_t &storageTypeMax)
static void printUniformQuantizedSubChannelType(UniformQuantizedSubChannelType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedSubChannelType.
static Type parseUniformType(DialectAsmParser &parser)
Parses a UniformQuantizedType.
static Type parseAnyType(DialectAsmParser &parser)
Parses an AnyQuantizedType.
static void printStorageType(QuantizedType type, DialectAsmPrinter &out)
static ParseResult parseBlockSizeInfoUntilRBrace(DialectAsmParser &parser, SmallVectorImpl< int32_t > &quantizedDimensions, SmallVectorImpl< int64_t > &blockSizes)
Parses block size information for sub-channel quantization, assuming the leading '{' has already been...
static void printQuantParams(double scale, int64_t zeroPoint, DialectAsmPrinter &out)
static Type parseCalibratedType(DialectAsmParser &parser)
Parses an CalibratedQuantizedType.
static ParseResult parseQuantParamListUntilRBrace(DialectAsmParser &parser, Type expressedType, SmallVectorImpl< double > &scales, SmallVectorImpl< int64_t > &zeroPoints, SmallVectorImpl< int64_t > &dims)
Parses a bracketed list of quantization parameters, returning the dimensions of the parsed sub-tensor...
static void printCalibratedQuantizedType(CalibratedQuantizedType type, DialectAsmPrinter &out)
Helper that prints a CalibratedQuantizedType.
static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedPerAxisType.
static void printUniformQuantizedType(UniformQuantizedType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedType.
static LogicalResult isScaleInExpressedTypeRange(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double scale)
Checks if the given scale value is within the valid range of the expressed type.
static void printDenseQuantizationParameters(ArrayRef< APFloat > scales, ArrayRef< APInt > zeroPoints, ArrayRef< int64_t > shape, DialectAsmPrinter &out)
Prints quantization parameters as a nested list of scale[:zero_point] elements.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
virtual ParseResult parseLBrace()=0
Parse a { token.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalRBrace()=0
Parse a } token if present.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual OptionalParseResult parseOptionalType(Type &result)=0
Parse an optional type.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseFloat(double &result)=0
Parse a floating point value from the stream.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class represents a diagnostic that is inflight and set to be reported.
This class implements Optional functionality for ParseResult.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
A quantized type that maps storage to/from expressed types in an unspecified way.
Definition QuantTypes.h:203
A quantized type that infers its range from given min/max values.
Definition QuantTypes.h:524
Base class for all quantized types known to this dialect.
Definition QuantTypes.h:50
bool hasStorageTypeBounds() const
Return whether the storage type has explicit min or max boundaries different from the minimum and max...
bool isSigned() const
Whether the storage type should be interpreted as a signed quantity (true) or an unsigned value (fals...
Definition QuantTypes.h:103
int64_t getStorageTypeMax() const
The maximum value that storageType can take.
int64_t getStorageTypeMin() const
The minimum value that storageType can take.
Type getStorageType() const
Gets the underlying type used for to store values.
Represents per-axis (also known as per-channel quantization).
Definition QuantTypes.h:324
int32_t getQuantizedDimension() const
Specifies the dimension of the Tensor's shape that the scales and zero_points correspond to.
ArrayRef< int64_t > getZeroPoints() const
Gets the storage values corresponding to the real value 0 in the affine equation.
ArrayRef< double > getScales() const
Gets the quantization scales.
Represents sub-channel (also known as blockwise quantization).
Definition QuantTypes.h:409
DenseElementsAttr getZeroPoints() const
Gets the quantization zero-points.
const SmallVector< std::pair< int32_t, int64_t > > getBlockSizeInfo() const
Gets the block size information.
DenseElementsAttr getScales() const
Gets the quantization scales.
Represents a family of uniform, quantized types.
Definition QuantTypes.h:264
double getScale() const
Gets the scale term.
int64_t getZeroPoint() const
Gets the storage value corresponding to the real value 0 in the affine equation.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144