MLIR  20.0.0git
TypeParser.cpp
Go to the documentation of this file.
1 //===- TypeParser.h - Quantization Type Parser ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/IR/Location.h"
14 #include "mlir/IR/Types.h"
15 #include "llvm/ADT/APFloat.h"
16 #include "llvm/Support/Format.h"
17 #include "llvm/Support/MathExtras.h"
18 #include "llvm/Support/SourceMgr.h"
19 #include "llvm/Support/raw_ostream.h"
20 
21 using namespace mlir;
22 using namespace quant;
23 
24 static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
25  auto typeLoc = parser.getCurrentLocation();
26  IntegerType type;
27 
28  // Parse storage type (alpha_ident, integer_literal).
29  StringRef identifier;
30  unsigned storageTypeWidth = 0;
31  OptionalParseResult result = parser.parseOptionalType(type);
32  if (result.has_value()) {
33  if (!succeeded(*result))
34  return nullptr;
35  isSigned = !type.isUnsigned();
36  storageTypeWidth = type.getWidth();
37  } else if (succeeded(parser.parseKeyword(&identifier))) {
38  // Otherwise, this must be an unsigned integer (`u` integer-literal).
39  if (!identifier.consume_front("u")) {
40  parser.emitError(typeLoc, "illegal storage type prefix");
41  return nullptr;
42  }
43  if (identifier.getAsInteger(10, storageTypeWidth)) {
44  parser.emitError(typeLoc, "expected storage type width");
45  return nullptr;
46  }
47  isSigned = false;
48  type = parser.getBuilder().getIntegerType(storageTypeWidth);
49  } else {
50  return nullptr;
51  }
52 
53  if (storageTypeWidth == 0 ||
54  storageTypeWidth > QuantizedType::MaxStorageBits) {
55  parser.emitError(typeLoc, "illegal storage type size: ")
56  << storageTypeWidth;
57  return nullptr;
58  }
59 
60  return type;
61 }
62 
63 static ParseResult parseStorageRange(DialectAsmParser &parser,
64  IntegerType storageType, bool isSigned,
65  int64_t &storageTypeMin,
66  int64_t &storageTypeMax) {
67  int64_t defaultIntegerMin = QuantizedType::getDefaultMinimumForInteger(
68  isSigned, storageType.getWidth());
69  int64_t defaultIntegerMax = QuantizedType::getDefaultMaximumForInteger(
70  isSigned, storageType.getWidth());
71  if (failed(parser.parseOptionalLess())) {
72  storageTypeMin = defaultIntegerMin;
73  storageTypeMax = defaultIntegerMax;
74  return success();
75  }
76 
77  // Explicit storage min and storage max.
78  SMLoc minLoc = parser.getCurrentLocation(), maxLoc;
79  if (parser.parseInteger(storageTypeMin) || parser.parseColon() ||
80  parser.getCurrentLocation(&maxLoc) ||
81  parser.parseInteger(storageTypeMax) || parser.parseGreater())
82  return failure();
83  if (storageTypeMin < defaultIntegerMin) {
84  return parser.emitError(minLoc, "illegal storage type minimum: ")
85  << storageTypeMin;
86  }
87  if (storageTypeMax > defaultIntegerMax) {
88  return parser.emitError(maxLoc, "illegal storage type maximum: ")
89  << storageTypeMax;
90  }
91  return success();
92 }
93 
95  double &min, double &max) {
96  auto typeLoc = parser.getCurrentLocation();
97  FloatType type;
98 
99  if (failed(parser.parseType(type))) {
100  parser.emitError(typeLoc, "expecting float expressed type");
101  return nullptr;
102  }
103 
104  // Calibrated min and max values.
105  if (parser.parseLess() || parser.parseFloat(min) || parser.parseColon() ||
106  parser.parseFloat(max) || parser.parseGreater()) {
107  parser.emitError(typeLoc, "calibrated values must be present");
108  return nullptr;
109  }
110  return type;
111 }
112 
113 /// Parses an AnyQuantizedType.
114 ///
115 /// any ::= `any<` storage-spec (expressed-type-spec)?`>`
116 /// storage-spec ::= storage-type (`<` storage-range `>`)?
117 /// storage-range ::= integer-literal `:` integer-literal
118 /// storage-type ::= (`i` | `u`) integer-literal
119 /// expressed-type-spec ::= `:` `f` integer-literal
121  IntegerType storageType;
122  FloatType expressedType;
123  unsigned typeFlags = 0;
124  int64_t storageTypeMin;
125  int64_t storageTypeMax;
126 
127  // Type specification.
128  if (parser.parseLess())
129  return nullptr;
130 
131  // Storage type.
132  bool isSigned = false;
133  storageType = parseStorageType(parser, isSigned);
134  if (!storageType) {
135  return nullptr;
136  }
137  if (isSigned) {
138  typeFlags |= QuantizationFlags::Signed;
139  }
140 
141  // Storage type range.
142  if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
143  storageTypeMax)) {
144  return nullptr;
145  }
146 
147  // Optional expressed type.
148  if (succeeded(parser.parseOptionalColon())) {
149  if (parser.parseType(expressedType)) {
150  return nullptr;
151  }
152  }
153 
154  if (parser.parseGreater()) {
155  return nullptr;
156  }
157 
158  return parser.getChecked<AnyQuantizedType>(
159  typeFlags, storageType, expressedType, storageTypeMin, storageTypeMax);
160 }
161 
162 static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale,
163  int64_t &zeroPoint) {
164  // scale[:zeroPoint]?
165  // scale.
166  if (parser.parseFloat(scale))
167  return failure();
168 
169  // zero point.
170  zeroPoint = 0;
171  if (failed(parser.parseOptionalColon())) {
172  // Default zero point.
173  return success();
174  }
175 
176  return parser.parseInteger(zeroPoint);
177 }
178 
179 /// Parses a UniformQuantizedType.
180 ///
181 /// uniform_type ::= uniform_per_layer
182 /// | uniform_per_axis
183 /// uniform_per_layer ::= `uniform<` storage-spec expressed-type-spec
184 /// `,` scale-zero `>`
185 /// uniform_per_axis ::= `uniform<` storage-spec expressed-type-spec
186 /// axis-spec `,` scale-zero-list `>`
187 /// storage-spec ::= storage-type (`<` storage-range `>`)?
188 /// storage-range ::= integer-literal `:` integer-literal
189 /// storage-type ::= (`i` | `u`) integer-literal
190 /// expressed-type-spec ::= `:` `f` integer-literal
191 /// axis-spec ::= `:` integer-literal
192 /// scale-zero ::= float-literal `:` integer-literal
193 /// scale-zero-list ::= `{` scale-zero (`,` scale-zero)* `}`
195  IntegerType storageType;
196  FloatType expressedType;
197  unsigned typeFlags = 0;
198  int64_t storageTypeMin;
199  int64_t storageTypeMax;
200  bool isPerAxis = false;
201  int32_t quantizedDimension;
202  SmallVector<double, 1> scales;
203  SmallVector<int64_t, 1> zeroPoints;
204 
205  // Type specification.
206  if (parser.parseLess()) {
207  return nullptr;
208  }
209 
210  // Storage type.
211  bool isSigned = false;
212  storageType = parseStorageType(parser, isSigned);
213  if (!storageType) {
214  return nullptr;
215  }
216  if (isSigned) {
217  typeFlags |= QuantizationFlags::Signed;
218  }
219 
220  // Storage type range.
221  if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
222  storageTypeMax)) {
223  return nullptr;
224  }
225 
226  // Expressed type.
227  if (parser.parseColon() || parser.parseType(expressedType)) {
228  return nullptr;
229  }
230 
231  // Optionally parse quantized dimension for per-axis quantization.
232  if (succeeded(parser.parseOptionalColon())) {
233  if (parser.parseInteger(quantizedDimension))
234  return nullptr;
235  isPerAxis = true;
236  }
237 
238  // Comma leading into range_spec.
239  if (parser.parseComma()) {
240  return nullptr;
241  }
242 
243  // Parameter specification.
244  // For per-axis, ranges are in a {} delimitted list.
245  if (isPerAxis) {
246  if (parser.parseLBrace()) {
247  return nullptr;
248  }
249  }
250 
251  // Parse scales/zeroPoints.
252  SMLoc scaleZPLoc = parser.getCurrentLocation();
253  do {
254  scales.resize(scales.size() + 1);
255  zeroPoints.resize(zeroPoints.size() + 1);
256  if (parseQuantParams(parser, scales.back(), zeroPoints.back())) {
257  return nullptr;
258  }
259  } while (isPerAxis && succeeded(parser.parseOptionalComma()));
260 
261  if (isPerAxis) {
262  if (parser.parseRBrace()) {
263  return nullptr;
264  }
265  }
266 
267  if (parser.parseGreater()) {
268  return nullptr;
269  }
270 
271  if (!isPerAxis && scales.size() > 1) {
272  return (parser.emitError(scaleZPLoc,
273  "multiple scales/zeroPoints provided, but "
274  "quantizedDimension wasn't specified"),
275  nullptr);
276  }
277 
278  if (isPerAxis) {
279  ArrayRef<double> scalesRef(scales.begin(), scales.end());
280  ArrayRef<int64_t> zeroPointsRef(zeroPoints.begin(), zeroPoints.end());
282  typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
283  quantizedDimension, storageTypeMin, storageTypeMax);
284  }
285 
286  return parser.getChecked<UniformQuantizedType>(
287  typeFlags, storageType, expressedType, scales.front(), zeroPoints.front(),
288  storageTypeMin, storageTypeMax);
289 }
290 
291 /// Parses an CalibratedQuantizedType.
292 ///
293 /// calibrated ::= `calibrated<` expressed-spec `>`
294 /// expressed-spec ::= expressed-type `<` calibrated-range `>`
295 /// expressed-type ::= `f` integer-literal
296 /// calibrated-range ::= float-literal `:` float-literal
298  FloatType expressedType;
299  double min;
300  double max;
301 
302  // Type specification.
303  if (parser.parseLess())
304  return nullptr;
305 
306  // Expressed type.
307  expressedType = parseExpressedTypeAndRange(parser, min, max);
308  if (!expressedType) {
309  return nullptr;
310  }
311 
312  if (parser.parseGreater()) {
313  return nullptr;
314  }
315 
316  return parser.getChecked<CalibratedQuantizedType>(expressedType, min, max);
317 }
318 
319 /// Parse a type registered to this dialect.
321  // All types start with an identifier that we switch on.
322  StringRef typeNameSpelling;
323  if (failed(parser.parseKeyword(&typeNameSpelling)))
324  return nullptr;
325 
326  if (typeNameSpelling == "uniform")
327  return parseUniformType(parser);
328  if (typeNameSpelling == "any")
329  return parseAnyType(parser);
330  if (typeNameSpelling == "calibrated")
331  return parseCalibratedType(parser);
332 
333  parser.emitError(parser.getNameLoc(),
334  "unknown quantized type " + typeNameSpelling);
335  return nullptr;
336 }
337 
339  // storage type
340  unsigned storageWidth = type.getStorageTypeIntegralWidth();
341  bool isSigned = type.isSigned();
342  if (isSigned) {
343  out << "i" << storageWidth;
344  } else {
345  out << "u" << storageWidth;
346  }
347 
348  // storageTypeMin and storageTypeMax if not default.
349  if (type.hasStorageTypeBounds()) {
350  out << "<" << type.getStorageTypeMin() << ":" << type.getStorageTypeMax()
351  << ">";
352  }
353 }
354 
355 static void printQuantParams(double scale, int64_t zeroPoint,
356  DialectAsmPrinter &out) {
357  out << scale;
358  if (zeroPoint != 0) {
359  out << ":" << zeroPoint;
360  }
361 }
362 
363 /// Helper that prints a AnyQuantizedType.
365  DialectAsmPrinter &out) {
366  out << "any<";
367  printStorageType(type, out);
368  if (Type expressedType = type.getExpressedType()) {
369  out << ":" << expressedType;
370  }
371  out << ">";
372 }
373 
374 /// Helper that prints a UniformQuantizedType.
376  DialectAsmPrinter &out) {
377  out << "uniform<";
378  printStorageType(type, out);
379  out << ":" << type.getExpressedType() << ", ";
380 
381  // scheme specific parameters
382  printQuantParams(type.getScale(), type.getZeroPoint(), out);
383  out << ">";
384 }
385 
386 /// Helper that prints a UniformQuantizedPerAxisType.
388  DialectAsmPrinter &out) {
389  out << "uniform<";
390  printStorageType(type, out);
391  out << ":" << type.getExpressedType() << ":";
392  out << type.getQuantizedDimension();
393  out << ", ";
394 
395  // scheme specific parameters
396  ArrayRef<double> scales = type.getScales();
397  ArrayRef<int64_t> zeroPoints = type.getZeroPoints();
398  out << "{";
399  llvm::interleave(
400  llvm::seq<size_t>(0, scales.size()), out,
401  [&](size_t index) {
402  printQuantParams(scales[index], zeroPoints[index], out);
403  },
404  ",");
405  out << "}>";
406 }
407 
408 /// Helper that prints a CalibratedQuantizedType.
410  DialectAsmPrinter &out) {
411  out << "calibrated<" << type.getExpressedType();
412  out << "<" << type.getMin() << ":" << type.getMax() << ">";
413  out << ">";
414 }
415 
416 /// Print a type registered to this dialect.
417 void QuantDialect::printType(Type type, DialectAsmPrinter &os) const {
418  if (auto anyType = llvm::dyn_cast<AnyQuantizedType>(type))
419  printAnyQuantizedType(anyType, os);
420  else if (auto uniformType = llvm::dyn_cast<UniformQuantizedType>(type))
421  printUniformQuantizedType(uniformType, os);
422  else if (auto perAxisType = llvm::dyn_cast<UniformQuantizedPerAxisType>(type))
423  printUniformQuantizedPerAxisType(perAxisType, os);
424  else if (auto calibratedType = llvm::dyn_cast<CalibratedQuantizedType>(type))
425  printCalibratedQuantizedType(calibratedType, os);
426  else
427  llvm_unreachable("Unhandled quantized type");
428 }
static void printAnyQuantizedType(AnyQuantizedType type, DialectAsmPrinter &out)
Helper that prints a AnyQuantizedType.
Definition: TypeParser.cpp:364
static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser, double &min, double &max)
Definition: TypeParser.cpp:94
static Type parseUniformType(DialectAsmParser &parser)
Parses a UniformQuantizedType.
Definition: TypeParser.cpp:194
static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned)
Definition: TypeParser.cpp:24
static Type parseAnyType(DialectAsmParser &parser)
Parses an AnyQuantizedType.
Definition: TypeParser.cpp:120
static void printStorageType(QuantizedType type, DialectAsmPrinter &out)
Definition: TypeParser.cpp:338
static void printQuantParams(double scale, int64_t zeroPoint, DialectAsmPrinter &out)
Definition: TypeParser.cpp:355
static Type parseCalibratedType(DialectAsmParser &parser)
Parses an CalibratedQuantizedType.
Definition: TypeParser.cpp:297
static ParseResult parseStorageRange(DialectAsmParser &parser, IntegerType storageType, bool isSigned, int64_t &storageTypeMin, int64_t &storageTypeMax)
Definition: TypeParser.cpp:63
static void printCalibratedQuantizedType(CalibratedQuantizedType type, DialectAsmPrinter &out)
Helper that prints a CalibratedQuantizedType.
Definition: TypeParser.cpp:409
static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale, int64_t &zeroPoint)
Definition: TypeParser.cpp:162
static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedPerAxisType.
Definition: TypeParser.cpp:387
static void printUniformQuantizedType(UniformQuantizedType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedType.
Definition: TypeParser.cpp:375
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
virtual ParseResult parseLBrace()=0
Parse a { token.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual OptionalParseResult parseOptionalType(Type &result)=0
Parse an optional type.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseFloat(double &result)=0
Parse a floating point value from the stream.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
A quantized type that maps storage to/from expressed types in an unspecified way.
Definition: QuantTypes.h:200
A quantized type that infers its range from given min/max values.
Definition: QuantTypes.h:391
Base class for all quantized types known to this dialect.
Definition: QuantTypes.h:49
static constexpr unsigned MaxStorageBits
The maximum number of bits supported for storage types.
Definition: QuantTypes.h:55
bool hasStorageTypeBounds() const
Return whether the storage type has explicit min or max boundaries different from the minimum and max...
Definition: QuantTypes.cpp:92
bool isSigned() const
Whether the storage type should be interpreted as a signed quantity (true) or an unsigned value (fals...
Definition: QuantTypes.h:102
int64_t getStorageTypeMax() const
The maximum value that storageType can take.
Definition: QuantTypes.cpp:88
static int64_t getDefaultMaximumForInteger(bool isSigned, unsigned integralWidth)
Gets the maximum possible stored by a storageType.
Definition: QuantTypes.h:77
unsigned getStorageTypeIntegralWidth() const
Gets the integral bit width that the underlying storage type can exactly represent.
Definition: QuantTypes.cpp:103
int64_t getStorageTypeMin() const
The minimum value that storageType can take.
Definition: QuantTypes.cpp:84
static int64_t getDefaultMinimumForInteger(bool isSigned, unsigned integralWidth)
Gets the minimum possible stored by a storageType.
Definition: QuantTypes.h:67
Represents per-axis (also known as per-channel quantization).
Definition: QuantTypes.h:321
int32_t getQuantizedDimension() const
Specifies the dimension of the Tensor's shape that the scales and zero_points correspond to.
Definition: QuantTypes.cpp:409
ArrayRef< int64_t > getZeroPoints() const
Gets the storage values corresponding to the real value 0 in the affine equation.
Definition: QuantTypes.cpp:405
ArrayRef< double > getScales() const
Gets the quantization scales.
Definition: QuantTypes.cpp:401
Represents a family of uniform, quantized types.
Definition: QuantTypes.h:261
double getScale() const
Gets the scale term.
Definition: QuantTypes.cpp:332
int64_t getZeroPoint() const
Gets the storage value corresponding to the real value 0 in the affine equation.
Definition: QuantTypes.cpp:334
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
Include the generated interface declarations.
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.