15 #include "llvm/ADT/APFloat.h"
16 #include "llvm/Support/Format.h"
17 #include "llvm/Support/MathExtras.h"
18 #include "llvm/Support/SourceMgr.h"
19 #include "llvm/Support/raw_ostream.h"
22 using namespace quant;
30 unsigned storageTypeWidth = 0;
33 if (!succeeded(*result))
35 isSigned = !type.isUnsigned();
36 storageTypeWidth = type.getWidth();
37 }
else if (succeeded(parser.
parseKeyword(&identifier))) {
39 if (!identifier.consume_front(
"u")) {
40 parser.
emitError(typeLoc,
"illegal storage type prefix");
43 if (identifier.getAsInteger(10, storageTypeWidth)) {
44 parser.
emitError(typeLoc,
"expected storage type width");
53 if (storageTypeWidth == 0 ||
55 parser.
emitError(typeLoc,
"illegal storage type size: ")
64 IntegerType storageType,
bool isSigned,
65 int64_t &storageTypeMin,
66 int64_t &storageTypeMax) {
68 isSigned, storageType.getWidth());
70 isSigned, storageType.getWidth());
72 storageTypeMin = defaultIntegerMin;
73 storageTypeMax = defaultIntegerMax;
83 if (storageTypeMin < defaultIntegerMin) {
84 return parser.
emitError(minLoc,
"illegal storage type minimum: ")
87 if (storageTypeMax > defaultIntegerMax) {
88 return parser.
emitError(maxLoc,
"illegal storage type maximum: ")
95 double &
min,
double &
max) {
100 parser.
emitError(typeLoc,
"expecting float expressed type");
107 parser.
emitError(typeLoc,
"calibrated values must be present");
121 IntegerType storageType;
123 unsigned typeFlags = 0;
124 int64_t storageTypeMin;
125 int64_t storageTypeMax;
132 bool isSigned =
false;
159 typeFlags, storageType, expressedType, storageTypeMin, storageTypeMax);
163 int64_t &zeroPoint) {
195 IntegerType storageType;
197 unsigned typeFlags = 0;
198 int64_t storageTypeMin;
199 int64_t storageTypeMax;
200 bool isPerAxis =
false;
201 int32_t quantizedDimension;
211 bool isSigned =
false;
254 scales.resize(scales.size() + 1);
255 zeroPoints.resize(zeroPoints.size() + 1);
271 if (!isPerAxis && scales.size() > 1) {
273 "multiple scales/zeroPoints provided, but "
274 "quantizedDimension wasn't specified"),
282 typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
283 quantizedDimension, storageTypeMin, storageTypeMax);
287 typeFlags, storageType, expressedType, scales.front(), zeroPoints.front(),
288 storageTypeMin, storageTypeMax);
308 if (!expressedType) {
322 StringRef typeNameSpelling;
326 if (typeNameSpelling ==
"uniform")
328 if (typeNameSpelling ==
"any")
330 if (typeNameSpelling ==
"calibrated")
334 "unknown quantized type " + typeNameSpelling);
343 out <<
"i" << storageWidth;
345 out <<
"u" << storageWidth;
358 if (zeroPoint != 0) {
359 out <<
":" << zeroPoint;
368 if (
Type expressedType = type.getExpressedType()) {
369 out <<
":" << expressedType;
379 out <<
":" << type.getExpressedType() <<
", ";
391 out <<
":" << type.getExpressedType() <<
":";
400 llvm::seq<size_t>(0, scales.size()), out,
402 printQuantParams(scales[index], zeroPoints[index], out);
411 out <<
"calibrated<" << type.getExpressedType();
412 out <<
"<" << type.
getMin() <<
":" << type.
getMax() <<
">";
418 if (
auto anyType = llvm::dyn_cast<AnyQuantizedType>(type))
420 else if (
auto uniformType = llvm::dyn_cast<UniformQuantizedType>(type))
422 else if (
auto perAxisType = llvm::dyn_cast<UniformQuantizedPerAxisType>(type))
424 else if (
auto calibratedType = llvm::dyn_cast<CalibratedQuantizedType>(type))
427 llvm_unreachable(
"Unhandled quantized type");
static void printAnyQuantizedType(AnyQuantizedType type, DialectAsmPrinter &out)
Helper that prints a AnyQuantizedType.
static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser, double &min, double &max)
static Type parseUniformType(DialectAsmParser &parser)
Parses a UniformQuantizedType.
static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned)
static Type parseAnyType(DialectAsmParser &parser)
Parses an AnyQuantizedType.
static void printStorageType(QuantizedType type, DialectAsmPrinter &out)
static void printQuantParams(double scale, int64_t zeroPoint, DialectAsmPrinter &out)
static Type parseCalibratedType(DialectAsmParser &parser)
Parses an CalibratedQuantizedType.
static ParseResult parseStorageRange(DialectAsmParser &parser, IntegerType storageType, bool isSigned, int64_t &storageTypeMin, int64_t &storageTypeMax)
static void printCalibratedQuantizedType(CalibratedQuantizedType type, DialectAsmPrinter &out)
Helper that prints a CalibratedQuantizedType.
static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale, int64_t &zeroPoint)
static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedPerAxisType.
static void printUniformQuantizedType(UniformQuantizedType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedType.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
virtual ParseResult parseLBrace()=0
Parse a { token.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual OptionalParseResult parseOptionalType(Type &result)=0
Parse an optional type.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseFloat(double &result)=0
Parse a floating point value from the stream.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
IntegerType getIntegerType(unsigned width)
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
A quantized type that maps storage to/from expressed types in an unspecified way.
A quantized type that infers its range from given min/max values.
Base class for all quantized types known to this dialect.
static constexpr unsigned MaxStorageBits
The maximum number of bits supported for storage types.
bool hasStorageTypeBounds() const
Return whether the storage type has explicit min or max boundaries different from the minimum and max...
bool isSigned() const
Whether the storage type should be interpreted as a signed quantity (true) or an unsigned value (fals...
int64_t getStorageTypeMax() const
The maximum value that storageType can take.
static int64_t getDefaultMaximumForInteger(bool isSigned, unsigned integralWidth)
Gets the maximum possible stored by a storageType.
unsigned getStorageTypeIntegralWidth() const
Gets the integral bit width that the underlying storage type can exactly represent.
int64_t getStorageTypeMin() const
The minimum value that storageType can take.
static int64_t getDefaultMinimumForInteger(bool isSigned, unsigned integralWidth)
Gets the minimum possible stored by a storageType.
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
Include the generated interface declarations.
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.