MLIR  21.0.0git
TypeParser.cpp
Go to the documentation of this file.
1 //===- TypeParser.h - Quantization Type Parser ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/IR/Location.h"
14 #include "mlir/IR/Types.h"
15 #include "llvm/ADT/APFloat.h"
16 #include "llvm/Support/Format.h"
17 #include "llvm/Support/MathExtras.h"
18 #include "llvm/Support/SourceMgr.h"
19 #include "llvm/Support/raw_ostream.h"
20 
21 using namespace mlir;
22 using namespace quant;
23 
24 static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
25  auto typeLoc = parser.getCurrentLocation();
26  IntegerType type;
27 
28  // Parse storage type (alpha_ident, integer_literal).
29  StringRef identifier;
30  unsigned storageTypeWidth = 0;
31  OptionalParseResult result = parser.parseOptionalType(type);
32  if (result.has_value()) {
33  if (!succeeded(*result))
34  return nullptr;
35  isSigned = !type.isUnsigned();
36  storageTypeWidth = type.getWidth();
37  } else if (succeeded(parser.parseKeyword(&identifier))) {
38  // Otherwise, this must be an unsigned integer (`u` integer-literal).
39  if (!identifier.consume_front("u")) {
40  parser.emitError(typeLoc, "illegal storage type prefix");
41  return nullptr;
42  }
43  if (identifier.getAsInteger(10, storageTypeWidth)) {
44  parser.emitError(typeLoc, "expected storage type width");
45  return nullptr;
46  }
47  isSigned = false;
48  type = parser.getBuilder().getIntegerType(storageTypeWidth);
49  } else {
50  return nullptr;
51  }
52 
53  if (storageTypeWidth == 0 ||
54  storageTypeWidth > QuantizedType::MaxStorageBits) {
55  parser.emitError(typeLoc, "illegal storage type size: ")
56  << storageTypeWidth;
57  return nullptr;
58  }
59 
60  return type;
61 }
62 
63 static ParseResult parseStorageRange(DialectAsmParser &parser,
64  IntegerType storageType, bool isSigned,
65  int64_t &storageTypeMin,
66  int64_t &storageTypeMax) {
67  int64_t defaultIntegerMin = QuantizedType::getDefaultMinimumForInteger(
68  isSigned, storageType.getWidth());
69  int64_t defaultIntegerMax = QuantizedType::getDefaultMaximumForInteger(
70  isSigned, storageType.getWidth());
71  if (failed(parser.parseOptionalLess())) {
72  storageTypeMin = defaultIntegerMin;
73  storageTypeMax = defaultIntegerMax;
74  return success();
75  }
76 
77  // Explicit storage min and storage max.
78  SMLoc minLoc = parser.getCurrentLocation(), maxLoc;
79  if (parser.parseInteger(storageTypeMin) || parser.parseColon() ||
80  parser.getCurrentLocation(&maxLoc) ||
81  parser.parseInteger(storageTypeMax) || parser.parseGreater())
82  return failure();
83  if (storageTypeMin < defaultIntegerMin) {
84  return parser.emitError(minLoc, "illegal storage type minimum: ")
85  << storageTypeMin;
86  }
87  if (storageTypeMax > defaultIntegerMax) {
88  return parser.emitError(maxLoc, "illegal storage type maximum: ")
89  << storageTypeMax;
90  }
91  return success();
92 }
93 
95  double &min, double &max) {
96  auto typeLoc = parser.getCurrentLocation();
97  FloatType type;
98 
99  if (failed(parser.parseType(type))) {
100  parser.emitError(typeLoc, "expecting float expressed type");
101  return nullptr;
102  }
103 
104  // Calibrated min and max values.
105  if (parser.parseLess() || parser.parseFloat(min) || parser.parseColon() ||
106  parser.parseFloat(max) || parser.parseGreater()) {
107  parser.emitError(typeLoc, "calibrated values must be present");
108  return nullptr;
109  }
110  return type;
111 }
112 
113 /// Parses an AnyQuantizedType.
114 ///
115 /// any ::= `any<` storage-spec (expressed-type-spec)?`>`
116 /// storage-spec ::= storage-type (`<` storage-range `>`)?
117 /// storage-range ::= integer-literal `:` integer-literal
118 /// storage-type ::= (`i` | `u`) integer-literal
119 /// expressed-type-spec ::= `:` `f` integer-literal
121  IntegerType storageType;
122  FloatType expressedType;
123  unsigned typeFlags = 0;
124  int64_t storageTypeMin;
125  int64_t storageTypeMax;
126 
127  // Type specification.
128  if (parser.parseLess())
129  return nullptr;
130 
131  // Storage type.
132  bool isSigned = false;
133  storageType = parseStorageType(parser, isSigned);
134  if (!storageType) {
135  return nullptr;
136  }
137  if (isSigned) {
138  typeFlags |= QuantizationFlags::Signed;
139  }
140 
141  // Storage type range.
142  if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
143  storageTypeMax)) {
144  return nullptr;
145  }
146 
147  // Optional expressed type.
148  if (succeeded(parser.parseOptionalColon())) {
149  if (parser.parseType(expressedType)) {
150  return nullptr;
151  }
152  }
153 
154  if (parser.parseGreater()) {
155  return nullptr;
156  }
157 
158  return parser.getChecked<AnyQuantizedType>(
159  typeFlags, storageType, expressedType, storageTypeMin, storageTypeMax);
160 }
161 
162 /// Checks if the given scale value is within the valid range of the expressed
163 /// type. The `expressedType` argument is the floating-point type used for
164 /// expressing the quantized values, and `scale` is the double value to check.
165 LogicalResult
167  Type expressedType, double scale) {
168  auto floatType = cast<FloatType>(expressedType);
169  double minScale =
170  APFloat::getSmallest(floatType.getFloatSemantics()).convertToDouble();
171  double maxScale =
172  APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble();
173  if (scale < minScale || scale > maxScale)
174  return emitError() << "scale " << scale << " out of expressed type range ["
175  << minScale << ", " << maxScale << "]";
176  return success();
177 }
178 
179 /// Parses a quantization parameter, which is either a scale value (float) or a
180 /// scale-zero point pair (float:integer). `expressedType`, expressing the type
181 /// of scale values, is used to validate the scale. The parsed scale and zero
182 /// point (if any) are stored in `scale` and `zeroPoint`.
183 static ParseResult parseQuantParams(DialectAsmParser &parser,
184  Type expressedType, double &scale,
185  int64_t &zeroPoint) {
186 
187  if (parser.parseFloat(scale)) {
188  return failure();
189  }
190 
191  if (failed(isScaleInExpressedTypeRange(
192  [&]() { return parser.emitError(parser.getCurrentLocation()); },
193  expressedType, scale))) {
194  return failure();
195  }
196 
197  zeroPoint = 0;
198  if (failed(parser.parseOptionalColon())) {
199  return success();
200  }
201 
202  return parser.parseInteger(zeroPoint);
203 }
204 
205 /// Parses block size information for sub-channel quantization, assuming the
206 /// leading '{' has already been parsed. The block size information is provided
207 /// as a comma-separated list of "Axis:BlockSize" pairs, terminated by a '}'.
208 ///
209 /// The parsed axis indices are stored in `quantizedDimensions`, and the
210 /// corresponding block sizes are stored in `blockSizes`.
211 static ParseResult
213  SmallVectorImpl<int32_t> &quantizedDimensions,
214  SmallVectorImpl<int64_t> &blockSizes) {
215  // Empty block-sizes info.
216  if (succeeded(parser.parseOptionalRBrace())) {
217  return success();
218  }
219 
220  auto parseBlockSizeElements = [&]() -> ParseResult {
221  quantizedDimensions.resize(quantizedDimensions.size() + 1);
222  blockSizes.resize(blockSizes.size() + 1);
223  if (parser.parseInteger(quantizedDimensions.back()) ||
224  parser.parseColon() || parser.parseInteger(blockSizes.back()))
225  return failure();
226  return success();
227  };
228 
229  if (parser.parseCommaSeparatedList(parseBlockSizeElements) ||
230  parser.parseRBrace()) {
231  return failure();
232  }
233 
234  return success();
235 }
236 
237 /// Parses a bracketed list of quantization parameters, returning the dimensions
238 /// of the parsed sub-tensors in `dims`. The dimension of the list is prepended
239 /// to the dimensions of the sub-tensors. This function assumes that the initial
240 /// left brace has already been parsed. For example:
241 ///
242 /// parseQuantParamListUntilRBrace(1.0:1, 2.0:4, 3.0:4}) -> Success,
243 /// dims = [3], scales = [1.0, 2.0, 3.0], zeroPoints = [1, 4, 4]
244 ///
245 /// parseQuantParamListUntilRBrace({1.0, 2.0}, {3.0:1, 4.0:9}}) -> Success,
246 /// dims = [2, 2], scales = [1.0, 2.0, 3.0, 4.0], zeroPoints = [0, 0, 1,
247 /// 9]
248 ///
249 /// This function expects all sub-tensors to have the same rank.
250 static ParseResult
252  SmallVectorImpl<double> &scales,
253  SmallVectorImpl<int64_t> &zeroPoints,
254  SmallVectorImpl<int64_t> &dims) {
255  auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims,
256  const SmallVectorImpl<int64_t> &newDims) -> ParseResult {
257  if (prevDims == newDims)
258  return success();
259  return parser.emitError(parser.getCurrentLocation())
260  << "tensor literal is invalid; ranks are not consistent "
261  "between elements";
262  };
263 
264  bool first = true;
265  SmallVector<int64_t, 4> newDims;
266  unsigned size = 0;
267 
268  auto parseOneElement = [&]() -> ParseResult {
269  SmallVector<int64_t, 4> thisDims;
270  if (succeeded(parser.parseOptionalLBrace())) {
271  if (parseQuantParamListUntilRBrace(parser, expressedType, scales,
272  zeroPoints, thisDims))
273  return failure();
274  } else {
275  zeroPoints.resize(zeroPoints.size() + 1);
276  scales.resize(scales.size() + 1);
277  if (parseQuantParams(parser, expressedType, scales.back(),
278  zeroPoints.back())) {
279  return failure();
280  }
281  }
282  ++size;
283  if (!first)
284  return checkDims(newDims, thisDims);
285  newDims = thisDims;
286  first = false;
287  return success();
288  };
289 
290  if (parser.parseCommaSeparatedList(parseOneElement) || parser.parseRBrace()) {
291  return failure();
292  }
293 
294  // Return the sublists' dimensions with 'size' prepended.
295  dims.clear();
296  dims.push_back(size);
297  dims.append(newDims.begin(), newDims.end());
298 
299  return success();
300 }
301 
302 /// Parses a UniformQuantizedType.
303 ///
304 /// uniform_type ::= uniform_per_layer
305 /// | uniform_per_axis
306 /// | uniform_sub_channel
307 /// uniform_per_layer ::= `uniform<` storage-spec expressed-type-spec
308 /// `,` scale-zero `>`
309 /// uniform_per_axis ::= `uniform<` storage-spec expressed-type-spec
310 /// axis-spec `,` `{` scale-zero-list `}` `>`
311 /// uniform_sub_channel ::= `uniform<` storage-spec expressed-type-spec
312 /// block-size-info `,` scale-zero-tensor `>`
313 /// storage-spec ::= storage-type (`<` storage-range `>`)?
314 /// storage-range ::= integer-literal `:` integer-literal
315 /// storage-type ::= (`i` | `u`) integer-literal
316 /// expressed-type-spec ::= `:` `f` integer-literal
317 /// axis-spec ::= `:` integer-literal
318 /// scale-zero ::= scale (`:` zero-point)?
319 /// scale ::= float-literal
320 /// zero-point ::= integer-literal
321 /// scale-zero-list ::= scale-zero (`,` scale-zero)*
322 /// block-size-info ::= `{` `}` | `{` axis-block `:` (`,` axis-block)* `}`
323 /// axis-block ::= axis-spec `:` block-size-spec
324 /// block-size-spec ::= integer-literal
325 /// scale-zero-tensor ::= scale-zero-dense-exp | scale-zero-list
326 /// scale-zero-dense-exp ::= `{`
327 /// scale-zero-tensor (`,` scale-zero-tensor)*
328 /// `}`
330  IntegerType storageType;
331  FloatType expressedType;
332  unsigned typeFlags = 0;
333  int64_t storageTypeMin;
334  int64_t storageTypeMax;
335  bool isPerAxis = false;
336  bool isSubChannel = false;
337  SmallVector<int32_t, 1> quantizedDimensions;
338  SmallVector<int64_t, 1> blockSizes;
339  SmallVector<double, 1> scales;
340  SmallVector<int64_t, 1> zeroPoints;
341 
342  // Type specification.
343  if (parser.parseLess()) {
344  return nullptr;
345  }
346 
347  // Storage type.
348  bool isSigned = false;
349  storageType = parseStorageType(parser, isSigned);
350  if (!storageType) {
351  return nullptr;
352  }
353  if (isSigned) {
354  typeFlags |= QuantizationFlags::Signed;
355  }
356 
357  // Storage type range.
358  if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
359  storageTypeMax)) {
360  return nullptr;
361  }
362 
363  // Expressed type.
364  if (parser.parseColon() || parser.parseType(expressedType)) {
365  return nullptr;
366  }
367 
368  // Optionally parse quantized dimension for per-axis or sub-channel
369  // quantization.
370  if (succeeded(parser.parseOptionalColon())) {
371  if (succeeded(parser.parseOptionalLBrace())) {
372  isSubChannel = true;
373  if (parseBlockSizeInfoUntilRBrace(parser, quantizedDimensions,
374  blockSizes)) {
375  return nullptr;
376  }
377  } else {
378  isPerAxis = true;
379  quantizedDimensions.resize(1);
380  if (parser.parseInteger(quantizedDimensions.back())) {
381  return nullptr;
382  }
383  }
384  }
385 
386  // Comma leading into range_spec.
387  if (parser.parseComma()) {
388  return nullptr;
389  }
390 
391  // Quantization parameter (scales/zeroPoints) specification.
392  bool isPerTensor = !isPerAxis && !isSubChannel;
394  if (isPerTensor) {
395  zeroPoints.resize(zeroPoints.size() + 1);
396  scales.resize(scales.size() + 1);
397  if (parseQuantParams(parser, expressedType, scales.back(),
398  zeroPoints.back())) {
399  return nullptr;
400  }
401 
402  } else {
403  if (parser.parseLBrace() ||
404  parseQuantParamListUntilRBrace(parser, expressedType, scales,
405  zeroPoints, dims)) {
406  return nullptr;
407  }
408  }
409 
410  if (parser.parseGreater()) {
411  return nullptr;
412  }
413 
414  if (isPerAxis) {
416  typeFlags, storageType, expressedType, scales, zeroPoints,
417  quantizedDimensions[0], storageTypeMin, storageTypeMax);
418  } else if (isSubChannel) {
419  SmallVector<APFloat> apFloatScales =
420  llvm::to_vector(llvm::map_range(scales, [&](double scale) -> APFloat {
421  APFloat apFloatScale(scale);
422  bool unused;
423  apFloatScale.convert(expressedType.getFloatSemantics(),
424  APFloat::rmNearestTiesToEven, &unused);
425  return apFloatScale;
426  }));
427  SmallVector<APInt> apIntZeroPoints = llvm::to_vector(
428  llvm::map_range(zeroPoints, [&](int64_t zeroPoint) -> APInt {
429  return APInt(storageType.getIntOrFloatBitWidth(), zeroPoint);
430  }));
431  auto scalesRef = mlir::DenseElementsAttr::get(
432  RankedTensorType::get(dims, expressedType), apFloatScales);
433  auto zeroPointsRef = mlir::DenseElementsAttr::get(
434  RankedTensorType::get(dims, storageType), apIntZeroPoints);
436  typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
437  quantizedDimensions, blockSizes, storageTypeMin, storageTypeMax);
438  }
439 
440  return parser.getChecked<UniformQuantizedType>(
441  typeFlags, storageType, expressedType, scales.front(), zeroPoints.front(),
442  storageTypeMin, storageTypeMax);
443 }
444 
445 /// Parses an CalibratedQuantizedType.
446 ///
447 /// calibrated ::= `calibrated<` expressed-spec `>`
448 /// expressed-spec ::= expressed-type `<` calibrated-range `>`
449 /// expressed-type ::= `f` integer-literal
450 /// calibrated-range ::= float-literal `:` float-literal
452  FloatType expressedType;
453  double min;
454  double max;
455 
456  // Type specification.
457  if (parser.parseLess())
458  return nullptr;
459 
460  // Expressed type.
461  expressedType = parseExpressedTypeAndRange(parser, min, max);
462  if (!expressedType) {
463  return nullptr;
464  }
465 
466  if (parser.parseGreater()) {
467  return nullptr;
468  }
469 
470  return parser.getChecked<CalibratedQuantizedType>(expressedType, min, max);
471 }
472 
473 /// Parse a type registered to this dialect.
475  // All types start with an identifier that we switch on.
476  StringRef typeNameSpelling;
477  if (failed(parser.parseKeyword(&typeNameSpelling)))
478  return nullptr;
479 
480  if (typeNameSpelling == "uniform")
481  return parseUniformType(parser);
482  if (typeNameSpelling == "any")
483  return parseAnyType(parser);
484  if (typeNameSpelling == "calibrated")
485  return parseCalibratedType(parser);
486 
487  parser.emitError(parser.getNameLoc(),
488  "unknown quantized type " + typeNameSpelling);
489  return nullptr;
490 }
491 
493  // storage type
494  unsigned storageWidth = type.getStorageTypeIntegralWidth();
495  bool isSigned = type.isSigned();
496  if (isSigned) {
497  out << "i" << storageWidth;
498  } else {
499  out << "u" << storageWidth;
500  }
501 
502  // storageTypeMin and storageTypeMax if not default.
503  if (type.hasStorageTypeBounds()) {
504  out << "<" << type.getStorageTypeMin() << ":" << type.getStorageTypeMax()
505  << ">";
506  }
507 }
508 
509 static void printQuantParams(double scale, int64_t zeroPoint,
510  DialectAsmPrinter &out) {
511  out << scale;
512  if (zeroPoint != 0) {
513  out << ":" << zeroPoint;
514  }
515 }
516 
517 static void
518 printBlockSizeInfo(ArrayRef<std::pair<int32_t, int64_t>> blockSizeInfo,
519  DialectAsmPrinter &out) {
520  out << "{";
521  llvm::interleaveComma(
522  llvm::seq<size_t>(0, blockSizeInfo.size()), out, [&](size_t index) {
523  out << blockSizeInfo[index].first << ":" << blockSizeInfo[index].second;
524  });
525  out << "}";
526 }
527 
528 /// Helper that prints a AnyQuantizedType.
530  DialectAsmPrinter &out) {
531  out << "any<";
532  printStorageType(type, out);
533  if (Type expressedType = type.getExpressedType()) {
534  out << ":" << expressedType;
535  }
536  out << ">";
537 }
538 
539 /// Helper that prints a UniformQuantizedType.
541  DialectAsmPrinter &out) {
542  out << "uniform<";
543  printStorageType(type, out);
544  out << ":" << type.getExpressedType() << ", ";
545 
546  // scheme specific parameters
547  printQuantParams(type.getScale(), type.getZeroPoint(), out);
548  out << ">";
549 }
550 
551 /// Helper that prints a UniformQuantizedPerAxisType.
553  DialectAsmPrinter &out) {
554  out << "uniform<";
555  printStorageType(type, out);
556  out << ":" << type.getExpressedType() << ":";
557  out << type.getQuantizedDimension();
558  out << ", ";
559 
560  // scheme specific parameters
561  ArrayRef<double> scales = type.getScales();
562  ArrayRef<int64_t> zeroPoints = type.getZeroPoints();
563  out << "{";
564  llvm::interleave(
565  llvm::seq<size_t>(0, scales.size()), out,
566  [&](size_t index) {
567  printQuantParams(scales[index], zeroPoints[index], out);
568  },
569  ",");
570  out << "}>";
571 }
572 
573 /// Prints quantization parameters as a nested list of `scale`[:`zero_point`]
574 /// elements. The nesting corresponds to the `shape` dimensions.
575 ///
576 /// Elements are delimited by commas, and the inner dimensions are enclosed in
577 /// braces. `zero_point` is only printed if it is non-zero. For example:
578 ///
579 /// printDenseQuantizationParameters(scales=[1.0, 2.0, 3.0, 4.0],
580 /// zeroPoints=[0, 0, 1, 9],
581 /// shape=[2, 2])
582 ///
583 /// would print:
584 ///
585 /// {{1.0, 2.0}, {3.0:1, 4.0:9}}
587  ArrayRef<APInt> zeroPoints,
588  ArrayRef<int64_t> shape,
589  DialectAsmPrinter &out) {
590  int64_t rank = shape.size();
591  SmallVector<unsigned, 4> counter(rank, 0);
592  unsigned openBrackets = 0;
593 
594  auto incrementCounterAndDelimit = [&]() {
595  ++counter[rank - 1];
596  for (unsigned i = rank - 1; i > 0; --i) {
597  if (counter[i] >= shape[i]) {
598  counter[i] = 0;
599  ++counter[i - 1];
600  --openBrackets;
601  out << '}';
602  }
603  }
604  };
605 
606  for (unsigned idx = 0, e = scales.size(); idx < e; ++idx) {
607  if (idx != 0)
608  out << ", ";
609  while (openBrackets++ < rank)
610  out << '{';
611  openBrackets = rank;
612  out << scales[idx];
613  if (zeroPoints[idx] != 0) {
614  out << ":" << zeroPoints[idx];
615  }
616  incrementCounterAndDelimit();
617  }
618  while (openBrackets-- > 0)
619  out << '}';
620 }
621 
622 /// Helper that prints a UniformQuantizedSubChannelType.
623 static void
625  DialectAsmPrinter &out) {
626  out << "uniform<";
627  printStorageType(type, out);
628  out << ":" << type.getExpressedType() << ":";
630  out << ", ";
631 
632  auto scalesItr = type.getScales().getValues<APFloat>();
633  auto zeroPointsItr = type.getZeroPoints().getValues<APInt>();
634  SmallVector<APFloat> scales(scalesItr.begin(), scalesItr.end());
635  SmallVector<APInt> zeroPoints(zeroPointsItr.begin(), zeroPointsItr.end());
636  printDenseQuantizationParameters(scales, zeroPoints,
637  type.getScales().getType().getShape(), out);
638  out << ">";
639 }
640 
641 /// Helper that prints a CalibratedQuantizedType.
643  DialectAsmPrinter &out) {
644  out << "calibrated<" << type.getExpressedType();
645  out << "<" << type.getMin() << ":" << type.getMax() << ">";
646  out << ">";
647 }
648 
649 /// Print a type registered to this dialect.
650 void QuantDialect::printType(Type type, DialectAsmPrinter &os) const {
651  if (auto anyType = llvm::dyn_cast<AnyQuantizedType>(type))
652  printAnyQuantizedType(anyType, os);
653  else if (auto uniformType = llvm::dyn_cast<UniformQuantizedType>(type))
654  printUniformQuantizedType(uniformType, os);
655  else if (auto perAxisType = llvm::dyn_cast<UniformQuantizedPerAxisType>(type))
656  printUniformQuantizedPerAxisType(perAxisType, os);
657  else if (auto perAxisType =
658  llvm::dyn_cast<UniformQuantizedSubChannelType>(type))
659  printUniformQuantizedSubChannelType(perAxisType, os);
660  else if (auto calibratedType = llvm::dyn_cast<CalibratedQuantizedType>(type))
661  printCalibratedQuantizedType(calibratedType, os);
662  else
663  llvm_unreachable("Unhandled quantized type");
664 }
static void printBlockSizeInfo(ArrayRef< std::pair< int32_t, int64_t >> blockSizeInfo, DialectAsmPrinter &out)
Definition: TypeParser.cpp:518
static ParseResult parseQuantParams(DialectAsmParser &parser, Type expressedType, double &scale, int64_t &zeroPoint)
Parses a quantization parameter, which is either a scale value (float) or a scale-zero point pair (fl...
Definition: TypeParser.cpp:183
void printDenseQuantizationParameters(ArrayRef< APFloat > scales, ArrayRef< APInt > zeroPoints, ArrayRef< int64_t > shape, DialectAsmPrinter &out)
Prints quantization parameters as a nested list of scale[:zero_point] elements.
Definition: TypeParser.cpp:586
static void printAnyQuantizedType(AnyQuantizedType type, DialectAsmPrinter &out)
Helper that prints a AnyQuantizedType.
Definition: TypeParser.cpp:529
LogicalResult isScaleInExpressedTypeRange(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double scale)
Checks if the given scale value is within the valid range of the expressed type.
Definition: TypeParser.cpp:166
static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser, double &min, double &max)
Definition: TypeParser.cpp:94
static void printUniformQuantizedSubChannelType(UniformQuantizedSubChannelType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedSubChannelType.
Definition: TypeParser.cpp:624
static Type parseUniformType(DialectAsmParser &parser)
Parses a UniformQuantizedType.
Definition: TypeParser.cpp:329
static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned)
Definition: TypeParser.cpp:24
static Type parseAnyType(DialectAsmParser &parser)
Parses an AnyQuantizedType.
Definition: TypeParser.cpp:120
static void printStorageType(QuantizedType type, DialectAsmPrinter &out)
Definition: TypeParser.cpp:492
static ParseResult parseBlockSizeInfoUntilRBrace(DialectAsmParser &parser, SmallVectorImpl< int32_t > &quantizedDimensions, SmallVectorImpl< int64_t > &blockSizes)
Parses block size information for sub-channel quantization, assuming the leading '{' has already been...
Definition: TypeParser.cpp:212
static void printQuantParams(double scale, int64_t zeroPoint, DialectAsmPrinter &out)
Definition: TypeParser.cpp:509
static Type parseCalibratedType(DialectAsmParser &parser)
Parses an CalibratedQuantizedType.
Definition: TypeParser.cpp:451
static ParseResult parseQuantParamListUntilRBrace(DialectAsmParser &parser, Type expressedType, SmallVectorImpl< double > &scales, SmallVectorImpl< int64_t > &zeroPoints, SmallVectorImpl< int64_t > &dims)
Parses a bracketed list of quantization parameters, returning the dimensions of the parsed sub-tensor...
Definition: TypeParser.cpp:251
static ParseResult parseStorageRange(DialectAsmParser &parser, IntegerType storageType, bool isSigned, int64_t &storageTypeMin, int64_t &storageTypeMax)
Definition: TypeParser.cpp:63
static void printCalibratedQuantizedType(CalibratedQuantizedType type, DialectAsmPrinter &out)
Helper that prints a CalibratedQuantizedType.
Definition: TypeParser.cpp:642
static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedPerAxisType.
Definition: TypeParser.cpp:552
static void printUniformQuantizedType(UniformQuantizedType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedType.
Definition: TypeParser.cpp:540
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
virtual ParseResult parseLBrace()=0
Parse a { token.
virtual ParseResult parseOptionalRBrace()=0
Parse a } token if present.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual OptionalParseResult parseOptionalType(Type &result)=0
Parse an optional type.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseFloat(double &result)=0
Parse a floating point value from the stream.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
A quantized type that maps storage to/from expressed types in an unspecified way.
Definition: QuantTypes.h:201
A quantized type that infers its range from given min/max values.
Definition: QuantTypes.h:522
Base class for all quantized types known to this dialect.
Definition: QuantTypes.h:50
static constexpr unsigned MaxStorageBits
The maximum number of bits supported for storage types.
Definition: QuantTypes.h:56
bool hasStorageTypeBounds() const
Return whether the storage type has explicit min or max boundaries different from the minimum and max...
Definition: QuantTypes.cpp:92
bool isSigned() const
Whether the storage type should be interpreted as a signed quantity (true) or an unsigned value (fals...
Definition: QuantTypes.h:103
int64_t getStorageTypeMax() const
The maximum value that storageType can take.
Definition: QuantTypes.cpp:88
static int64_t getDefaultMaximumForInteger(bool isSigned, unsigned integralWidth)
Gets the maximum possible stored by a storageType.
Definition: QuantTypes.h:78
unsigned getStorageTypeIntegralWidth() const
Gets the integral bit width that the underlying storage type can exactly represent.
Definition: QuantTypes.cpp:103
int64_t getStorageTypeMin() const
The minimum value that storageType can take.
Definition: QuantTypes.cpp:84
static int64_t getDefaultMinimumForInteger(bool isSigned, unsigned integralWidth)
Gets the minimum possible stored by a storageType.
Definition: QuantTypes.h:68
Represents per-axis (also known as per-channel quantization).
Definition: QuantTypes.h:322
int32_t getQuantizedDimension() const
Specifies the dimension of the Tensor's shape that the scales and zero_points correspond to.
Definition: QuantTypes.cpp:409
ArrayRef< int64_t > getZeroPoints() const
Gets the storage values corresponding to the real value 0 in the affine equation.
Definition: QuantTypes.cpp:405
ArrayRef< double > getScales() const
Gets the quantization scales.
Definition: QuantTypes.cpp:401
Represents sub-channel (also known as blockwise quantization).
Definition: QuantTypes.h:407
DenseElementsAttr getZeroPoints() const
Gets the quantization zero-points.
Definition: QuantTypes.cpp:504
const SmallVector< std::pair< int32_t, int64_t > > getBlockSizeInfo() const
Gets the block size information.
Definition: QuantTypes.cpp:518
DenseElementsAttr getScales() const
Gets the quantization scales.
Definition: QuantTypes.cpp:500
Represents a family of uniform, quantized types.
Definition: QuantTypes.h:262
double getScale() const
Gets the scale term.
Definition: QuantTypes.cpp:332
int64_t getZeroPoint() const
Gets the storage value corresponding to the real value 0 in the affine equation.
Definition: QuantTypes.cpp:334
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.