MLIR  14.0.0git
TypeParser.cpp
Go to the documentation of this file.
1 //===- TypeParser.h - Quantization Type Parser ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/IR/Location.h"
14 #include "mlir/IR/Types.h"
15 #include "llvm/ADT/APFloat.h"
16 #include "llvm/ADT/StringSwitch.h"
17 #include "llvm/Support/Format.h"
18 #include "llvm/Support/MathExtras.h"
19 #include "llvm/Support/SourceMgr.h"
20 #include "llvm/Support/raw_ostream.h"
21 
22 using namespace mlir;
23 using namespace quant;
24 
25 static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
26  auto typeLoc = parser.getCurrentLocation();
27  IntegerType type;
28 
29  // Parse storage type (alpha_ident, integer_literal).
30  StringRef identifier;
31  unsigned storageTypeWidth = 0;
32  OptionalParseResult result = parser.parseOptionalType(type);
33  if (result.hasValue()) {
34  if (!succeeded(*result)) {
35  parser.parseType(type);
36  return nullptr;
37  }
38  isSigned = !type.isUnsigned();
39  storageTypeWidth = type.getWidth();
40  } else if (succeeded(parser.parseKeyword(&identifier))) {
41  // Otherwise, this must be an unsigned integer (`u` integer-literal).
42  if (!identifier.consume_front("u")) {
43  parser.emitError(typeLoc, "illegal storage type prefix");
44  return nullptr;
45  }
46  if (identifier.getAsInteger(10, storageTypeWidth)) {
47  parser.emitError(typeLoc, "expected storage type width");
48  return nullptr;
49  }
50  isSigned = false;
51  type = parser.getBuilder().getIntegerType(storageTypeWidth);
52  } else {
53  return nullptr;
54  }
55 
56  if (storageTypeWidth == 0 ||
57  storageTypeWidth > QuantizedType::MaxStorageBits) {
58  parser.emitError(typeLoc, "illegal storage type size: ")
59  << storageTypeWidth;
60  return nullptr;
61  }
62 
63  return type;
64 }
65 
67  IntegerType storageType, bool isSigned,
68  int64_t &storageTypeMin,
69  int64_t &storageTypeMax) {
70  int64_t defaultIntegerMin = QuantizedType::getDefaultMinimumForInteger(
71  isSigned, storageType.getWidth());
72  int64_t defaultIntegerMax = QuantizedType::getDefaultMaximumForInteger(
73  isSigned, storageType.getWidth());
74  if (failed(parser.parseOptionalLess())) {
75  storageTypeMin = defaultIntegerMin;
76  storageTypeMax = defaultIntegerMax;
77  return success();
78  }
79 
80  // Explicit storage min and storage max.
81  llvm::SMLoc minLoc = parser.getCurrentLocation(), maxLoc;
82  if (parser.parseInteger(storageTypeMin) || parser.parseColon() ||
83  parser.getCurrentLocation(&maxLoc) ||
84  parser.parseInteger(storageTypeMax) || parser.parseGreater())
85  return failure();
86  if (storageTypeMin < defaultIntegerMin) {
87  return parser.emitError(minLoc, "illegal storage type minimum: ")
88  << storageTypeMin;
89  }
90  if (storageTypeMax > defaultIntegerMax) {
91  return parser.emitError(maxLoc, "illegal storage type maximum: ")
92  << storageTypeMax;
93  }
94  return success();
95 }
96 
98  double &min, double &max) {
99  auto typeLoc = parser.getCurrentLocation();
100  FloatType type;
101 
102  if (failed(parser.parseType(type))) {
103  parser.emitError(typeLoc, "expecting float expressed type");
104  return nullptr;
105  }
106 
107  // Calibrated min and max values.
108  if (parser.parseLess() || parser.parseFloat(min) || parser.parseColon() ||
109  parser.parseFloat(max) || parser.parseGreater()) {
110  parser.emitError(typeLoc, "calibrated values must be present");
111  return nullptr;
112  }
113  return type;
114 }
115 
116 /// Parses an AnyQuantizedType.
117 ///
118 /// any ::= `any<` storage-spec (expressed-type-spec)?`>`
119 /// storage-spec ::= storage-type (`<` storage-range `>`)?
120 /// storage-range ::= integer-literal `:` integer-literal
121 /// storage-type ::= (`i` | `u`) integer-literal
122 /// expressed-type-spec ::= `:` `f` integer-literal
124  IntegerType storageType;
125  FloatType expressedType;
126  unsigned typeFlags = 0;
127  int64_t storageTypeMin;
128  int64_t storageTypeMax;
129 
130  // Type specification.
131  if (parser.parseLess())
132  return nullptr;
133 
134  // Storage type.
135  bool isSigned = false;
136  storageType = parseStorageType(parser, isSigned);
137  if (!storageType) {
138  return nullptr;
139  }
140  if (isSigned) {
141  typeFlags |= QuantizationFlags::Signed;
142  }
143 
144  // Storage type range.
145  if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
146  storageTypeMax)) {
147  return nullptr;
148  }
149 
150  // Optional expressed type.
151  if (succeeded(parser.parseOptionalColon())) {
152  if (parser.parseType(expressedType)) {
153  return nullptr;
154  }
155  }
156 
157  if (parser.parseGreater()) {
158  return nullptr;
159  }
160 
161  return parser.getChecked<AnyQuantizedType>(
162  typeFlags, storageType, expressedType, storageTypeMin, storageTypeMax);
163 }
164 
166  int64_t &zeroPoint) {
167  // scale[:zeroPoint]?
168  // scale.
169  if (parser.parseFloat(scale))
170  return failure();
171 
172  // zero point.
173  zeroPoint = 0;
174  if (failed(parser.parseOptionalColon())) {
175  // Default zero point.
176  return success();
177  }
178 
179  return parser.parseInteger(zeroPoint);
180 }
181 
182 /// Parses a UniformQuantizedType.
183 ///
184 /// uniform_type ::= uniform_per_layer
185 /// | uniform_per_axis
186 /// uniform_per_layer ::= `uniform<` storage-spec expressed-type-spec
187 /// `,` scale-zero `>`
188 /// uniform_per_axis ::= `uniform<` storage-spec expressed-type-spec
189 /// axis-spec `,` scale-zero-list `>`
190 /// storage-spec ::= storage-type (`<` storage-range `>`)?
191 /// storage-range ::= integer-literal `:` integer-literal
192 /// storage-type ::= (`i` | `u`) integer-literal
193 /// expressed-type-spec ::= `:` `f` integer-literal
194 /// axis-spec ::= `:` integer-literal
195 /// scale-zero ::= float-literal `:` integer-literal
196 /// scale-zero-list ::= `{` scale-zero (`,` scale-zero)* `}`
198  IntegerType storageType;
199  FloatType expressedType;
200  unsigned typeFlags = 0;
201  int64_t storageTypeMin;
202  int64_t storageTypeMax;
203  bool isPerAxis = false;
204  int32_t quantizedDimension;
205  SmallVector<double, 1> scales;
206  SmallVector<int64_t, 1> zeroPoints;
207 
208  // Type specification.
209  if (parser.parseLess()) {
210  return nullptr;
211  }
212 
213  // Storage type.
214  bool isSigned = false;
215  storageType = parseStorageType(parser, isSigned);
216  if (!storageType) {
217  return nullptr;
218  }
219  if (isSigned) {
220  typeFlags |= QuantizationFlags::Signed;
221  }
222 
223  // Storage type range.
224  if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
225  storageTypeMax)) {
226  return nullptr;
227  }
228 
229  // Expressed type.
230  if (parser.parseColon() || parser.parseType(expressedType)) {
231  return nullptr;
232  }
233 
234  // Optionally parse quantized dimension for per-axis quantization.
235  if (succeeded(parser.parseOptionalColon())) {
236  if (parser.parseInteger(quantizedDimension))
237  return nullptr;
238  isPerAxis = true;
239  }
240 
241  // Comma leading into range_spec.
242  if (parser.parseComma()) {
243  return nullptr;
244  }
245 
246  // Parameter specification.
247  // For per-axis, ranges are in a {} delimitted list.
248  if (isPerAxis) {
249  if (parser.parseLBrace()) {
250  return nullptr;
251  }
252  }
253 
254  // Parse scales/zeroPoints.
255  llvm::SMLoc scaleZPLoc = parser.getCurrentLocation();
256  do {
257  scales.resize(scales.size() + 1);
258  zeroPoints.resize(zeroPoints.size() + 1);
259  if (parseQuantParams(parser, scales.back(), zeroPoints.back())) {
260  return nullptr;
261  }
262  } while (isPerAxis && succeeded(parser.parseOptionalComma()));
263 
264  if (isPerAxis) {
265  if (parser.parseRBrace()) {
266  return nullptr;
267  }
268  }
269 
270  if (parser.parseGreater()) {
271  return nullptr;
272  }
273 
274  if (!isPerAxis && scales.size() > 1) {
275  return (parser.emitError(scaleZPLoc,
276  "multiple scales/zeroPoints provided, but "
277  "quantizedDimension wasn't specified"),
278  nullptr);
279  }
280 
281  if (isPerAxis) {
282  ArrayRef<double> scalesRef(scales.begin(), scales.end());
283  ArrayRef<int64_t> zeroPointsRef(zeroPoints.begin(), zeroPoints.end());
285  typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
286  quantizedDimension, storageTypeMin, storageTypeMax);
287  }
288 
289  return parser.getChecked<UniformQuantizedType>(
290  typeFlags, storageType, expressedType, scales.front(), zeroPoints.front(),
291  storageTypeMin, storageTypeMax);
292 }
293 
294 /// Parses an CalibratedQuantizedType.
295 ///
296 /// calibrated ::= `calibrated<` expressed-spec `>`
297 /// expressed-spec ::= expressed-type `<` calibrated-range `>`
298 /// expressed-type ::= `f` integer-literal
299 /// calibrated-range ::= float-literal `:` float-literal
301  FloatType expressedType;
302  double min;
303  double max;
304 
305  // Type specification.
306  if (parser.parseLess())
307  return nullptr;
308 
309  // Expressed type.
310  expressedType = parseExpressedTypeAndRange(parser, min, max);
311  if (!expressedType) {
312  return nullptr;
313  }
314 
315  if (parser.parseGreater()) {
316  return nullptr;
317  }
318 
319  return parser.getChecked<CalibratedQuantizedType>(expressedType, min, max);
320 }
321 
322 /// Parse a type registered to this dialect.
324  // All types start with an identifier that we switch on.
325  StringRef typeNameSpelling;
326  if (failed(parser.parseKeyword(&typeNameSpelling)))
327  return nullptr;
328 
329  if (typeNameSpelling == "uniform")
330  return parseUniformType(parser);
331  if (typeNameSpelling == "any")
332  return parseAnyType(parser);
333  if (typeNameSpelling == "calibrated")
334  return parseCalibratedType(parser);
335 
336  parser.emitError(parser.getNameLoc(),
337  "unknown quantized type " + typeNameSpelling);
338  return nullptr;
339 }
340 
342  // storage type
343  unsigned storageWidth = type.getStorageTypeIntegralWidth();
344  bool isSigned = type.isSigned();
345  if (isSigned) {
346  out << "i" << storageWidth;
347  } else {
348  out << "u" << storageWidth;
349  }
350 
351  // storageTypeMin and storageTypeMax if not default.
352  int64_t defaultIntegerMin =
353  QuantizedType::getDefaultMinimumForInteger(isSigned, storageWidth);
354  int64_t defaultIntegerMax =
355  QuantizedType::getDefaultMaximumForInteger(isSigned, storageWidth);
356  if (defaultIntegerMin != type.getStorageTypeMin() ||
357  defaultIntegerMax != type.getStorageTypeMax()) {
358  out << "<" << type.getStorageTypeMin() << ":" << type.getStorageTypeMax()
359  << ">";
360  }
361 }
362 
363 static void printQuantParams(double scale, int64_t zeroPoint,
364  DialectAsmPrinter &out) {
365  out << scale;
366  if (zeroPoint != 0) {
367  out << ":" << zeroPoint;
368  }
369 }
370 
371 /// Helper that prints a AnyQuantizedType.
373  DialectAsmPrinter &out) {
374  out << "any<";
375  printStorageType(type, out);
376  if (Type expressedType = type.getExpressedType()) {
377  out << ":" << expressedType;
378  }
379  out << ">";
380 }
381 
382 /// Helper that prints a UniformQuantizedType.
384  DialectAsmPrinter &out) {
385  out << "uniform<";
386  printStorageType(type, out);
387  out << ":" << type.getExpressedType() << ", ";
388 
389  // scheme specific parameters
390  printQuantParams(type.getScale(), type.getZeroPoint(), out);
391  out << ">";
392 }
393 
394 /// Helper that prints a UniformQuantizedPerAxisType.
396  DialectAsmPrinter &out) {
397  out << "uniform<";
398  printStorageType(type, out);
399  out << ":" << type.getExpressedType() << ":";
400  out << type.getQuantizedDimension();
401  out << ", ";
402 
403  // scheme specific parameters
404  ArrayRef<double> scales = type.getScales();
405  ArrayRef<int64_t> zeroPoints = type.getZeroPoints();
406  out << "{";
407  llvm::interleave(
408  llvm::seq<size_t>(0, scales.size()), out,
409  [&](size_t index) {
410  printQuantParams(scales[index], zeroPoints[index], out);
411  },
412  ",");
413  out << "}>";
414 }
415 
416 /// Helper that prints a CalibratedQuantizedType.
418  DialectAsmPrinter &out) {
419  out << "calibrated<" << type.getExpressedType();
420  out << "<" << type.getMin() << ":" << type.getMax() << ">";
421  out << ">";
422 }
423 
424 /// Print a type registered to this dialect.
426  if (auto anyType = type.dyn_cast<AnyQuantizedType>())
427  printAnyQuantizedType(anyType, os);
428  else if (auto uniformType = type.dyn_cast<UniformQuantizedType>())
429  printUniformQuantizedType(uniformType, os);
430  else if (auto perAxisType = type.dyn_cast<UniformQuantizedPerAxisType>())
431  printUniformQuantizedPerAxisType(perAxisType, os);
432  else if (auto calibratedType = type.dyn_cast<CalibratedQuantizedType>())
433  printCalibratedQuantizedType(calibratedType, os);
434  else
435  llvm_unreachable("Unhandled quantized type");
436 }
Include the generated interface declarations.
int64_t getStorageTypeMin() const
The minimum value that storageType can take.
Definition: QuantTypes.cpp:67
static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedPerAxisType.
Definition: TypeParser.cpp:395
static Value min(ImplicitLocOpBuilder &builder, Value a, Value b)
static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser, double &min, double &max)
Definition: TypeParser.cpp:97
static int64_t getDefaultMinimumForInteger(bool isSigned, unsigned integralWidth)
Gets the minimum possible stored by a storageType.
Definition: QuantTypes.h:70
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
Represents a family of uniform, quantized types.
Definition: QuantTypes.h:256
virtual ParseResult parseLBrace()=0
Parse a { token.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
A quantized type that infers its range from given min/max values.
Definition: QuantTypes.h:383
static void printCalibratedQuantizedType(CalibratedQuantizedType type, DialectAsmPrinter &out)
Helper that prints a CalibratedQuantizedType.
Definition: TypeParser.cpp:417
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseComma()=0
Parse a , token.
virtual llvm::SMLoc getNameLoc() const =0
Return the location of the original name token.
static Type parseCalibratedType(DialectAsmParser &parser)
Parses an CalibratedQuantizedType.
Definition: TypeParser.cpp:300
virtual ParseResult parseFloat(double &result)=0
Parse a floating point value from the stream.
static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned)
Definition: TypeParser.cpp:25
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
static void printStorageType(QuantizedType type, DialectAsmPrinter &out)
Definition: TypeParser.cpp:341
virtual ParseResult parseColon()=0
Parse a : token.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
static void printQuantParams(double scale, int64_t zeroPoint, DialectAsmPrinter &out)
Definition: TypeParser.cpp:363
virtual ParseResult parseOptionalLess()=0
Parse a &#39;<&#39; token if present.
unsigned getStorageTypeIntegralWidth() const
Gets the integral bit width that the underlying storage type can exactly represent.
Definition: QuantTypes.cpp:75
virtual ParseResult parseGreater()=0
Parse a &#39;>&#39; token.
virtual llvm::SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
A quantized type that maps storage to/from expressed types in an unspecified way. ...
Definition: QuantTypes.h:197
static Type parseAnyType(DialectAsmParser &parser)
Parses an AnyQuantizedType.
Definition: TypeParser.cpp:123
U dyn_cast() const
Definition: Types.h:244
static ParseResult parseStorageRange(DialectAsmParser &parser, IntegerType storageType, bool isSigned, int64_t &storageTypeMin, int64_t &storageTypeMax)
Definition: TypeParser.cpp:66
int64_t getZeroPoint() const
Gets the storage value corresponding to the real value 0 in the affine equation.
Definition: QuantTypes.cpp:299
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:58
ParseResult parseKeyword(StringRef keyword, const Twine &msg="")
Parse a given keyword.
virtual ParseResult parseLess()=0
Parse a &#39;<&#39; token.
Represents per-axis (also known as per-channel quantization).
Definition: QuantTypes.h:314
int64_t getStorageTypeMax() const
The maximum value that storageType can take.
Definition: QuantTypes.cpp:71
double getScale() const
Gets the scale term.
Definition: QuantTypes.cpp:297
bool isSigned() const
Whether the storage type should be interpreted as a signed quantity (true) or an unsigned value (fals...
Definition: QuantTypes.h:105
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
virtual ParseResult parseRBrace()=0
Parse a } token.
static Type parseUniformType(DialectAsmParser &parser)
Parses a UniformQuantizedType.
Definition: TypeParser.cpp:197
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
static void printAnyQuantizedType(AnyQuantizedType type, DialectAsmPrinter &out)
Helper that prints a AnyQuantizedType.
Definition: TypeParser.cpp:372
T getChecked(llvm::SMLoc loc, ParamsT &&... params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:52
static void printUniformQuantizedType(UniformQuantizedType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedType.
Definition: TypeParser.cpp:383
virtual InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
static constexpr unsigned MaxStorageBits
The maximum number of bits supported for storage types.
Definition: QuantTypes.h:58
Type parseType(DialectAsmParser &parser)
Parses an LLVM dialect type.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
Base class for all quantized types known to this dialect.
Definition: QuantTypes.h:52
virtual ParseResult parseType(Type &result)=0
Parse a type.
int32_t getQuantizedDimension() const
Specifies the dimension of the Tensor&#39;s shape that the scales and zero_points correspond to...
Definition: QuantTypes.cpp:366
ArrayRef< double > getScales() const
Gets the quantization scales.
Definition: QuantTypes.cpp:358
bool hasValue() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:62
ArrayRef< int64_t > getZeroPoints() const
Gets the storage values corresponding to the real value 0 in the affine equation. ...
Definition: QuantTypes.cpp:362
This class represents success/failure for operation parsing.
Definition: OpDefinition.h:36
static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale, int64_t &zeroPoint)
Definition: TypeParser.cpp:165
virtual OptionalParseResult parseOptionalType(Type &result)=0
Parse an optional type.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
static int64_t getDefaultMaximumForInteger(bool isSigned, unsigned integralWidth)
Gets the maximum possible stored by a storageType.
Definition: QuantTypes.h:80
static Value max(ImplicitLocOpBuilder &builder, Value a, Value b)