MLIR 23.0.0git
TypeParser.cpp
Go to the documentation of this file.
1//===- TypeParser.h - Quantization Type Parser ------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/IR/Types.h"
14#include "llvm/ADT/APFloat.h"
15#include "llvm/ADT/SmallVectorExtras.h"
16
17using namespace mlir;
18using namespace quant;
19
20static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
21 auto typeLoc = parser.getCurrentLocation();
22 IntegerType type;
23
24 // Parse storage type (alpha_ident, integer_literal).
25 StringRef identifier;
26 unsigned storageTypeWidth = 0;
28 if (result.has_value()) {
29 if (!succeeded(*result))
30 return nullptr;
31 isSigned = !type.isUnsigned();
32 storageTypeWidth = type.getWidth();
33 } else if (succeeded(parser.parseKeyword(&identifier))) {
34 // Otherwise, this must be an unsigned integer (`u` integer-literal).
35 if (!identifier.consume_front("u")) {
36 parser.emitError(typeLoc, "illegal storage type prefix");
37 return nullptr;
38 }
39 if (identifier.getAsInteger(10, storageTypeWidth)) {
40 parser.emitError(typeLoc, "expected storage type width");
41 return nullptr;
42 }
43 isSigned = false;
44 type = parser.getBuilder().getIntegerType(storageTypeWidth);
45 } else {
46 return nullptr;
47 }
48
49 if (storageTypeWidth == 0 ||
50 storageTypeWidth > QuantizedType::MaxStorageBits) {
51 parser.emitError(typeLoc, "illegal storage type size: ")
52 << storageTypeWidth;
53 return nullptr;
54 }
55
56 return type;
57}
58
59static ParseResult parseStorageRange(DialectAsmParser &parser,
60 IntegerType storageType, bool isSigned,
61 int64_t &storageTypeMin,
62 int64_t &storageTypeMax) {
63 int64_t defaultIntegerMin = QuantizedType::getDefaultMinimumForInteger(
64 isSigned, storageType.getWidth());
65 int64_t defaultIntegerMax = QuantizedType::getDefaultMaximumForInteger(
66 isSigned, storageType.getWidth());
67 if (failed(parser.parseOptionalLess())) {
68 storageTypeMin = defaultIntegerMin;
69 storageTypeMax = defaultIntegerMax;
70 return success();
71 }
72
73 // Explicit storage min and storage max.
74 SMLoc minLoc = parser.getCurrentLocation(), maxLoc;
75 if (parser.parseInteger(storageTypeMin) || parser.parseColon() ||
76 parser.getCurrentLocation(&maxLoc) ||
77 parser.parseInteger(storageTypeMax) || parser.parseGreater())
78 return failure();
79 if (storageTypeMin < defaultIntegerMin) {
80 return parser.emitError(minLoc, "illegal storage type minimum: ")
81 << storageTypeMin;
82 }
83 if (storageTypeMax > defaultIntegerMax) {
84 return parser.emitError(maxLoc, "illegal storage type maximum: ")
85 << storageTypeMax;
86 }
87 return success();
88}
89
91 double &min, double &max) {
92 auto typeLoc = parser.getCurrentLocation();
93 FloatType type;
94
95 if (failed(parser.parseType(type))) {
96 parser.emitError(typeLoc, "expecting float expressed type");
97 return nullptr;
98 }
99
100 // Calibrated min and max values.
101 if (parser.parseLess() || parser.parseFloat(min) || parser.parseColon() ||
102 parser.parseFloat(max) || parser.parseGreater()) {
103 parser.emitError(typeLoc, "calibrated values must be present");
104 return nullptr;
105 }
106 return type;
107}
108
109/// Parses an AnyQuantizedType.
110///
111/// any ::= `any<` storage-spec (expressed-type-spec)?`>`
112/// storage-spec ::= storage-type (`<` storage-range `>`)?
113/// storage-range ::= integer-literal `:` integer-literal
114/// storage-type ::= (`i` | `u`) integer-literal
115/// expressed-type-spec ::= `:` `f` integer-literal
117 IntegerType storageType;
118 FloatType expressedType;
119 unsigned typeFlags = 0;
120 int64_t storageTypeMin;
121 int64_t storageTypeMax;
122
123 // Type specification.
124 if (parser.parseLess())
125 return nullptr;
126
127 // Storage type.
128 bool isSigned = false;
129 storageType = parseStorageType(parser, isSigned);
130 if (!storageType) {
131 return nullptr;
132 }
133 if (isSigned) {
134 typeFlags |= QuantizationFlags::Signed;
135 }
136
137 // Storage type range.
138 if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
139 storageTypeMax)) {
140 return nullptr;
141 }
142
143 // Optional expressed type.
144 if (succeeded(parser.parseOptionalColon())) {
145 if (parser.parseType(expressedType)) {
146 return nullptr;
147 }
148 }
149
150 if (parser.parseGreater()) {
151 return nullptr;
152 }
153
154 return parser.getChecked<AnyQuantizedType>(
155 typeFlags, storageType, expressedType, storageTypeMin, storageTypeMax);
156}
157
158/// Checks if the given scale value is within the valid range of the expressed
159/// type. The `expressedType` argument is the floating-point type used for
160/// expressing the quantized values, and `scale` is the double value to check.
161static LogicalResult
163 Type expressedType, double scale) {
164 auto floatType = cast<FloatType>(expressedType);
165 double minScale =
166 APFloat::getSmallest(floatType.getFloatSemantics()).convertToDouble();
167 double maxScale =
168 APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble();
169 if (scale < minScale || scale > maxScale)
170 return emitError() << "scale " << scale << " out of expressed type range ["
171 << minScale << ", " << maxScale << "]";
172 return success();
173}
174
175/// Parses a quantization parameter, which is either a scale value (float) or a
176/// scale-zero point pair (float:integer). `expressedType`, expressing the type
177/// of scale values, is used to validate the scale. The parsed scale and zero
178/// point (if any) are stored in `scale` and `zeroPoint`.
179static ParseResult parseQuantParams(DialectAsmParser &parser,
180 Type expressedType, double &scale,
181 int64_t &zeroPoint) {
182
183 if (parser.parseFloat(scale)) {
184 return failure();
185 }
186
188 [&]() { return parser.emitError(parser.getCurrentLocation()); },
189 expressedType, scale))) {
190 return failure();
191 }
192
193 zeroPoint = 0;
194 if (failed(parser.parseOptionalColon())) {
195 return success();
196 }
197
198 return parser.parseInteger(zeroPoint);
199}
200
201/// Parses block size information for sub-channel quantization, assuming the
202/// leading '{' has already been parsed. The block size information is provided
203/// as a comma-separated list of "Axis:BlockSize" pairs, terminated by a '}'.
204///
205/// The parsed axis indices are stored in `quantizedDimensions`, and the
206/// corresponding block sizes are stored in `blockSizes`.
207static ParseResult
209 SmallVectorImpl<int32_t> &quantizedDimensions,
210 SmallVectorImpl<int64_t> &blockSizes) {
211 // Empty block-sizes info.
212 if (succeeded(parser.parseOptionalRBrace())) {
213 return success();
214 }
215
216 auto parseBlockSizeElements = [&]() -> ParseResult {
217 quantizedDimensions.resize(quantizedDimensions.size() + 1);
218 blockSizes.resize(blockSizes.size() + 1);
219 if (parser.parseInteger(quantizedDimensions.back()) ||
220 parser.parseColon() || parser.parseInteger(blockSizes.back()))
221 return failure();
222 return success();
223 };
224
225 if (parser.parseCommaSeparatedList(parseBlockSizeElements) ||
226 parser.parseRBrace()) {
227 return failure();
228 }
229
230 return success();
231}
232
233/// Parses a bracketed list of quantization parameters, returning the dimensions
234/// of the parsed sub-tensors in `dims`. The dimension of the list is prepended
235/// to the dimensions of the sub-tensors. This function assumes that the initial
236/// left brace has already been parsed. For example:
237///
238/// parseQuantParamListUntilRBrace(1.0:1, 2.0:4, 3.0:4}) -> Success,
239/// dims = [3], scales = [1.0, 2.0, 3.0], zeroPoints = [1, 4, 4]
240///
241/// parseQuantParamListUntilRBrace({1.0, 2.0}, {3.0:1, 4.0:9}}) -> Success,
242/// dims = [2, 2], scales = [1.0, 2.0, 3.0, 4.0], zeroPoints = [0, 0, 1,
243/// 9]
244///
245/// This function expects all sub-tensors to have the same rank.
246static ParseResult
249 SmallVectorImpl<int64_t> &zeroPoints,
251 auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims,
252 const SmallVectorImpl<int64_t> &newDims) -> ParseResult {
253 if (prevDims == newDims)
254 return success();
255 return parser.emitError(parser.getCurrentLocation())
256 << "tensor literal is invalid; ranks are not consistent "
257 "between elements";
258 };
259
260 bool first = true;
262 unsigned size = 0;
263
264 auto parseOneElement = [&]() -> ParseResult {
266 if (succeeded(parser.parseOptionalLBrace())) {
267 if (parseQuantParamListUntilRBrace(parser, expressedType, scales,
268 zeroPoints, thisDims))
269 return failure();
270 } else {
271 zeroPoints.resize(zeroPoints.size() + 1);
272 scales.resize(scales.size() + 1);
273 if (parseQuantParams(parser, expressedType, scales.back(),
274 zeroPoints.back())) {
275 return failure();
276 }
277 }
278 ++size;
279 if (!first)
280 return checkDims(newDims, thisDims);
281 newDims = thisDims;
282 first = false;
283 return success();
284 };
285
286 if (parser.parseCommaSeparatedList(parseOneElement) || parser.parseRBrace()) {
287 return failure();
288 }
289
290 // Return the sublists' dimensions with 'size' prepended.
291 dims.clear();
292 dims.push_back(size);
293 dims.append(newDims.begin(), newDims.end());
294
295 return success();
296}
297
298/// Parses a UniformQuantizedType.
299///
300/// uniform_type ::= uniform_per_layer
301/// | uniform_per_axis
302/// | uniform_sub_channel
303/// uniform_per_layer ::= `uniform<` storage-spec expressed-type-spec
304/// `,` scale-zero `>`
305/// uniform_per_axis ::= `uniform<` storage-spec expressed-type-spec
306/// axis-spec `,` `{` scale-zero-list `}` `>`
307/// uniform_sub_channel ::= `uniform<` storage-spec expressed-type-spec
308/// block-size-info `,` scale-zero-tensor `>`
309/// storage-spec ::= storage-type (`<` storage-range `>`)?
310/// storage-range ::= integer-literal `:` integer-literal
311/// storage-type ::= (`i` | `u`) integer-literal
312/// expressed-type-spec ::= `:` `f` integer-literal
313/// axis-spec ::= `:` integer-literal
314/// scale-zero ::= scale (`:` zero-point)?
315/// scale ::= float-literal
316/// zero-point ::= integer-literal
317/// scale-zero-list ::= scale-zero (`,` scale-zero)*
318/// block-size-info ::= `{` `}` | `{` axis-block `:` (`,` axis-block)* `}`
319/// axis-block ::= axis-spec `:` block-size-spec
320/// block-size-spec ::= integer-literal
321/// scale-zero-tensor ::= scale-zero-dense-exp | scale-zero-list
322/// scale-zero-dense-exp ::= `{`
323/// scale-zero-tensor (`,` scale-zero-tensor)*
324/// `}`
326 IntegerType storageType;
327 FloatType expressedType;
328 unsigned typeFlags = 0;
329 int64_t storageTypeMin;
330 int64_t storageTypeMax;
331 bool isPerAxis = false;
332 bool isSubChannel = false;
333 SmallVector<int32_t, 1> quantizedDimensions;
334 SmallVector<int64_t, 1> blockSizes;
336 SmallVector<int64_t, 1> zeroPoints;
337
338 // Type specification.
339 if (parser.parseLess()) {
340 return nullptr;
341 }
342
343 // Storage type.
344 bool isSigned = false;
345 storageType = parseStorageType(parser, isSigned);
346 if (!storageType) {
347 return nullptr;
348 }
349 if (isSigned) {
350 typeFlags |= QuantizationFlags::Signed;
351 }
352
353 // Storage type range.
354 if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
355 storageTypeMax)) {
356 return nullptr;
357 }
358
359 // Expressed type.
360 if (parser.parseColon() || parser.parseType(expressedType)) {
361 return nullptr;
362 }
363
364 // Optionally parse quantized dimension for per-axis or sub-channel
365 // quantization.
366 if (succeeded(parser.parseOptionalColon())) {
367 if (succeeded(parser.parseOptionalLBrace())) {
368 isSubChannel = true;
369 if (parseBlockSizeInfoUntilRBrace(parser, quantizedDimensions,
370 blockSizes)) {
371 return nullptr;
372 }
373 } else {
374 isPerAxis = true;
375 quantizedDimensions.resize(1);
376 if (parser.parseInteger(quantizedDimensions.back())) {
377 return nullptr;
378 }
379 }
380 }
381
382 // Comma leading into range_spec.
383 if (parser.parseComma()) {
384 return nullptr;
385 }
386
387 // Quantization parameter (scales/zeroPoints) specification.
388 bool isPerTensor = !isPerAxis && !isSubChannel;
390 if (isPerTensor) {
391 zeroPoints.resize(zeroPoints.size() + 1);
392 scales.resize(scales.size() + 1);
393 if (parseQuantParams(parser, expressedType, scales.back(),
394 zeroPoints.back())) {
395 return nullptr;
396 }
397
398 } else {
399 if (parser.parseLBrace() ||
400 parseQuantParamListUntilRBrace(parser, expressedType, scales,
401 zeroPoints, dims)) {
402 return nullptr;
403 }
404 }
405
406 if (parser.parseGreater()) {
407 return nullptr;
408 }
409
410 if (isPerAxis) {
412 typeFlags, storageType, expressedType, scales, zeroPoints,
413 quantizedDimensions[0], storageTypeMin, storageTypeMax);
414 }
415 if (isSubChannel) {
416 SmallVector<APFloat> apFloatScales =
417 llvm::map_to_vector(scales, [&](double scale) -> APFloat {
418 APFloat apFloatScale(scale);
419 bool unused;
420 apFloatScale.convert(expressedType.getFloatSemantics(),
421 APFloat::rmNearestTiesToEven, &unused);
422 return apFloatScale;
423 });
424 SmallVector<APInt> apIntZeroPoints =
425 llvm::map_to_vector(zeroPoints, [&](int64_t zeroPoint) -> APInt {
426 return APInt(storageType.getIntOrFloatBitWidth(), zeroPoint);
427 });
428 auto scalesRef = mlir::DenseElementsAttr::get(
429 RankedTensorType::get(dims, expressedType), apFloatScales);
430 auto zeroPointsRef = mlir::DenseElementsAttr::get(
431 RankedTensorType::get(dims, storageType), apIntZeroPoints);
433 typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
434 quantizedDimensions, blockSizes, storageTypeMin, storageTypeMax);
435 }
436
437 return parser.getChecked<UniformQuantizedType>(
438 typeFlags, storageType, expressedType, scales.front(), zeroPoints.front(),
439 storageTypeMin, storageTypeMax);
440}
441
442/// Parses an CalibratedQuantizedType.
443///
444/// calibrated ::= `calibrated<` expressed-spec `>`
445/// expressed-spec ::= expressed-type `<` calibrated-range `>`
446/// expressed-type ::= `f` integer-literal
447/// calibrated-range ::= float-literal `:` float-literal
449 FloatType expressedType;
450 double min;
451 double max;
452
453 // Type specification.
454 if (parser.parseLess())
455 return nullptr;
456
457 // Expressed type.
458 expressedType = parseExpressedTypeAndRange(parser, min, max);
459 if (!expressedType) {
460 return nullptr;
461 }
462
463 if (parser.parseGreater()) {
464 return nullptr;
465 }
466
467 return parser.getChecked<CalibratedQuantizedType>(expressedType, min, max);
468}
469
470/// Parse a type registered to this dialect.
471Type QuantDialect::parseType(DialectAsmParser &parser) const {
472 // All types start with an identifier that we switch on.
473 StringRef typeNameSpelling;
474 if (failed(parser.parseKeyword(&typeNameSpelling)))
475 return nullptr;
476
477 if (typeNameSpelling == "uniform")
478 return parseUniformType(parser);
479 if (typeNameSpelling == "any")
480 return parseAnyType(parser);
481 if (typeNameSpelling == "calibrated")
482 return parseCalibratedType(parser);
483
484 parser.emitError(parser.getNameLoc(),
485 "unknown quantized type " + typeNameSpelling);
486 return nullptr;
487}
488
490 // storage type
491 unsigned storageWidth = type.getStorageTypeIntegralWidth();
492 bool isSigned = type.isSigned();
493 if (isSigned) {
494 out << "i" << storageWidth;
495 } else {
496 out << "u" << storageWidth;
497 }
498
499 // storageTypeMin and storageTypeMax if not default.
500 if (type.hasStorageTypeBounds()) {
501 out << "<" << type.getStorageTypeMin() << ":" << type.getStorageTypeMax()
502 << ">";
503 }
504}
505
506static void printQuantParams(double scale, int64_t zeroPoint,
507 DialectAsmPrinter &out) {
508 out << scale;
509 if (zeroPoint != 0) {
510 out << ":" << zeroPoint;
511 }
512}
513
514static void
515printBlockSizeInfo(ArrayRef<std::pair<int32_t, int64_t>> blockSizeInfo,
516 DialectAsmPrinter &out) {
517 out << "{";
518 llvm::interleaveComma(
519 llvm::seq<size_t>(0, blockSizeInfo.size()), out, [&](size_t index) {
520 out << blockSizeInfo[index].first << ":" << blockSizeInfo[index].second;
521 });
522 out << "}";
523}
524
525/// Helper that prints a AnyQuantizedType.
527 DialectAsmPrinter &out) {
528 out << "any<";
529 printStorageType(type, out);
530 if (Type expressedType = type.getExpressedType()) {
531 out << ":" << expressedType;
532 }
533 out << ">";
534}
535
536/// Helper that prints a UniformQuantizedType.
538 DialectAsmPrinter &out) {
539 out << "uniform<";
540 printStorageType(type, out);
541 out << ":" << type.getExpressedType() << ", ";
542
543 // scheme specific parameters
544 printQuantParams(type.getScale(), type.getZeroPoint(), out);
545 out << ">";
546}
547
548/// Helper that prints a UniformQuantizedPerAxisType.
550 DialectAsmPrinter &out) {
551 out << "uniform<";
552 printStorageType(type, out);
553 out << ":" << type.getExpressedType() << ":";
554 out << type.getQuantizedDimension();
555 out << ", ";
556
557 // scheme specific parameters
558 ArrayRef<double> scales = type.getScales();
559 ArrayRef<int64_t> zeroPoints = type.getZeroPoints();
560 out << "{";
561 llvm::interleave(
562 llvm::seq<size_t>(0, scales.size()), out,
563 [&](size_t index) {
564 printQuantParams(scales[index], zeroPoints[index], out);
565 },
566 ",");
567 out << "}>";
568}
569
570/// Prints quantization parameters as a nested list of `scale`[:`zero_point`]
571/// elements. The nesting corresponds to the `shape` dimensions.
572///
573/// Elements are delimited by commas, and the inner dimensions are enclosed in
574/// braces. `zero_point` is only printed if it is non-zero. For example:
575///
576/// printDenseQuantizationParameters(scales=[1.0, 2.0, 3.0, 4.0],
577/// zeroPoints=[0, 0, 1, 9],
578/// shape=[2, 2])
579///
580/// would print:
581///
582/// {{1.0, 2.0}, {3.0:1, 4.0:9}}
584 ArrayRef<APInt> zeroPoints,
586 DialectAsmPrinter &out) {
587 int64_t rank = shape.size();
588 SmallVector<unsigned, 4> counter(rank, 0);
589 unsigned openBrackets = 0;
590
591 auto incrementCounterAndDelimit = [&]() {
592 ++counter[rank - 1];
593 for (unsigned i = rank - 1; i > 0; --i) {
594 if (counter[i] >= shape[i]) {
595 counter[i] = 0;
596 ++counter[i - 1];
597 --openBrackets;
598 out << '}';
599 }
600 }
601 };
602
603 for (unsigned idx = 0, e = scales.size(); idx < e; ++idx) {
604 if (idx != 0)
605 out << ", ";
606 while (openBrackets++ < rank)
607 out << '{';
608 openBrackets = rank;
609 out << scales[idx];
610 if (zeroPoints[idx] != 0) {
611 out << ":" << zeroPoints[idx];
612 }
613 incrementCounterAndDelimit();
614 }
615 while (openBrackets-- > 0)
616 out << '}';
617}
618
619/// Helper that prints a UniformQuantizedSubChannelType.
620static void
622 DialectAsmPrinter &out) {
623 out << "uniform<";
624 printStorageType(type, out);
625 out << ":" << type.getExpressedType() << ":";
627 out << ", ";
628
629 auto scalesItr = type.getScales().getValues<APFloat>();
630 auto zeroPointsItr = type.getZeroPoints().getValues<APInt>();
631 SmallVector<APFloat> scales(scalesItr.begin(), scalesItr.end());
632 SmallVector<APInt> zeroPoints(zeroPointsItr.begin(), zeroPointsItr.end());
633 printDenseQuantizationParameters(scales, zeroPoints,
634 type.getScales().getType().getShape(), out);
635 out << ">";
636}
637
638/// Helper that prints a CalibratedQuantizedType.
640 DialectAsmPrinter &out) {
641 out << "calibrated<" << type.getExpressedType();
642 out << "<" << type.getMin() << ":" << type.getMax() << ">";
643 out << ">";
644}
645
646/// Print a type registered to this dialect.
647void QuantDialect::printType(Type type, DialectAsmPrinter &os) const {
648 if (auto anyType = llvm::dyn_cast<AnyQuantizedType>(type))
649 printAnyQuantizedType(anyType, os);
650 else if (auto uniformType = llvm::dyn_cast<UniformQuantizedType>(type))
651 printUniformQuantizedType(uniformType, os);
652 else if (auto perAxisType = llvm::dyn_cast<UniformQuantizedPerAxisType>(type))
653 printUniformQuantizedPerAxisType(perAxisType, os);
654 else if (auto perAxisType =
655 llvm::dyn_cast<UniformQuantizedSubChannelType>(type))
657 else if (auto calibratedType = llvm::dyn_cast<CalibratedQuantizedType>(type))
658 printCalibratedQuantizedType(calibratedType, os);
659 else
660 llvm_unreachable("Unhandled quantized type");
661}
return success()
static void printBlockSizeInfo(ArrayRef< std::pair< int32_t, int64_t > > blockSizeInfo, DialectAsmPrinter &out)
static ParseResult parseQuantParams(DialectAsmParser &parser, Type expressedType, double &scale, int64_t &zeroPoint)
Parses a quantization parameter, which is either a scale value (float) or a scale-zero point pair (fl...
static void printAnyQuantizedType(AnyQuantizedType type, DialectAsmPrinter &out)
Helper that prints a AnyQuantizedType.
static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser, double &min, double &max)
static void printUniformQuantizedSubChannelType(UniformQuantizedSubChannelType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedSubChannelType.
static Type parseUniformType(DialectAsmParser &parser)
Parses a UniformQuantizedType.
static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned)
static Type parseAnyType(DialectAsmParser &parser)
Parses an AnyQuantizedType.
static void printStorageType(QuantizedType type, DialectAsmPrinter &out)
static ParseResult parseBlockSizeInfoUntilRBrace(DialectAsmParser &parser, SmallVectorImpl< int32_t > &quantizedDimensions, SmallVectorImpl< int64_t > &blockSizes)
Parses block size information for sub-channel quantization, assuming the leading '{' has already been...
static void printQuantParams(double scale, int64_t zeroPoint, DialectAsmPrinter &out)
static Type parseCalibratedType(DialectAsmParser &parser)
Parses an CalibratedQuantizedType.
static ParseResult parseQuantParamListUntilRBrace(DialectAsmParser &parser, Type expressedType, SmallVectorImpl< double > &scales, SmallVectorImpl< int64_t > &zeroPoints, SmallVectorImpl< int64_t > &dims)
Parses a bracketed list of quantization parameters, returning the dimensions of the parsed sub-tensor...
static ParseResult parseStorageRange(DialectAsmParser &parser, IntegerType storageType, bool isSigned, int64_t &storageTypeMin, int64_t &storageTypeMax)
static void printCalibratedQuantizedType(CalibratedQuantizedType type, DialectAsmPrinter &out)
Helper that prints a CalibratedQuantizedType.
static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedPerAxisType.
static void printUniformQuantizedType(UniformQuantizedType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedType.
static LogicalResult isScaleInExpressedTypeRange(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double scale)
Checks if the given scale value is within the valid range of the expressed type.
static void printDenseQuantizationParameters(ArrayRef< APFloat > scales, ArrayRef< APInt > zeroPoints, ArrayRef< int64_t > shape, DialectAsmPrinter &out)
Prints quantization parameters as a nested list of scale[:zero_point] elements.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
virtual ParseResult parseLBrace()=0
Parse a { token.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalRBrace()=0
Parse a } token if present.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual OptionalParseResult parseOptionalType(Type &result)=0
Parse an optional type.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseFloat(double &result)=0
Parse a floating point value from the stream.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class represents a diagnostic that is inflight and set to be reported.
This class implements Optional functionality for ParseResult.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
A quantized type that maps storage to/from expressed types in an unspecified way.
Definition QuantTypes.h:203
A quantized type that infers its range from given min/max values.
Definition QuantTypes.h:524
Base class for all quantized types known to this dialect.
Definition QuantTypes.h:50
bool hasStorageTypeBounds() const
Return whether the storage type has explicit min or max boundaries different from the minimum and max...
bool isSigned() const
Whether the storage type should be interpreted as a signed quantity (true) or an unsigned value (fals...
Definition QuantTypes.h:103
int64_t getStorageTypeMax() const
The maximum value that storageType can take.
unsigned getStorageTypeIntegralWidth() const
Gets the integral bit width that the underlying storage type can exactly represent.
int64_t getStorageTypeMin() const
The minimum value that storageType can take.
Represents per-axis (also known as per-channel quantization).
Definition QuantTypes.h:324
int32_t getQuantizedDimension() const
Specifies the dimension of the Tensor's shape that the scales and zero_points correspond to.
ArrayRef< int64_t > getZeroPoints() const
Gets the storage values corresponding to the real value 0 in the affine equation.
ArrayRef< double > getScales() const
Gets the quantization scales.
Represents sub-channel (also known as blockwise quantization).
Definition QuantTypes.h:409
DenseElementsAttr getZeroPoints() const
Gets the quantization zero-points.
const SmallVector< std::pair< int32_t, int64_t > > getBlockSizeInfo() const
Gets the block size information.
DenseElementsAttr getScales() const
Gets the quantization scales.
Represents a family of uniform, quantized types.
Definition QuantTypes.h:264
double getScale() const
Gets the scale term.
int64_t getZeroPoint() const
Gets the storage value corresponding to the real value 0 in the affine equation.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144