MLIR 22.0.0git
TypeParser.cpp
Go to the documentation of this file.
1//===- TypeParser.h - Quantization Type Parser ------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/IR/Types.h"
14#include "llvm/ADT/APFloat.h"
15
16using namespace mlir;
17using namespace quant;
18
19static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
20 auto typeLoc = parser.getCurrentLocation();
21 IntegerType type;
22
23 // Parse storage type (alpha_ident, integer_literal).
24 StringRef identifier;
25 unsigned storageTypeWidth = 0;
27 if (result.has_value()) {
28 if (!succeeded(*result))
29 return nullptr;
30 isSigned = !type.isUnsigned();
31 storageTypeWidth = type.getWidth();
32 } else if (succeeded(parser.parseKeyword(&identifier))) {
33 // Otherwise, this must be an unsigned integer (`u` integer-literal).
34 if (!identifier.consume_front("u")) {
35 parser.emitError(typeLoc, "illegal storage type prefix");
36 return nullptr;
37 }
38 if (identifier.getAsInteger(10, storageTypeWidth)) {
39 parser.emitError(typeLoc, "expected storage type width");
40 return nullptr;
41 }
42 isSigned = false;
43 type = parser.getBuilder().getIntegerType(storageTypeWidth);
44 } else {
45 return nullptr;
46 }
47
48 if (storageTypeWidth == 0 ||
49 storageTypeWidth > QuantizedType::MaxStorageBits) {
50 parser.emitError(typeLoc, "illegal storage type size: ")
51 << storageTypeWidth;
52 return nullptr;
53 }
54
55 return type;
56}
57
58static ParseResult parseStorageRange(DialectAsmParser &parser,
59 IntegerType storageType, bool isSigned,
60 int64_t &storageTypeMin,
61 int64_t &storageTypeMax) {
63 isSigned, storageType.getWidth());
65 isSigned, storageType.getWidth());
66 if (failed(parser.parseOptionalLess())) {
67 storageTypeMin = defaultIntegerMin;
68 storageTypeMax = defaultIntegerMax;
69 return success();
70 }
71
72 // Explicit storage min and storage max.
73 SMLoc minLoc = parser.getCurrentLocation(), maxLoc;
74 if (parser.parseInteger(storageTypeMin) || parser.parseColon() ||
75 parser.getCurrentLocation(&maxLoc) ||
76 parser.parseInteger(storageTypeMax) || parser.parseGreater())
77 return failure();
78 if (storageTypeMin < defaultIntegerMin) {
79 return parser.emitError(minLoc, "illegal storage type minimum: ")
80 << storageTypeMin;
81 }
82 if (storageTypeMax > defaultIntegerMax) {
83 return parser.emitError(maxLoc, "illegal storage type maximum: ")
84 << storageTypeMax;
85 }
86 return success();
87}
88
90 double &min, double &max) {
91 auto typeLoc = parser.getCurrentLocation();
92 FloatType type;
93
94 if (failed(parser.parseType(type))) {
95 parser.emitError(typeLoc, "expecting float expressed type");
96 return nullptr;
97 }
98
99 // Calibrated min and max values.
100 if (parser.parseLess() || parser.parseFloat(min) || parser.parseColon() ||
101 parser.parseFloat(max) || parser.parseGreater()) {
102 parser.emitError(typeLoc, "calibrated values must be present");
103 return nullptr;
104 }
105 return type;
106}
107
108/// Parses an AnyQuantizedType.
109///
110/// any ::= `any<` storage-spec (expressed-type-spec)?`>`
111/// storage-spec ::= storage-type (`<` storage-range `>`)?
112/// storage-range ::= integer-literal `:` integer-literal
113/// storage-type ::= (`i` | `u`) integer-literal
114/// expressed-type-spec ::= `:` `f` integer-literal
116 IntegerType storageType;
117 FloatType expressedType;
118 unsigned typeFlags = 0;
119 int64_t storageTypeMin;
120 int64_t storageTypeMax;
121
122 // Type specification.
123 if (parser.parseLess())
124 return nullptr;
125
126 // Storage type.
127 bool isSigned = false;
128 storageType = parseStorageType(parser, isSigned);
129 if (!storageType) {
130 return nullptr;
131 }
132 if (isSigned) {
133 typeFlags |= QuantizationFlags::Signed;
134 }
135
136 // Storage type range.
137 if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
138 storageTypeMax)) {
139 return nullptr;
140 }
141
142 // Optional expressed type.
143 if (succeeded(parser.parseOptionalColon())) {
144 if (parser.parseType(expressedType)) {
145 return nullptr;
146 }
147 }
148
149 if (parser.parseGreater()) {
150 return nullptr;
151 }
152
153 return parser.getChecked<AnyQuantizedType>(
154 typeFlags, storageType, expressedType, storageTypeMin, storageTypeMax);
155}
156
157/// Checks if the given scale value is within the valid range of the expressed
158/// type. The `expressedType` argument is the floating-point type used for
159/// expressing the quantized values, and `scale` is the double value to check.
160static LogicalResult
162 Type expressedType, double scale) {
163 auto floatType = cast<FloatType>(expressedType);
164 double minScale =
165 APFloat::getSmallest(floatType.getFloatSemantics()).convertToDouble();
166 double maxScale =
167 APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble();
168 if (scale < minScale || scale > maxScale)
169 return emitError() << "scale " << scale << " out of expressed type range ["
170 << minScale << ", " << maxScale << "]";
171 return success();
172}
173
174/// Parses a quantization parameter, which is either a scale value (float) or a
175/// scale-zero point pair (float:integer). `expressedType`, expressing the type
176/// of scale values, is used to validate the scale. The parsed scale and zero
177/// point (if any) are stored in `scale` and `zeroPoint`.
178static ParseResult parseQuantParams(DialectAsmParser &parser,
179 Type expressedType, double &scale,
180 int64_t &zeroPoint) {
181
182 if (parser.parseFloat(scale)) {
183 return failure();
184 }
185
187 [&]() { return parser.emitError(parser.getCurrentLocation()); },
188 expressedType, scale))) {
189 return failure();
190 }
191
192 zeroPoint = 0;
193 if (failed(parser.parseOptionalColon())) {
194 return success();
195 }
196
197 return parser.parseInteger(zeroPoint);
198}
199
200/// Parses block size information for sub-channel quantization, assuming the
201/// leading '{' has already been parsed. The block size information is provided
202/// as a comma-separated list of "Axis:BlockSize" pairs, terminated by a '}'.
203///
204/// The parsed axis indices are stored in `quantizedDimensions`, and the
205/// corresponding block sizes are stored in `blockSizes`.
206static ParseResult
208 SmallVectorImpl<int32_t> &quantizedDimensions,
209 SmallVectorImpl<int64_t> &blockSizes) {
210 // Empty block-sizes info.
211 if (succeeded(parser.parseOptionalRBrace())) {
212 return success();
213 }
214
215 auto parseBlockSizeElements = [&]() -> ParseResult {
216 quantizedDimensions.resize(quantizedDimensions.size() + 1);
217 blockSizes.resize(blockSizes.size() + 1);
218 if (parser.parseInteger(quantizedDimensions.back()) ||
219 parser.parseColon() || parser.parseInteger(blockSizes.back()))
220 return failure();
221 return success();
222 };
223
224 if (parser.parseCommaSeparatedList(parseBlockSizeElements) ||
225 parser.parseRBrace()) {
226 return failure();
227 }
228
229 return success();
230}
231
232/// Parses a bracketed list of quantization parameters, returning the dimensions
233/// of the parsed sub-tensors in `dims`. The dimension of the list is prepended
234/// to the dimensions of the sub-tensors. This function assumes that the initial
235/// left brace has already been parsed. For example:
236///
237/// parseQuantParamListUntilRBrace(1.0:1, 2.0:4, 3.0:4}) -> Success,
238/// dims = [3], scales = [1.0, 2.0, 3.0], zeroPoints = [1, 4, 4]
239///
240/// parseQuantParamListUntilRBrace({1.0, 2.0}, {3.0:1, 4.0:9}}) -> Success,
241/// dims = [2, 2], scales = [1.0, 2.0, 3.0, 4.0], zeroPoints = [0, 0, 1,
242/// 9]
243///
244/// This function expects all sub-tensors to have the same rank.
245static ParseResult
248 SmallVectorImpl<int64_t> &zeroPoints,
250 auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims,
251 const SmallVectorImpl<int64_t> &newDims) -> ParseResult {
252 if (prevDims == newDims)
253 return success();
254 return parser.emitError(parser.getCurrentLocation())
255 << "tensor literal is invalid; ranks are not consistent "
256 "between elements";
257 };
258
259 bool first = true;
261 unsigned size = 0;
262
263 auto parseOneElement = [&]() -> ParseResult {
265 if (succeeded(parser.parseOptionalLBrace())) {
266 if (parseQuantParamListUntilRBrace(parser, expressedType, scales,
267 zeroPoints, thisDims))
268 return failure();
269 } else {
270 zeroPoints.resize(zeroPoints.size() + 1);
271 scales.resize(scales.size() + 1);
272 if (parseQuantParams(parser, expressedType, scales.back(),
273 zeroPoints.back())) {
274 return failure();
275 }
276 }
277 ++size;
278 if (!first)
279 return checkDims(newDims, thisDims);
280 newDims = thisDims;
281 first = false;
282 return success();
283 };
284
285 if (parser.parseCommaSeparatedList(parseOneElement) || parser.parseRBrace()) {
286 return failure();
287 }
288
289 // Return the sublists' dimensions with 'size' prepended.
290 dims.clear();
291 dims.push_back(size);
292 dims.append(newDims.begin(), newDims.end());
293
294 return success();
295}
296
297/// Parses a UniformQuantizedType.
298///
299/// uniform_type ::= uniform_per_layer
300/// | uniform_per_axis
301/// | uniform_sub_channel
302/// uniform_per_layer ::= `uniform<` storage-spec expressed-type-spec
303/// `,` scale-zero `>`
304/// uniform_per_axis ::= `uniform<` storage-spec expressed-type-spec
305/// axis-spec `,` `{` scale-zero-list `}` `>`
306/// uniform_sub_channel ::= `uniform<` storage-spec expressed-type-spec
307/// block-size-info `,` scale-zero-tensor `>`
308/// storage-spec ::= storage-type (`<` storage-range `>`)?
309/// storage-range ::= integer-literal `:` integer-literal
310/// storage-type ::= (`i` | `u`) integer-literal
311/// expressed-type-spec ::= `:` `f` integer-literal
312/// axis-spec ::= `:` integer-literal
313/// scale-zero ::= scale (`:` zero-point)?
314/// scale ::= float-literal
315/// zero-point ::= integer-literal
316/// scale-zero-list ::= scale-zero (`,` scale-zero)*
317/// block-size-info ::= `{` `}` | `{` axis-block `:` (`,` axis-block)* `}`
318/// axis-block ::= axis-spec `:` block-size-spec
319/// block-size-spec ::= integer-literal
320/// scale-zero-tensor ::= scale-zero-dense-exp | scale-zero-list
321/// scale-zero-dense-exp ::= `{`
322/// scale-zero-tensor (`,` scale-zero-tensor)*
323/// `}`
325 IntegerType storageType;
326 FloatType expressedType;
327 unsigned typeFlags = 0;
328 int64_t storageTypeMin;
329 int64_t storageTypeMax;
330 bool isPerAxis = false;
331 bool isSubChannel = false;
332 SmallVector<int32_t, 1> quantizedDimensions;
333 SmallVector<int64_t, 1> blockSizes;
335 SmallVector<int64_t, 1> zeroPoints;
336
337 // Type specification.
338 if (parser.parseLess()) {
339 return nullptr;
340 }
341
342 // Storage type.
343 bool isSigned = false;
344 storageType = parseStorageType(parser, isSigned);
345 if (!storageType) {
346 return nullptr;
347 }
348 if (isSigned) {
349 typeFlags |= QuantizationFlags::Signed;
350 }
351
352 // Storage type range.
353 if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
354 storageTypeMax)) {
355 return nullptr;
356 }
357
358 // Expressed type.
359 if (parser.parseColon() || parser.parseType(expressedType)) {
360 return nullptr;
361 }
362
363 // Optionally parse quantized dimension for per-axis or sub-channel
364 // quantization.
365 if (succeeded(parser.parseOptionalColon())) {
366 if (succeeded(parser.parseOptionalLBrace())) {
367 isSubChannel = true;
368 if (parseBlockSizeInfoUntilRBrace(parser, quantizedDimensions,
369 blockSizes)) {
370 return nullptr;
371 }
372 } else {
373 isPerAxis = true;
374 quantizedDimensions.resize(1);
375 if (parser.parseInteger(quantizedDimensions.back())) {
376 return nullptr;
377 }
378 }
379 }
380
381 // Comma leading into range_spec.
382 if (parser.parseComma()) {
383 return nullptr;
384 }
385
386 // Quantization parameter (scales/zeroPoints) specification.
387 bool isPerTensor = !isPerAxis && !isSubChannel;
389 if (isPerTensor) {
390 zeroPoints.resize(zeroPoints.size() + 1);
391 scales.resize(scales.size() + 1);
392 if (parseQuantParams(parser, expressedType, scales.back(),
393 zeroPoints.back())) {
394 return nullptr;
395 }
396
397 } else {
398 if (parser.parseLBrace() ||
399 parseQuantParamListUntilRBrace(parser, expressedType, scales,
400 zeroPoints, dims)) {
401 return nullptr;
402 }
403 }
404
405 if (parser.parseGreater()) {
406 return nullptr;
407 }
408
409 if (isPerAxis) {
411 typeFlags, storageType, expressedType, scales, zeroPoints,
412 quantizedDimensions[0], storageTypeMin, storageTypeMax);
413 }
414 if (isSubChannel) {
415 SmallVector<APFloat> apFloatScales =
416 llvm::to_vector(llvm::map_range(scales, [&](double scale) -> APFloat {
417 APFloat apFloatScale(scale);
418 bool unused;
419 apFloatScale.convert(expressedType.getFloatSemantics(),
420 APFloat::rmNearestTiesToEven, &unused);
421 return apFloatScale;
422 }));
423 SmallVector<APInt> apIntZeroPoints = llvm::to_vector(
424 llvm::map_range(zeroPoints, [&](int64_t zeroPoint) -> APInt {
425 return APInt(storageType.getIntOrFloatBitWidth(), zeroPoint);
426 }));
427 auto scalesRef = mlir::DenseElementsAttr::get(
428 RankedTensorType::get(dims, expressedType), apFloatScales);
429 auto zeroPointsRef = mlir::DenseElementsAttr::get(
430 RankedTensorType::get(dims, storageType), apIntZeroPoints);
432 typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
433 quantizedDimensions, blockSizes, storageTypeMin, storageTypeMax);
434 }
435
436 return parser.getChecked<UniformQuantizedType>(
437 typeFlags, storageType, expressedType, scales.front(), zeroPoints.front(),
438 storageTypeMin, storageTypeMax);
439}
440
441/// Parses an CalibratedQuantizedType.
442///
443/// calibrated ::= `calibrated<` expressed-spec `>`
444/// expressed-spec ::= expressed-type `<` calibrated-range `>`
445/// expressed-type ::= `f` integer-literal
446/// calibrated-range ::= float-literal `:` float-literal
448 FloatType expressedType;
449 double min;
450 double max;
451
452 // Type specification.
453 if (parser.parseLess())
454 return nullptr;
455
456 // Expressed type.
457 expressedType = parseExpressedTypeAndRange(parser, min, max);
458 if (!expressedType) {
459 return nullptr;
460 }
461
462 if (parser.parseGreater()) {
463 return nullptr;
464 }
465
466 return parser.getChecked<CalibratedQuantizedType>(expressedType, min, max);
467}
468
469/// Parse a type registered to this dialect.
470Type QuantDialect::parseType(DialectAsmParser &parser) const {
471 // All types start with an identifier that we switch on.
472 StringRef typeNameSpelling;
473 if (failed(parser.parseKeyword(&typeNameSpelling)))
474 return nullptr;
475
476 if (typeNameSpelling == "uniform")
477 return parseUniformType(parser);
478 if (typeNameSpelling == "any")
479 return parseAnyType(parser);
480 if (typeNameSpelling == "calibrated")
481 return parseCalibratedType(parser);
482
483 parser.emitError(parser.getNameLoc(),
484 "unknown quantized type " + typeNameSpelling);
485 return nullptr;
486}
487
489 // storage type
490 unsigned storageWidth = type.getStorageTypeIntegralWidth();
491 bool isSigned = type.isSigned();
492 if (isSigned) {
493 out << "i" << storageWidth;
494 } else {
495 out << "u" << storageWidth;
496 }
497
498 // storageTypeMin and storageTypeMax if not default.
499 if (type.hasStorageTypeBounds()) {
500 out << "<" << type.getStorageTypeMin() << ":" << type.getStorageTypeMax()
501 << ">";
502 }
503}
504
505static void printQuantParams(double scale, int64_t zeroPoint,
506 DialectAsmPrinter &out) {
507 out << scale;
508 if (zeroPoint != 0) {
509 out << ":" << zeroPoint;
510 }
511}
512
513static void
514printBlockSizeInfo(ArrayRef<std::pair<int32_t, int64_t>> blockSizeInfo,
515 DialectAsmPrinter &out) {
516 out << "{";
517 llvm::interleaveComma(
518 llvm::seq<size_t>(0, blockSizeInfo.size()), out, [&](size_t index) {
519 out << blockSizeInfo[index].first << ":" << blockSizeInfo[index].second;
520 });
521 out << "}";
522}
523
524/// Helper that prints a AnyQuantizedType.
526 DialectAsmPrinter &out) {
527 out << "any<";
528 printStorageType(type, out);
529 if (Type expressedType = type.getExpressedType()) {
530 out << ":" << expressedType;
531 }
532 out << ">";
533}
534
535/// Helper that prints a UniformQuantizedType.
537 DialectAsmPrinter &out) {
538 out << "uniform<";
539 printStorageType(type, out);
540 out << ":" << type.getExpressedType() << ", ";
541
542 // scheme specific parameters
543 printQuantParams(type.getScale(), type.getZeroPoint(), out);
544 out << ">";
545}
546
547/// Helper that prints a UniformQuantizedPerAxisType.
549 DialectAsmPrinter &out) {
550 out << "uniform<";
551 printStorageType(type, out);
552 out << ":" << type.getExpressedType() << ":";
553 out << type.getQuantizedDimension();
554 out << ", ";
555
556 // scheme specific parameters
557 ArrayRef<double> scales = type.getScales();
558 ArrayRef<int64_t> zeroPoints = type.getZeroPoints();
559 out << "{";
560 llvm::interleave(
561 llvm::seq<size_t>(0, scales.size()), out,
562 [&](size_t index) {
563 printQuantParams(scales[index], zeroPoints[index], out);
564 },
565 ",");
566 out << "}>";
567}
568
569/// Prints quantization parameters as a nested list of `scale`[:`zero_point`]
570/// elements. The nesting corresponds to the `shape` dimensions.
571///
572/// Elements are delimited by commas, and the inner dimensions are enclosed in
573/// braces. `zero_point` is only printed if it is non-zero. For example:
574///
575/// printDenseQuantizationParameters(scales=[1.0, 2.0, 3.0, 4.0],
576/// zeroPoints=[0, 0, 1, 9],
577/// shape=[2, 2])
578///
579/// would print:
580///
581/// {{1.0, 2.0}, {3.0:1, 4.0:9}}
583 ArrayRef<APInt> zeroPoints,
585 DialectAsmPrinter &out) {
586 int64_t rank = shape.size();
587 SmallVector<unsigned, 4> counter(rank, 0);
588 unsigned openBrackets = 0;
589
590 auto incrementCounterAndDelimit = [&]() {
591 ++counter[rank - 1];
592 for (unsigned i = rank - 1; i > 0; --i) {
593 if (counter[i] >= shape[i]) {
594 counter[i] = 0;
595 ++counter[i - 1];
596 --openBrackets;
597 out << '}';
598 }
599 }
600 };
601
602 for (unsigned idx = 0, e = scales.size(); idx < e; ++idx) {
603 if (idx != 0)
604 out << ", ";
605 while (openBrackets++ < rank)
606 out << '{';
607 openBrackets = rank;
608 out << scales[idx];
609 if (zeroPoints[idx] != 0) {
610 out << ":" << zeroPoints[idx];
611 }
612 incrementCounterAndDelimit();
613 }
614 while (openBrackets-- > 0)
615 out << '}';
616}
617
618/// Helper that prints a UniformQuantizedSubChannelType.
619static void
621 DialectAsmPrinter &out) {
622 out << "uniform<";
623 printStorageType(type, out);
624 out << ":" << type.getExpressedType() << ":";
626 out << ", ";
627
628 auto scalesItr = type.getScales().getValues<APFloat>();
629 auto zeroPointsItr = type.getZeroPoints().getValues<APInt>();
630 SmallVector<APFloat> scales(scalesItr.begin(), scalesItr.end());
631 SmallVector<APInt> zeroPoints(zeroPointsItr.begin(), zeroPointsItr.end());
632 printDenseQuantizationParameters(scales, zeroPoints,
633 type.getScales().getType().getShape(), out);
634 out << ">";
635}
636
637/// Helper that prints a CalibratedQuantizedType.
639 DialectAsmPrinter &out) {
640 out << "calibrated<" << type.getExpressedType();
641 out << "<" << type.getMin() << ":" << type.getMax() << ">";
642 out << ">";
643}
644
645/// Print a type registered to this dialect.
646void QuantDialect::printType(Type type, DialectAsmPrinter &os) const {
647 if (auto anyType = llvm::dyn_cast<AnyQuantizedType>(type))
648 printAnyQuantizedType(anyType, os);
649 else if (auto uniformType = llvm::dyn_cast<UniformQuantizedType>(type))
650 printUniformQuantizedType(uniformType, os);
651 else if (auto perAxisType = llvm::dyn_cast<UniformQuantizedPerAxisType>(type))
652 printUniformQuantizedPerAxisType(perAxisType, os);
653 else if (auto perAxisType =
654 llvm::dyn_cast<UniformQuantizedSubChannelType>(type))
656 else if (auto calibratedType = llvm::dyn_cast<CalibratedQuantizedType>(type))
657 printCalibratedQuantizedType(calibratedType, os);
658 else
659 llvm_unreachable("Unhandled quantized type");
660}
return success()
static void printBlockSizeInfo(ArrayRef< std::pair< int32_t, int64_t > > blockSizeInfo, DialectAsmPrinter &out)
static ParseResult parseQuantParams(DialectAsmParser &parser, Type expressedType, double &scale, int64_t &zeroPoint)
Parses a quantization parameter, which is either a scale value (float) or a scale-zero point pair (fl...
static void printAnyQuantizedType(AnyQuantizedType type, DialectAsmPrinter &out)
Helper that prints a AnyQuantizedType.
static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser, double &min, double &max)
static void printUniformQuantizedSubChannelType(UniformQuantizedSubChannelType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedSubChannelType.
static Type parseUniformType(DialectAsmParser &parser)
Parses a UniformQuantizedType.
static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned)
static Type parseAnyType(DialectAsmParser &parser)
Parses an AnyQuantizedType.
static void printStorageType(QuantizedType type, DialectAsmPrinter &out)
static ParseResult parseBlockSizeInfoUntilRBrace(DialectAsmParser &parser, SmallVectorImpl< int32_t > &quantizedDimensions, SmallVectorImpl< int64_t > &blockSizes)
Parses block size information for sub-channel quantization, assuming the leading '{' has already been...
static void printQuantParams(double scale, int64_t zeroPoint, DialectAsmPrinter &out)
static Type parseCalibratedType(DialectAsmParser &parser)
Parses an CalibratedQuantizedType.
static ParseResult parseQuantParamListUntilRBrace(DialectAsmParser &parser, Type expressedType, SmallVectorImpl< double > &scales, SmallVectorImpl< int64_t > &zeroPoints, SmallVectorImpl< int64_t > &dims)
Parses a bracketed list of quantization parameters, returning the dimensions of the parsed sub-tensor...
static ParseResult parseStorageRange(DialectAsmParser &parser, IntegerType storageType, bool isSigned, int64_t &storageTypeMin, int64_t &storageTypeMax)
static void printCalibratedQuantizedType(CalibratedQuantizedType type, DialectAsmPrinter &out)
Helper that prints a CalibratedQuantizedType.
static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedPerAxisType.
static void printUniformQuantizedType(UniformQuantizedType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedType.
static LogicalResult isScaleInExpressedTypeRange(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double scale)
Checks if the given scale value is within the valid range of the expressed type.
static void printDenseQuantizationParameters(ArrayRef< APFloat > scales, ArrayRef< APInt > zeroPoints, ArrayRef< int64_t > shape, DialectAsmPrinter &out)
Prints quantization parameters as a nested list of scale[:zero_point] elements.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
virtual ParseResult parseLBrace()=0
Parse a { token.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalRBrace()=0
Parse a } token if present.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual OptionalParseResult parseOptionalType(Type &result)=0
Parse an optional type.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseFloat(double &result)=0
Parse a floating point value from the stream.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class represents a diagnostic that is inflight and set to be reported.
This class implements Optional functionality for ParseResult.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
A quantized type that maps storage to/from expressed types in an unspecified way.
Definition QuantTypes.h:203
A quantized type that infers its range from given min/max values.
Definition QuantTypes.h:524
Base class for all quantized types known to this dialect.
Definition QuantTypes.h:50
static constexpr unsigned MaxStorageBits
The maximum number of bits supported for storage types.
Definition QuantTypes.h:56
bool hasStorageTypeBounds() const
Return whether the storage type has explicit min or max boundaries different from the minimum and max...
bool isSigned() const
Whether the storage type should be interpreted as a signed quantity (true) or an unsigned value (fals...
Definition QuantTypes.h:103
int64_t getStorageTypeMax() const
The maximum value that storageType can take.
static int64_t getDefaultMaximumForInteger(bool isSigned, unsigned integralWidth)
Gets the maximum possible stored by a storageType.
Definition QuantTypes.h:78
unsigned getStorageTypeIntegralWidth() const
Gets the integral bit width that the underlying storage type can exactly represent.
int64_t getStorageTypeMin() const
The minimum value that storageType can take.
static int64_t getDefaultMinimumForInteger(bool isSigned, unsigned integralWidth)
Gets the minimum possible stored by a storageType.
Definition QuantTypes.h:68
Represents per-axis (also known as per-channel quantization).
Definition QuantTypes.h:324
int32_t getQuantizedDimension() const
Specifies the dimension of the Tensor's shape that the scales and zero_points correspond to.
ArrayRef< int64_t > getZeroPoints() const
Gets the storage values corresponding to the real value 0 in the affine equation.
ArrayRef< double > getScales() const
Gets the quantization scales.
Represents sub-channel (also known as blockwise quantization).
Definition QuantTypes.h:409
DenseElementsAttr getZeroPoints() const
Gets the quantization zero-points.
const SmallVector< std::pair< int32_t, int64_t > > getBlockSizeInfo() const
Gets the block size information.
DenseElementsAttr getScales() const
Gets the quantization scales.
Represents a family of uniform, quantized types.
Definition QuantTypes.h:264
double getScale() const
Gets the scale term.
int64_t getZeroPoint() const
Gets the storage value corresponding to the real value 0 in the affine equation.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152