MLIR  16.0.0git
TypeParser.cpp
Go to the documentation of this file.
1 //===- TypeParser.h - Quantization Type Parser ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/IR/Location.h"
14 #include "mlir/IR/Types.h"
15 #include "llvm/ADT/APFloat.h"
16 #include "llvm/ADT/StringSwitch.h"
17 #include "llvm/Support/Format.h"
18 #include "llvm/Support/MathExtras.h"
19 #include "llvm/Support/SourceMgr.h"
20 #include "llvm/Support/raw_ostream.h"
21 
22 using namespace mlir;
23 using namespace quant;
24 
25 static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
26  auto typeLoc = parser.getCurrentLocation();
27  IntegerType type;
28 
29  // Parse storage type (alpha_ident, integer_literal).
30  StringRef identifier;
31  unsigned storageTypeWidth = 0;
32  OptionalParseResult result = parser.parseOptionalType(type);
33  if (result.has_value()) {
34  if (!succeeded(*result))
35  return nullptr;
36  isSigned = !type.isUnsigned();
37  storageTypeWidth = type.getWidth();
38  } else if (succeeded(parser.parseKeyword(&identifier))) {
39  // Otherwise, this must be an unsigned integer (`u` integer-literal).
40  if (!identifier.consume_front("u")) {
41  parser.emitError(typeLoc, "illegal storage type prefix");
42  return nullptr;
43  }
44  if (identifier.getAsInteger(10, storageTypeWidth)) {
45  parser.emitError(typeLoc, "expected storage type width");
46  return nullptr;
47  }
48  isSigned = false;
49  type = parser.getBuilder().getIntegerType(storageTypeWidth);
50  } else {
51  return nullptr;
52  }
53 
54  if (storageTypeWidth == 0 ||
55  storageTypeWidth > QuantizedType::MaxStorageBits) {
56  parser.emitError(typeLoc, "illegal storage type size: ")
57  << storageTypeWidth;
58  return nullptr;
59  }
60 
61  return type;
62 }
63 
65  IntegerType storageType, bool isSigned,
66  int64_t &storageTypeMin,
67  int64_t &storageTypeMax) {
68  int64_t defaultIntegerMin = QuantizedType::getDefaultMinimumForInteger(
69  isSigned, storageType.getWidth());
70  int64_t defaultIntegerMax = QuantizedType::getDefaultMaximumForInteger(
71  isSigned, storageType.getWidth());
72  if (failed(parser.parseOptionalLess())) {
73  storageTypeMin = defaultIntegerMin;
74  storageTypeMax = defaultIntegerMax;
75  return success();
76  }
77 
78  // Explicit storage min and storage max.
79  SMLoc minLoc = parser.getCurrentLocation(), maxLoc;
80  if (parser.parseInteger(storageTypeMin) || parser.parseColon() ||
81  parser.getCurrentLocation(&maxLoc) ||
82  parser.parseInteger(storageTypeMax) || parser.parseGreater())
83  return failure();
84  if (storageTypeMin < defaultIntegerMin) {
85  return parser.emitError(minLoc, "illegal storage type minimum: ")
86  << storageTypeMin;
87  }
88  if (storageTypeMax > defaultIntegerMax) {
89  return parser.emitError(maxLoc, "illegal storage type maximum: ")
90  << storageTypeMax;
91  }
92  return success();
93 }
94 
96  double &min, double &max) {
97  auto typeLoc = parser.getCurrentLocation();
98  FloatType type;
99 
100  if (failed(parser.parseType(type))) {
101  parser.emitError(typeLoc, "expecting float expressed type");
102  return nullptr;
103  }
104 
105  // Calibrated min and max values.
106  if (parser.parseLess() || parser.parseFloat(min) || parser.parseColon() ||
107  parser.parseFloat(max) || parser.parseGreater()) {
108  parser.emitError(typeLoc, "calibrated values must be present");
109  return nullptr;
110  }
111  return type;
112 }
113 
114 /// Parses an AnyQuantizedType.
115 ///
116 /// any ::= `any<` storage-spec (expressed-type-spec)?`>`
117 /// storage-spec ::= storage-type (`<` storage-range `>`)?
118 /// storage-range ::= integer-literal `:` integer-literal
119 /// storage-type ::= (`i` | `u`) integer-literal
120 /// expressed-type-spec ::= `:` `f` integer-literal
122  IntegerType storageType;
123  FloatType expressedType;
124  unsigned typeFlags = 0;
125  int64_t storageTypeMin;
126  int64_t storageTypeMax;
127 
128  // Type specification.
129  if (parser.parseLess())
130  return nullptr;
131 
132  // Storage type.
133  bool isSigned = false;
134  storageType = parseStorageType(parser, isSigned);
135  if (!storageType) {
136  return nullptr;
137  }
138  if (isSigned) {
139  typeFlags |= QuantizationFlags::Signed;
140  }
141 
142  // Storage type range.
143  if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
144  storageTypeMax)) {
145  return nullptr;
146  }
147 
148  // Optional expressed type.
149  if (succeeded(parser.parseOptionalColon())) {
150  if (parser.parseType(expressedType)) {
151  return nullptr;
152  }
153  }
154 
155  if (parser.parseGreater()) {
156  return nullptr;
157  }
158 
159  return parser.getChecked<AnyQuantizedType>(
160  typeFlags, storageType, expressedType, storageTypeMin, storageTypeMax);
161 }
162 
164  int64_t &zeroPoint) {
165  // scale[:zeroPoint]?
166  // scale.
167  if (parser.parseFloat(scale))
168  return failure();
169 
170  // zero point.
171  zeroPoint = 0;
172  if (failed(parser.parseOptionalColon())) {
173  // Default zero point.
174  return success();
175  }
176 
177  return parser.parseInteger(zeroPoint);
178 }
179 
180 /// Parses a UniformQuantizedType.
181 ///
182 /// uniform_type ::= uniform_per_layer
183 /// | uniform_per_axis
184 /// uniform_per_layer ::= `uniform<` storage-spec expressed-type-spec
185 /// `,` scale-zero `>`
186 /// uniform_per_axis ::= `uniform<` storage-spec expressed-type-spec
187 /// axis-spec `,` scale-zero-list `>`
188 /// storage-spec ::= storage-type (`<` storage-range `>`)?
189 /// storage-range ::= integer-literal `:` integer-literal
190 /// storage-type ::= (`i` | `u`) integer-literal
191 /// expressed-type-spec ::= `:` `f` integer-literal
192 /// axis-spec ::= `:` integer-literal
193 /// scale-zero ::= float-literal `:` integer-literal
194 /// scale-zero-list ::= `{` scale-zero (`,` scale-zero)* `}`
196  IntegerType storageType;
197  FloatType expressedType;
198  unsigned typeFlags = 0;
199  int64_t storageTypeMin;
200  int64_t storageTypeMax;
201  bool isPerAxis = false;
202  int32_t quantizedDimension;
203  SmallVector<double, 1> scales;
204  SmallVector<int64_t, 1> zeroPoints;
205 
206  // Type specification.
207  if (parser.parseLess()) {
208  return nullptr;
209  }
210 
211  // Storage type.
212  bool isSigned = false;
213  storageType = parseStorageType(parser, isSigned);
214  if (!storageType) {
215  return nullptr;
216  }
217  if (isSigned) {
218  typeFlags |= QuantizationFlags::Signed;
219  }
220 
221  // Storage type range.
222  if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
223  storageTypeMax)) {
224  return nullptr;
225  }
226 
227  // Expressed type.
228  if (parser.parseColon() || parser.parseType(expressedType)) {
229  return nullptr;
230  }
231 
232  // Optionally parse quantized dimension for per-axis quantization.
233  if (succeeded(parser.parseOptionalColon())) {
234  if (parser.parseInteger(quantizedDimension))
235  return nullptr;
236  isPerAxis = true;
237  }
238 
239  // Comma leading into range_spec.
240  if (parser.parseComma()) {
241  return nullptr;
242  }
243 
244  // Parameter specification.
245  // For per-axis, ranges are in a {} delimitted list.
246  if (isPerAxis) {
247  if (parser.parseLBrace()) {
248  return nullptr;
249  }
250  }
251 
252  // Parse scales/zeroPoints.
253  SMLoc scaleZPLoc = parser.getCurrentLocation();
254  do {
255  scales.resize(scales.size() + 1);
256  zeroPoints.resize(zeroPoints.size() + 1);
257  if (parseQuantParams(parser, scales.back(), zeroPoints.back())) {
258  return nullptr;
259  }
260  } while (isPerAxis && succeeded(parser.parseOptionalComma()));
261 
262  if (isPerAxis) {
263  if (parser.parseRBrace()) {
264  return nullptr;
265  }
266  }
267 
268  if (parser.parseGreater()) {
269  return nullptr;
270  }
271 
272  if (!isPerAxis && scales.size() > 1) {
273  return (parser.emitError(scaleZPLoc,
274  "multiple scales/zeroPoints provided, but "
275  "quantizedDimension wasn't specified"),
276  nullptr);
277  }
278 
279  if (isPerAxis) {
280  ArrayRef<double> scalesRef(scales.begin(), scales.end());
281  ArrayRef<int64_t> zeroPointsRef(zeroPoints.begin(), zeroPoints.end());
283  typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
284  quantizedDimension, storageTypeMin, storageTypeMax);
285  }
286 
287  return parser.getChecked<UniformQuantizedType>(
288  typeFlags, storageType, expressedType, scales.front(), zeroPoints.front(),
289  storageTypeMin, storageTypeMax);
290 }
291 
292 /// Parses an CalibratedQuantizedType.
293 ///
294 /// calibrated ::= `calibrated<` expressed-spec `>`
295 /// expressed-spec ::= expressed-type `<` calibrated-range `>`
296 /// expressed-type ::= `f` integer-literal
297 /// calibrated-range ::= float-literal `:` float-literal
299  FloatType expressedType;
300  double min;
301  double max;
302 
303  // Type specification.
304  if (parser.parseLess())
305  return nullptr;
306 
307  // Expressed type.
308  expressedType = parseExpressedTypeAndRange(parser, min, max);
309  if (!expressedType) {
310  return nullptr;
311  }
312 
313  if (parser.parseGreater()) {
314  return nullptr;
315  }
316 
317  return parser.getChecked<CalibratedQuantizedType>(expressedType, min, max);
318 }
319 
320 /// Parse a type registered to this dialect.
322  // All types start with an identifier that we switch on.
323  StringRef typeNameSpelling;
324  if (failed(parser.parseKeyword(&typeNameSpelling)))
325  return nullptr;
326 
327  if (typeNameSpelling == "uniform")
328  return parseUniformType(parser);
329  if (typeNameSpelling == "any")
330  return parseAnyType(parser);
331  if (typeNameSpelling == "calibrated")
332  return parseCalibratedType(parser);
333 
334  parser.emitError(parser.getNameLoc(),
335  "unknown quantized type " + typeNameSpelling);
336  return nullptr;
337 }
338 
340  // storage type
341  unsigned storageWidth = type.getStorageTypeIntegralWidth();
342  bool isSigned = type.isSigned();
343  if (isSigned) {
344  out << "i" << storageWidth;
345  } else {
346  out << "u" << storageWidth;
347  }
348 
349  // storageTypeMin and storageTypeMax if not default.
350  int64_t defaultIntegerMin =
351  QuantizedType::getDefaultMinimumForInteger(isSigned, storageWidth);
352  int64_t defaultIntegerMax =
353  QuantizedType::getDefaultMaximumForInteger(isSigned, storageWidth);
354  if (defaultIntegerMin != type.getStorageTypeMin() ||
355  defaultIntegerMax != type.getStorageTypeMax()) {
356  out << "<" << type.getStorageTypeMin() << ":" << type.getStorageTypeMax()
357  << ">";
358  }
359 }
360 
361 static void printQuantParams(double scale, int64_t zeroPoint,
362  DialectAsmPrinter &out) {
363  out << scale;
364  if (zeroPoint != 0) {
365  out << ":" << zeroPoint;
366  }
367 }
368 
369 /// Helper that prints a AnyQuantizedType.
371  DialectAsmPrinter &out) {
372  out << "any<";
373  printStorageType(type, out);
374  if (Type expressedType = type.getExpressedType()) {
375  out << ":" << expressedType;
376  }
377  out << ">";
378 }
379 
380 /// Helper that prints a UniformQuantizedType.
382  DialectAsmPrinter &out) {
383  out << "uniform<";
384  printStorageType(type, out);
385  out << ":" << type.getExpressedType() << ", ";
386 
387  // scheme specific parameters
388  printQuantParams(type.getScale(), type.getZeroPoint(), out);
389  out << ">";
390 }
391 
392 /// Helper that prints a UniformQuantizedPerAxisType.
394  DialectAsmPrinter &out) {
395  out << "uniform<";
396  printStorageType(type, out);
397  out << ":" << type.getExpressedType() << ":";
398  out << type.getQuantizedDimension();
399  out << ", ";
400 
401  // scheme specific parameters
402  ArrayRef<double> scales = type.getScales();
403  ArrayRef<int64_t> zeroPoints = type.getZeroPoints();
404  out << "{";
405  llvm::interleave(
406  llvm::seq<size_t>(0, scales.size()), out,
407  [&](size_t index) {
408  printQuantParams(scales[index], zeroPoints[index], out);
409  },
410  ",");
411  out << "}>";
412 }
413 
414 /// Helper that prints a CalibratedQuantizedType.
416  DialectAsmPrinter &out) {
417  out << "calibrated<" << type.getExpressedType();
418  out << "<" << type.getMin() << ":" << type.getMax() << ">";
419  out << ">";
420 }
421 
422 /// Print a type registered to this dialect.
424  if (auto anyType = type.dyn_cast<AnyQuantizedType>())
425  printAnyQuantizedType(anyType, os);
426  else if (auto uniformType = type.dyn_cast<UniformQuantizedType>())
427  printUniformQuantizedType(uniformType, os);
428  else if (auto perAxisType = type.dyn_cast<UniformQuantizedPerAxisType>())
429  printUniformQuantizedPerAxisType(perAxisType, os);
430  else if (auto calibratedType = type.dyn_cast<CalibratedQuantizedType>())
431  printCalibratedQuantizedType(calibratedType, os);
432  else
433  llvm_unreachable("Unhandled quantized type");
434 }
Include the generated interface declarations.
int64_t getStorageTypeMin() const
The minimum value that storageType can take.
Definition: QuantTypes.cpp:67
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedPerAxisType.
Definition: TypeParser.cpp:393
static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser, double &min, double &max)
Definition: TypeParser.cpp:95
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
static int64_t getDefaultMinimumForInteger(bool isSigned, unsigned integralWidth)
Gets the minimum possible stored by a storageType.
Definition: QuantTypes.h:70
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
Represents a family of uniform, quantized types.
Definition: QuantTypes.h:256
virtual ParseResult parseLBrace()=0
Parse a { token.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
A quantized type that infers its range from given min/max values.
Definition: QuantTypes.h:383
static void printCalibratedQuantizedType(CalibratedQuantizedType type, DialectAsmPrinter &out)
Helper that prints a CalibratedQuantizedType.
Definition: TypeParser.cpp:415
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseComma()=0
Parse a , token.
static Type parseCalibratedType(DialectAsmParser &parser)
Parses an CalibratedQuantizedType.
Definition: TypeParser.cpp:298
virtual ParseResult parseFloat(double &result)=0
Parse a floating point value from the stream.
static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned)
Definition: TypeParser.cpp:25
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
static void printStorageType(QuantizedType type, DialectAsmPrinter &out)
Definition: TypeParser.cpp:339
virtual ParseResult parseColon()=0
Parse a : token.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
static void printQuantParams(double scale, int64_t zeroPoint, DialectAsmPrinter &out)
Definition: TypeParser.cpp:361
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalLess()=0
Parse a &#39;<&#39; token if present.
unsigned getStorageTypeIntegralWidth() const
Gets the integral bit width that the underlying storage type can exactly represent.
Definition: QuantTypes.cpp:75
virtual ParseResult parseGreater()=0
Parse a &#39;>&#39; token.
A quantized type that maps storage to/from expressed types in an unspecified way. ...
Definition: QuantTypes.h:197
static Type parseAnyType(DialectAsmParser &parser)
Parses an AnyQuantizedType.
Definition: TypeParser.cpp:121
U dyn_cast() const
Definition: Types.h:270
static ParseResult parseStorageRange(DialectAsmParser &parser, IntegerType storageType, bool isSigned, int64_t &storageTypeMin, int64_t &storageTypeMax)
Definition: TypeParser.cpp:64
int64_t getZeroPoint() const
Gets the storage value corresponding to the real value 0 in the affine equation.
Definition: QuantTypes.cpp:299
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:58
virtual ParseResult parseLess()=0
Parse a &#39;<&#39; token.
Represents per-axis (also known as per-channel quantization).
Definition: QuantTypes.h:314
int64_t getStorageTypeMax() const
The maximum value that storageType can take.
Definition: QuantTypes.cpp:71
double getScale() const
Gets the scale term.
Definition: QuantTypes.cpp:297
bool isSigned() const
Whether the storage type should be interpreted as a signed quantity (true) or an unsigned value (fals...
Definition: QuantTypes.h:105
Type parseType(llvm::StringRef typeStr, MLIRContext *context)
This parses a single MLIR type to an MLIR context if it was valid.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
virtual ParseResult parseRBrace()=0
Parse a } token.
static Type parseUniformType(DialectAsmParser &parser)
Parses a UniformQuantizedType.
Definition: TypeParser.cpp:195
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
static void printAnyQuantizedType(AnyQuantizedType type, DialectAsmPrinter &out)
Helper that prints a AnyQuantizedType.
Definition: TypeParser.cpp:370
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:37
static void printUniformQuantizedType(UniformQuantizedType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedType.
Definition: TypeParser.cpp:381
static constexpr unsigned MaxStorageBits
The maximum number of bits supported for storage types.
Definition: QuantTypes.h:58
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
Base class for all quantized types known to this dialect.
Definition: QuantTypes.h:52
virtual ParseResult parseType(Type &result)=0
Parse a type.
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:47
int32_t getQuantizedDimension() const
Specifies the dimension of the Tensor&#39;s shape that the scales and zero_points correspond to...
Definition: QuantTypes.cpp:366
ArrayRef< double > getScales() const
Gets the quantization scales.
Definition: QuantTypes.cpp:358
ArrayRef< int64_t > getZeroPoints() const
Gets the storage values corresponding to the real value 0 in the affine equation. ...
Definition: QuantTypes.cpp:362
This class represents success/failure for parsing-like operations that find it important to chain tog...
static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale, int64_t &zeroPoint)
Definition: TypeParser.cpp:163
virtual OptionalParseResult parseOptionalType(Type &result)=0
Parse an optional type.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static int64_t getDefaultMaximumForInteger(bool isSigned, unsigned integralWidth)
Gets the maximum possible stored by a storageType.
Definition: QuantTypes.h:80