MLIR  22.0.0git
TypeParser.cpp
Go to the documentation of this file.
1 //===- TypeParser.h - Quantization Type Parser ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/IR/Types.h"
14 #include "llvm/ADT/APFloat.h"
15 
16 using namespace mlir;
17 using namespace quant;
18 
19 static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
20  auto typeLoc = parser.getCurrentLocation();
21  IntegerType type;
22 
23  // Parse storage type (alpha_ident, integer_literal).
24  StringRef identifier;
25  unsigned storageTypeWidth = 0;
26  OptionalParseResult result = parser.parseOptionalType(type);
27  if (result.has_value()) {
28  if (!succeeded(*result))
29  return nullptr;
30  isSigned = !type.isUnsigned();
31  storageTypeWidth = type.getWidth();
32  } else if (succeeded(parser.parseKeyword(&identifier))) {
33  // Otherwise, this must be an unsigned integer (`u` integer-literal).
34  if (!identifier.consume_front("u")) {
35  parser.emitError(typeLoc, "illegal storage type prefix");
36  return nullptr;
37  }
38  if (identifier.getAsInteger(10, storageTypeWidth)) {
39  parser.emitError(typeLoc, "expected storage type width");
40  return nullptr;
41  }
42  isSigned = false;
43  type = parser.getBuilder().getIntegerType(storageTypeWidth);
44  } else {
45  return nullptr;
46  }
47 
48  if (storageTypeWidth == 0 ||
49  storageTypeWidth > QuantizedType::MaxStorageBits) {
50  parser.emitError(typeLoc, "illegal storage type size: ")
51  << storageTypeWidth;
52  return nullptr;
53  }
54 
55  return type;
56 }
57 
58 static ParseResult parseStorageRange(DialectAsmParser &parser,
59  IntegerType storageType, bool isSigned,
60  int64_t &storageTypeMin,
61  int64_t &storageTypeMax) {
62  int64_t defaultIntegerMin = QuantizedType::getDefaultMinimumForInteger(
63  isSigned, storageType.getWidth());
64  int64_t defaultIntegerMax = QuantizedType::getDefaultMaximumForInteger(
65  isSigned, storageType.getWidth());
66  if (failed(parser.parseOptionalLess())) {
67  storageTypeMin = defaultIntegerMin;
68  storageTypeMax = defaultIntegerMax;
69  return success();
70  }
71 
72  // Explicit storage min and storage max.
73  SMLoc minLoc = parser.getCurrentLocation(), maxLoc;
74  if (parser.parseInteger(storageTypeMin) || parser.parseColon() ||
75  parser.getCurrentLocation(&maxLoc) ||
76  parser.parseInteger(storageTypeMax) || parser.parseGreater())
77  return failure();
78  if (storageTypeMin < defaultIntegerMin) {
79  return parser.emitError(minLoc, "illegal storage type minimum: ")
80  << storageTypeMin;
81  }
82  if (storageTypeMax > defaultIntegerMax) {
83  return parser.emitError(maxLoc, "illegal storage type maximum: ")
84  << storageTypeMax;
85  }
86  return success();
87 }
88 
90  double &min, double &max) {
91  auto typeLoc = parser.getCurrentLocation();
92  FloatType type;
93 
94  if (failed(parser.parseType(type))) {
95  parser.emitError(typeLoc, "expecting float expressed type");
96  return nullptr;
97  }
98 
99  // Calibrated min and max values.
100  if (parser.parseLess() || parser.parseFloat(min) || parser.parseColon() ||
101  parser.parseFloat(max) || parser.parseGreater()) {
102  parser.emitError(typeLoc, "calibrated values must be present");
103  return nullptr;
104  }
105  return type;
106 }
107 
108 /// Parses an AnyQuantizedType.
109 ///
110 /// any ::= `any<` storage-spec (expressed-type-spec)?`>`
111 /// storage-spec ::= storage-type (`<` storage-range `>`)?
112 /// storage-range ::= integer-literal `:` integer-literal
113 /// storage-type ::= (`i` | `u`) integer-literal
114 /// expressed-type-spec ::= `:` `f` integer-literal
116  IntegerType storageType;
117  FloatType expressedType;
118  unsigned typeFlags = 0;
119  int64_t storageTypeMin;
120  int64_t storageTypeMax;
121 
122  // Type specification.
123  if (parser.parseLess())
124  return nullptr;
125 
126  // Storage type.
127  bool isSigned = false;
128  storageType = parseStorageType(parser, isSigned);
129  if (!storageType) {
130  return nullptr;
131  }
132  if (isSigned) {
133  typeFlags |= QuantizationFlags::Signed;
134  }
135 
136  // Storage type range.
137  if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
138  storageTypeMax)) {
139  return nullptr;
140  }
141 
142  // Optional expressed type.
143  if (succeeded(parser.parseOptionalColon())) {
144  if (parser.parseType(expressedType)) {
145  return nullptr;
146  }
147  }
148 
149  if (parser.parseGreater()) {
150  return nullptr;
151  }
152 
153  return parser.getChecked<AnyQuantizedType>(
154  typeFlags, storageType, expressedType, storageTypeMin, storageTypeMax);
155 }
156 
157 /// Checks if the given scale value is within the valid range of the expressed
158 /// type. The `expressedType` argument is the floating-point type used for
159 /// expressing the quantized values, and `scale` is the double value to check.
160 LogicalResult
162  Type expressedType, double scale) {
163  auto floatType = cast<FloatType>(expressedType);
164  double minScale =
165  APFloat::getSmallest(floatType.getFloatSemantics()).convertToDouble();
166  double maxScale =
167  APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble();
168  if (scale < minScale || scale > maxScale)
169  return emitError() << "scale " << scale << " out of expressed type range ["
170  << minScale << ", " << maxScale << "]";
171  return success();
172 }
173 
174 /// Parses a quantization parameter, which is either a scale value (float) or a
175 /// scale-zero point pair (float:integer). `expressedType`, expressing the type
176 /// of scale values, is used to validate the scale. The parsed scale and zero
177 /// point (if any) are stored in `scale` and `zeroPoint`.
178 static ParseResult parseQuantParams(DialectAsmParser &parser,
179  Type expressedType, double &scale,
180  int64_t &zeroPoint) {
181 
182  if (parser.parseFloat(scale)) {
183  return failure();
184  }
185 
186  if (failed(isScaleInExpressedTypeRange(
187  [&]() { return parser.emitError(parser.getCurrentLocation()); },
188  expressedType, scale))) {
189  return failure();
190  }
191 
192  zeroPoint = 0;
193  if (failed(parser.parseOptionalColon())) {
194  return success();
195  }
196 
197  return parser.parseInteger(zeroPoint);
198 }
199 
200 /// Parses block size information for sub-channel quantization, assuming the
201 /// leading '{' has already been parsed. The block size information is provided
202 /// as a comma-separated list of "Axis:BlockSize" pairs, terminated by a '}'.
203 ///
204 /// The parsed axis indices are stored in `quantizedDimensions`, and the
205 /// corresponding block sizes are stored in `blockSizes`.
206 static ParseResult
208  SmallVectorImpl<int32_t> &quantizedDimensions,
209  SmallVectorImpl<int64_t> &blockSizes) {
210  // Empty block-sizes info.
211  if (succeeded(parser.parseOptionalRBrace())) {
212  return success();
213  }
214 
215  auto parseBlockSizeElements = [&]() -> ParseResult {
216  quantizedDimensions.resize(quantizedDimensions.size() + 1);
217  blockSizes.resize(blockSizes.size() + 1);
218  if (parser.parseInteger(quantizedDimensions.back()) ||
219  parser.parseColon() || parser.parseInteger(blockSizes.back()))
220  return failure();
221  return success();
222  };
223 
224  if (parser.parseCommaSeparatedList(parseBlockSizeElements) ||
225  parser.parseRBrace()) {
226  return failure();
227  }
228 
229  return success();
230 }
231 
232 /// Parses a bracketed list of quantization parameters, returning the dimensions
233 /// of the parsed sub-tensors in `dims`. The dimension of the list is prepended
234 /// to the dimensions of the sub-tensors. This function assumes that the initial
235 /// left brace has already been parsed. For example:
236 ///
237 /// parseQuantParamListUntilRBrace(1.0:1, 2.0:4, 3.0:4}) -> Success,
238 /// dims = [3], scales = [1.0, 2.0, 3.0], zeroPoints = [1, 4, 4]
239 ///
240 /// parseQuantParamListUntilRBrace({1.0, 2.0}, {3.0:1, 4.0:9}}) -> Success,
241 /// dims = [2, 2], scales = [1.0, 2.0, 3.0, 4.0], zeroPoints = [0, 0, 1,
242 /// 9]
243 ///
244 /// This function expects all sub-tensors to have the same rank.
245 static ParseResult
247  SmallVectorImpl<double> &scales,
248  SmallVectorImpl<int64_t> &zeroPoints,
249  SmallVectorImpl<int64_t> &dims) {
250  auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims,
251  const SmallVectorImpl<int64_t> &newDims) -> ParseResult {
252  if (prevDims == newDims)
253  return success();
254  return parser.emitError(parser.getCurrentLocation())
255  << "tensor literal is invalid; ranks are not consistent "
256  "between elements";
257  };
258 
259  bool first = true;
260  SmallVector<int64_t, 4> newDims;
261  unsigned size = 0;
262 
263  auto parseOneElement = [&]() -> ParseResult {
264  SmallVector<int64_t, 4> thisDims;
265  if (succeeded(parser.parseOptionalLBrace())) {
266  if (parseQuantParamListUntilRBrace(parser, expressedType, scales,
267  zeroPoints, thisDims))
268  return failure();
269  } else {
270  zeroPoints.resize(zeroPoints.size() + 1);
271  scales.resize(scales.size() + 1);
272  if (parseQuantParams(parser, expressedType, scales.back(),
273  zeroPoints.back())) {
274  return failure();
275  }
276  }
277  ++size;
278  if (!first)
279  return checkDims(newDims, thisDims);
280  newDims = thisDims;
281  first = false;
282  return success();
283  };
284 
285  if (parser.parseCommaSeparatedList(parseOneElement) || parser.parseRBrace()) {
286  return failure();
287  }
288 
289  // Return the sublists' dimensions with 'size' prepended.
290  dims.clear();
291  dims.push_back(size);
292  dims.append(newDims.begin(), newDims.end());
293 
294  return success();
295 }
296 
297 /// Parses a UniformQuantizedType.
298 ///
299 /// uniform_type ::= uniform_per_layer
300 /// | uniform_per_axis
301 /// | uniform_sub_channel
302 /// uniform_per_layer ::= `uniform<` storage-spec expressed-type-spec
303 /// `,` scale-zero `>`
304 /// uniform_per_axis ::= `uniform<` storage-spec expressed-type-spec
305 /// axis-spec `,` `{` scale-zero-list `}` `>`
306 /// uniform_sub_channel ::= `uniform<` storage-spec expressed-type-spec
307 /// block-size-info `,` scale-zero-tensor `>`
308 /// storage-spec ::= storage-type (`<` storage-range `>`)?
309 /// storage-range ::= integer-literal `:` integer-literal
310 /// storage-type ::= (`i` | `u`) integer-literal
311 /// expressed-type-spec ::= `:` `f` integer-literal
312 /// axis-spec ::= `:` integer-literal
313 /// scale-zero ::= scale (`:` zero-point)?
314 /// scale ::= float-literal
315 /// zero-point ::= integer-literal
316 /// scale-zero-list ::= scale-zero (`,` scale-zero)*
317 /// block-size-info ::= `{` `}` | `{` axis-block `:` (`,` axis-block)* `}`
318 /// axis-block ::= axis-spec `:` block-size-spec
319 /// block-size-spec ::= integer-literal
320 /// scale-zero-tensor ::= scale-zero-dense-exp | scale-zero-list
321 /// scale-zero-dense-exp ::= `{`
322 /// scale-zero-tensor (`,` scale-zero-tensor)*
323 /// `}`
325  IntegerType storageType;
326  FloatType expressedType;
327  unsigned typeFlags = 0;
328  int64_t storageTypeMin;
329  int64_t storageTypeMax;
330  bool isPerAxis = false;
331  bool isSubChannel = false;
332  SmallVector<int32_t, 1> quantizedDimensions;
333  SmallVector<int64_t, 1> blockSizes;
334  SmallVector<double, 1> scales;
335  SmallVector<int64_t, 1> zeroPoints;
336 
337  // Type specification.
338  if (parser.parseLess()) {
339  return nullptr;
340  }
341 
342  // Storage type.
343  bool isSigned = false;
344  storageType = parseStorageType(parser, isSigned);
345  if (!storageType) {
346  return nullptr;
347  }
348  if (isSigned) {
349  typeFlags |= QuantizationFlags::Signed;
350  }
351 
352  // Storage type range.
353  if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
354  storageTypeMax)) {
355  return nullptr;
356  }
357 
358  // Expressed type.
359  if (parser.parseColon() || parser.parseType(expressedType)) {
360  return nullptr;
361  }
362 
363  // Optionally parse quantized dimension for per-axis or sub-channel
364  // quantization.
365  if (succeeded(parser.parseOptionalColon())) {
366  if (succeeded(parser.parseOptionalLBrace())) {
367  isSubChannel = true;
368  if (parseBlockSizeInfoUntilRBrace(parser, quantizedDimensions,
369  blockSizes)) {
370  return nullptr;
371  }
372  } else {
373  isPerAxis = true;
374  quantizedDimensions.resize(1);
375  if (parser.parseInteger(quantizedDimensions.back())) {
376  return nullptr;
377  }
378  }
379  }
380 
381  // Comma leading into range_spec.
382  if (parser.parseComma()) {
383  return nullptr;
384  }
385 
386  // Quantization parameter (scales/zeroPoints) specification.
387  bool isPerTensor = !isPerAxis && !isSubChannel;
389  if (isPerTensor) {
390  zeroPoints.resize(zeroPoints.size() + 1);
391  scales.resize(scales.size() + 1);
392  if (parseQuantParams(parser, expressedType, scales.back(),
393  zeroPoints.back())) {
394  return nullptr;
395  }
396 
397  } else {
398  if (parser.parseLBrace() ||
399  parseQuantParamListUntilRBrace(parser, expressedType, scales,
400  zeroPoints, dims)) {
401  return nullptr;
402  }
403  }
404 
405  if (parser.parseGreater()) {
406  return nullptr;
407  }
408 
409  if (isPerAxis) {
411  typeFlags, storageType, expressedType, scales, zeroPoints,
412  quantizedDimensions[0], storageTypeMin, storageTypeMax);
413  } else if (isSubChannel) {
414  SmallVector<APFloat> apFloatScales =
415  llvm::to_vector(llvm::map_range(scales, [&](double scale) -> APFloat {
416  APFloat apFloatScale(scale);
417  bool unused;
418  apFloatScale.convert(expressedType.getFloatSemantics(),
419  APFloat::rmNearestTiesToEven, &unused);
420  return apFloatScale;
421  }));
422  SmallVector<APInt> apIntZeroPoints = llvm::to_vector(
423  llvm::map_range(zeroPoints, [&](int64_t zeroPoint) -> APInt {
424  return APInt(storageType.getIntOrFloatBitWidth(), zeroPoint);
425  }));
426  auto scalesRef = mlir::DenseElementsAttr::get(
427  RankedTensorType::get(dims, expressedType), apFloatScales);
428  auto zeroPointsRef = mlir::DenseElementsAttr::get(
429  RankedTensorType::get(dims, storageType), apIntZeroPoints);
431  typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
432  quantizedDimensions, blockSizes, storageTypeMin, storageTypeMax);
433  }
434 
435  return parser.getChecked<UniformQuantizedType>(
436  typeFlags, storageType, expressedType, scales.front(), zeroPoints.front(),
437  storageTypeMin, storageTypeMax);
438 }
439 
440 /// Parses an CalibratedQuantizedType.
441 ///
442 /// calibrated ::= `calibrated<` expressed-spec `>`
443 /// expressed-spec ::= expressed-type `<` calibrated-range `>`
444 /// expressed-type ::= `f` integer-literal
445 /// calibrated-range ::= float-literal `:` float-literal
447  FloatType expressedType;
448  double min;
449  double max;
450 
451  // Type specification.
452  if (parser.parseLess())
453  return nullptr;
454 
455  // Expressed type.
456  expressedType = parseExpressedTypeAndRange(parser, min, max);
457  if (!expressedType) {
458  return nullptr;
459  }
460 
461  if (parser.parseGreater()) {
462  return nullptr;
463  }
464 
465  return parser.getChecked<CalibratedQuantizedType>(expressedType, min, max);
466 }
467 
468 /// Parse a type registered to this dialect.
470  // All types start with an identifier that we switch on.
471  StringRef typeNameSpelling;
472  if (failed(parser.parseKeyword(&typeNameSpelling)))
473  return nullptr;
474 
475  if (typeNameSpelling == "uniform")
476  return parseUniformType(parser);
477  if (typeNameSpelling == "any")
478  return parseAnyType(parser);
479  if (typeNameSpelling == "calibrated")
480  return parseCalibratedType(parser);
481 
482  parser.emitError(parser.getNameLoc(),
483  "unknown quantized type " + typeNameSpelling);
484  return nullptr;
485 }
486 
488  // storage type
489  unsigned storageWidth = type.getStorageTypeIntegralWidth();
490  bool isSigned = type.isSigned();
491  if (isSigned) {
492  out << "i" << storageWidth;
493  } else {
494  out << "u" << storageWidth;
495  }
496 
497  // storageTypeMin and storageTypeMax if not default.
498  if (type.hasStorageTypeBounds()) {
499  out << "<" << type.getStorageTypeMin() << ":" << type.getStorageTypeMax()
500  << ">";
501  }
502 }
503 
504 static void printQuantParams(double scale, int64_t zeroPoint,
505  DialectAsmPrinter &out) {
506  out << scale;
507  if (zeroPoint != 0) {
508  out << ":" << zeroPoint;
509  }
510 }
511 
512 static void
513 printBlockSizeInfo(ArrayRef<std::pair<int32_t, int64_t>> blockSizeInfo,
514  DialectAsmPrinter &out) {
515  out << "{";
516  llvm::interleaveComma(
517  llvm::seq<size_t>(0, blockSizeInfo.size()), out, [&](size_t index) {
518  out << blockSizeInfo[index].first << ":" << blockSizeInfo[index].second;
519  });
520  out << "}";
521 }
522 
523 /// Helper that prints a AnyQuantizedType.
525  DialectAsmPrinter &out) {
526  out << "any<";
527  printStorageType(type, out);
528  if (Type expressedType = type.getExpressedType()) {
529  out << ":" << expressedType;
530  }
531  out << ">";
532 }
533 
534 /// Helper that prints a UniformQuantizedType.
536  DialectAsmPrinter &out) {
537  out << "uniform<";
538  printStorageType(type, out);
539  out << ":" << type.getExpressedType() << ", ";
540 
541  // scheme specific parameters
542  printQuantParams(type.getScale(), type.getZeroPoint(), out);
543  out << ">";
544 }
545 
546 /// Helper that prints a UniformQuantizedPerAxisType.
548  DialectAsmPrinter &out) {
549  out << "uniform<";
550  printStorageType(type, out);
551  out << ":" << type.getExpressedType() << ":";
552  out << type.getQuantizedDimension();
553  out << ", ";
554 
555  // scheme specific parameters
556  ArrayRef<double> scales = type.getScales();
557  ArrayRef<int64_t> zeroPoints = type.getZeroPoints();
558  out << "{";
559  llvm::interleave(
560  llvm::seq<size_t>(0, scales.size()), out,
561  [&](size_t index) {
562  printQuantParams(scales[index], zeroPoints[index], out);
563  },
564  ",");
565  out << "}>";
566 }
567 
568 /// Prints quantization parameters as a nested list of `scale`[:`zero_point`]
569 /// elements. The nesting corresponds to the `shape` dimensions.
570 ///
571 /// Elements are delimited by commas, and the inner dimensions are enclosed in
572 /// braces. `zero_point` is only printed if it is non-zero. For example:
573 ///
574 /// printDenseQuantizationParameters(scales=[1.0, 2.0, 3.0, 4.0],
575 /// zeroPoints=[0, 0, 1, 9],
576 /// shape=[2, 2])
577 ///
578 /// would print:
579 ///
580 /// {{1.0, 2.0}, {3.0:1, 4.0:9}}
582  ArrayRef<APInt> zeroPoints,
583  ArrayRef<int64_t> shape,
584  DialectAsmPrinter &out) {
585  int64_t rank = shape.size();
586  SmallVector<unsigned, 4> counter(rank, 0);
587  unsigned openBrackets = 0;
588 
589  auto incrementCounterAndDelimit = [&]() {
590  ++counter[rank - 1];
591  for (unsigned i = rank - 1; i > 0; --i) {
592  if (counter[i] >= shape[i]) {
593  counter[i] = 0;
594  ++counter[i - 1];
595  --openBrackets;
596  out << '}';
597  }
598  }
599  };
600 
601  for (unsigned idx = 0, e = scales.size(); idx < e; ++idx) {
602  if (idx != 0)
603  out << ", ";
604  while (openBrackets++ < rank)
605  out << '{';
606  openBrackets = rank;
607  out << scales[idx];
608  if (zeroPoints[idx] != 0) {
609  out << ":" << zeroPoints[idx];
610  }
611  incrementCounterAndDelimit();
612  }
613  while (openBrackets-- > 0)
614  out << '}';
615 }
616 
617 /// Helper that prints a UniformQuantizedSubChannelType.
618 static void
620  DialectAsmPrinter &out) {
621  out << "uniform<";
622  printStorageType(type, out);
623  out << ":" << type.getExpressedType() << ":";
625  out << ", ";
626 
627  auto scalesItr = type.getScales().getValues<APFloat>();
628  auto zeroPointsItr = type.getZeroPoints().getValues<APInt>();
629  SmallVector<APFloat> scales(scalesItr.begin(), scalesItr.end());
630  SmallVector<APInt> zeroPoints(zeroPointsItr.begin(), zeroPointsItr.end());
631  printDenseQuantizationParameters(scales, zeroPoints,
632  type.getScales().getType().getShape(), out);
633  out << ">";
634 }
635 
636 /// Helper that prints a CalibratedQuantizedType.
638  DialectAsmPrinter &out) {
639  out << "calibrated<" << type.getExpressedType();
640  out << "<" << type.getMin() << ":" << type.getMax() << ">";
641  out << ">";
642 }
643 
644 /// Print a type registered to this dialect.
645 void QuantDialect::printType(Type type, DialectAsmPrinter &os) const {
646  if (auto anyType = llvm::dyn_cast<AnyQuantizedType>(type))
647  printAnyQuantizedType(anyType, os);
648  else if (auto uniformType = llvm::dyn_cast<UniformQuantizedType>(type))
649  printUniformQuantizedType(uniformType, os);
650  else if (auto perAxisType = llvm::dyn_cast<UniformQuantizedPerAxisType>(type))
651  printUniformQuantizedPerAxisType(perAxisType, os);
652  else if (auto perAxisType =
653  llvm::dyn_cast<UniformQuantizedSubChannelType>(type))
654  printUniformQuantizedSubChannelType(perAxisType, os);
655  else if (auto calibratedType = llvm::dyn_cast<CalibratedQuantizedType>(type))
656  printCalibratedQuantizedType(calibratedType, os);
657  else
658  llvm_unreachable("Unhandled quantized type");
659 }
static void printBlockSizeInfo(ArrayRef< std::pair< int32_t, int64_t >> blockSizeInfo, DialectAsmPrinter &out)
Definition: TypeParser.cpp:513
static ParseResult parseQuantParams(DialectAsmParser &parser, Type expressedType, double &scale, int64_t &zeroPoint)
Parses a quantization parameter, which is either a scale value (float) or a scale-zero point pair (fl...
Definition: TypeParser.cpp:178
void printDenseQuantizationParameters(ArrayRef< APFloat > scales, ArrayRef< APInt > zeroPoints, ArrayRef< int64_t > shape, DialectAsmPrinter &out)
Prints quantization parameters as a nested list of scale[:zero_point] elements.
Definition: TypeParser.cpp:581
static void printAnyQuantizedType(AnyQuantizedType type, DialectAsmPrinter &out)
Helper that prints a AnyQuantizedType.
Definition: TypeParser.cpp:524
LogicalResult isScaleInExpressedTypeRange(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double scale)
Checks if the given scale value is within the valid range of the expressed type.
Definition: TypeParser.cpp:161
static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser, double &min, double &max)
Definition: TypeParser.cpp:89
static void printUniformQuantizedSubChannelType(UniformQuantizedSubChannelType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedSubChannelType.
Definition: TypeParser.cpp:619
static Type parseUniformType(DialectAsmParser &parser)
Parses a UniformQuantizedType.
Definition: TypeParser.cpp:324
static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned)
Definition: TypeParser.cpp:19
static Type parseAnyType(DialectAsmParser &parser)
Parses an AnyQuantizedType.
Definition: TypeParser.cpp:115
static void printStorageType(QuantizedType type, DialectAsmPrinter &out)
Definition: TypeParser.cpp:487
static ParseResult parseBlockSizeInfoUntilRBrace(DialectAsmParser &parser, SmallVectorImpl< int32_t > &quantizedDimensions, SmallVectorImpl< int64_t > &blockSizes)
Parses block size information for sub-channel quantization, assuming the leading '{' has already been...
Definition: TypeParser.cpp:207
static void printQuantParams(double scale, int64_t zeroPoint, DialectAsmPrinter &out)
Definition: TypeParser.cpp:504
static Type parseCalibratedType(DialectAsmParser &parser)
Parses an CalibratedQuantizedType.
Definition: TypeParser.cpp:446
static ParseResult parseQuantParamListUntilRBrace(DialectAsmParser &parser, Type expressedType, SmallVectorImpl< double > &scales, SmallVectorImpl< int64_t > &zeroPoints, SmallVectorImpl< int64_t > &dims)
Parses a bracketed list of quantization parameters, returning the dimensions of the parsed sub-tensor...
Definition: TypeParser.cpp:246
static ParseResult parseStorageRange(DialectAsmParser &parser, IntegerType storageType, bool isSigned, int64_t &storageTypeMin, int64_t &storageTypeMax)
Definition: TypeParser.cpp:58
static void printCalibratedQuantizedType(CalibratedQuantizedType type, DialectAsmPrinter &out)
Helper that prints a CalibratedQuantizedType.
Definition: TypeParser.cpp:637
static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedPerAxisType.
Definition: TypeParser.cpp:547
static void printUniformQuantizedType(UniformQuantizedType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedType.
Definition: TypeParser.cpp:535
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
virtual ParseResult parseLBrace()=0
Parse a { token.
virtual ParseResult parseOptionalRBrace()=0
Parse a } token if present.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual OptionalParseResult parseOptionalType(Type &result)=0
Parse an optional type.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseFloat(double &result)=0
Parse a floating point value from the stream.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:40
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:50
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
A quantized type that maps storage to/from expressed types in an unspecified way.
Definition: QuantTypes.h:201
A quantized type that infers its range from given min/max values.
Definition: QuantTypes.h:522
Base class for all quantized types known to this dialect.
Definition: QuantTypes.h:50
static constexpr unsigned MaxStorageBits
The maximum number of bits supported for storage types.
Definition: QuantTypes.h:56
bool hasStorageTypeBounds() const
Return whether the storage type has explicit min or max boundaries different from the minimum and max...
Definition: QuantTypes.cpp:89
bool isSigned() const
Whether the storage type should be interpreted as a signed quantity (true) or an unsigned value (fals...
Definition: QuantTypes.h:103
int64_t getStorageTypeMax() const
The maximum value that storageType can take.
Definition: QuantTypes.cpp:85
static int64_t getDefaultMaximumForInteger(bool isSigned, unsigned integralWidth)
Gets the maximum possible stored by a storageType.
Definition: QuantTypes.h:78
unsigned getStorageTypeIntegralWidth() const
Gets the integral bit width that the underlying storage type can exactly represent.
Definition: QuantTypes.cpp:100
int64_t getStorageTypeMin() const
The minimum value that storageType can take.
Definition: QuantTypes.cpp:81
static int64_t getDefaultMinimumForInteger(bool isSigned, unsigned integralWidth)
Gets the minimum possible stored by a storageType.
Definition: QuantTypes.h:68
Represents per-axis (also known as per-channel quantization).
Definition: QuantTypes.h:322
int32_t getQuantizedDimension() const
Specifies the dimension of the Tensor's shape that the scales and zero_points correspond to.
Definition: QuantTypes.cpp:406
ArrayRef< int64_t > getZeroPoints() const
Gets the storage values corresponding to the real value 0 in the affine equation.
Definition: QuantTypes.cpp:402
ArrayRef< double > getScales() const
Gets the quantization scales.
Definition: QuantTypes.cpp:398
Represents sub-channel (also known as blockwise quantization).
Definition: QuantTypes.h:407
DenseElementsAttr getZeroPoints() const
Gets the quantization zero-points.
Definition: QuantTypes.cpp:501
const SmallVector< std::pair< int32_t, int64_t > > getBlockSizeInfo() const
Gets the block size information.
Definition: QuantTypes.cpp:515
DenseElementsAttr getScales() const
Gets the quantization scales.
Definition: QuantTypes.cpp:497
Represents a family of uniform, quantized types.
Definition: QuantTypes.h:262
double getScale() const
Gets the scale term.
Definition: QuantTypes.cpp:329
int64_t getZeroPoint() const
Gets the storage value corresponding to the real value 0 in the affine equation.
Definition: QuantTypes.cpp:331
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.