14 #include "llvm/ADT/APFloat.h"
17 using namespace quant;
25 unsigned storageTypeWidth = 0;
28 if (!succeeded(*result))
30 isSigned = !type.isUnsigned();
31 storageTypeWidth = type.getWidth();
32 }
else if (succeeded(parser.
parseKeyword(&identifier))) {
34 if (!identifier.consume_front(
"u")) {
35 parser.
emitError(typeLoc,
"illegal storage type prefix");
38 if (identifier.getAsInteger(10, storageTypeWidth)) {
39 parser.
emitError(typeLoc,
"expected storage type width");
48 if (storageTypeWidth == 0 ||
50 parser.
emitError(typeLoc,
"illegal storage type size: ")
59 IntegerType storageType,
bool isSigned,
60 int64_t &storageTypeMin,
61 int64_t &storageTypeMax) {
63 isSigned, storageType.getWidth());
65 isSigned, storageType.getWidth());
67 storageTypeMin = defaultIntegerMin;
68 storageTypeMax = defaultIntegerMax;
78 if (storageTypeMin < defaultIntegerMin) {
79 return parser.
emitError(minLoc,
"illegal storage type minimum: ")
82 if (storageTypeMax > defaultIntegerMax) {
83 return parser.
emitError(maxLoc,
"illegal storage type maximum: ")
90 double &
min,
double &
max) {
95 parser.
emitError(typeLoc,
"expecting float expressed type");
102 parser.
emitError(typeLoc,
"calibrated values must be present");
116 IntegerType storageType;
117 FloatType expressedType;
118 unsigned typeFlags = 0;
119 int64_t storageTypeMin;
120 int64_t storageTypeMax;
127 bool isSigned =
false;
154 typeFlags, storageType, expressedType, storageTypeMin, storageTypeMax);
162 Type expressedType,
double scale) {
163 auto floatType = cast<FloatType>(expressedType);
165 APFloat::getSmallest(floatType.getFloatSemantics()).convertToDouble();
167 APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble();
168 if (scale < minScale || scale > maxScale)
169 return emitError() <<
"scale " << scale <<
" out of expressed type range ["
170 << minScale <<
", " << maxScale <<
"]";
179 Type expressedType,
double &scale,
180 int64_t &zeroPoint) {
188 expressedType, scale))) {
215 auto parseBlockSizeElements = [&]() -> ParseResult {
216 quantizedDimensions.resize(quantizedDimensions.size() + 1);
217 blockSizes.resize(blockSizes.size() + 1);
252 if (prevDims == newDims)
255 <<
"tensor literal is invalid; ranks are not consistent "
263 auto parseOneElement = [&]() -> ParseResult {
267 zeroPoints, thisDims))
270 zeroPoints.resize(zeroPoints.size() + 1);
271 scales.resize(scales.size() + 1);
273 zeroPoints.back())) {
279 return checkDims(newDims, thisDims);
291 dims.push_back(size);
292 dims.append(newDims.begin(), newDims.end());
325 IntegerType storageType;
326 FloatType expressedType;
327 unsigned typeFlags = 0;
328 int64_t storageTypeMin;
329 int64_t storageTypeMax;
330 bool isPerAxis =
false;
331 bool isSubChannel =
false;
343 bool isSigned =
false;
374 quantizedDimensions.resize(1);
387 bool isPerTensor = !isPerAxis && !isSubChannel;
390 zeroPoints.resize(zeroPoints.size() + 1);
391 scales.resize(scales.size() + 1);
393 zeroPoints.back())) {
411 typeFlags, storageType, expressedType, scales, zeroPoints,
412 quantizedDimensions[0], storageTypeMin, storageTypeMax);
413 }
else if (isSubChannel) {
415 llvm::to_vector(llvm::map_range(scales, [&](
double scale) -> APFloat {
416 APFloat apFloatScale(scale);
418 apFloatScale.convert(expressedType.getFloatSemantics(),
419 APFloat::rmNearestTiesToEven, &unused);
423 llvm::map_range(zeroPoints, [&](int64_t zeroPoint) -> APInt {
424 return APInt(storageType.getIntOrFloatBitWidth(), zeroPoint);
431 typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
432 quantizedDimensions, blockSizes, storageTypeMin, storageTypeMax);
436 typeFlags, storageType, expressedType, scales.front(), zeroPoints.front(),
437 storageTypeMin, storageTypeMax);
447 FloatType expressedType;
457 if (!expressedType) {
471 StringRef typeNameSpelling;
475 if (typeNameSpelling ==
"uniform")
477 if (typeNameSpelling ==
"any")
479 if (typeNameSpelling ==
"calibrated")
483 "unknown quantized type " + typeNameSpelling);
492 out <<
"i" << storageWidth;
494 out <<
"u" << storageWidth;
507 if (zeroPoint != 0) {
508 out <<
":" << zeroPoint;
516 llvm::interleaveComma(
517 llvm::seq<size_t>(0, blockSizeInfo.size()), out, [&](
size_t index) {
518 out << blockSizeInfo[index].first <<
":" << blockSizeInfo[index].second;
528 if (
Type expressedType = type.getExpressedType()) {
529 out <<
":" << expressedType;
539 out <<
":" << type.getExpressedType() <<
", ";
551 out <<
":" << type.getExpressedType() <<
":";
560 llvm::seq<size_t>(0, scales.size()), out,
562 printQuantParams(scales[index], zeroPoints[index], out);
585 int64_t rank = shape.size();
587 unsigned openBrackets = 0;
589 auto incrementCounterAndDelimit = [&]() {
591 for (
unsigned i = rank - 1; i > 0; --i) {
592 if (counter[i] >= shape[i]) {
601 for (
unsigned idx = 0, e = scales.size(); idx < e; ++idx) {
604 while (openBrackets++ < rank)
608 if (zeroPoints[idx] != 0) {
609 out <<
":" << zeroPoints[idx];
611 incrementCounterAndDelimit();
613 while (openBrackets-- > 0)
623 out <<
":" << type.getExpressedType() <<
":";
639 out <<
"calibrated<" << type.getExpressedType();
640 out <<
"<" << type.
getMin() <<
":" << type.
getMax() <<
">";
646 if (
auto anyType = llvm::dyn_cast<AnyQuantizedType>(type))
648 else if (
auto uniformType = llvm::dyn_cast<UniformQuantizedType>(type))
650 else if (
auto perAxisType = llvm::dyn_cast<UniformQuantizedPerAxisType>(type))
652 else if (
auto perAxisType =
653 llvm::dyn_cast<UniformQuantizedSubChannelType>(type))
655 else if (
auto calibratedType = llvm::dyn_cast<CalibratedQuantizedType>(type))
658 llvm_unreachable(
"Unhandled quantized type");
static void printBlockSizeInfo(ArrayRef< std::pair< int32_t, int64_t >> blockSizeInfo, DialectAsmPrinter &out)
static ParseResult parseQuantParams(DialectAsmParser &parser, Type expressedType, double &scale, int64_t &zeroPoint)
Parses a quantization parameter, which is either a scale value (float) or a scale-zero point pair (fl...
void printDenseQuantizationParameters(ArrayRef< APFloat > scales, ArrayRef< APInt > zeroPoints, ArrayRef< int64_t > shape, DialectAsmPrinter &out)
Prints quantization parameters as a nested list of scale[:zero_point] elements.
static void printAnyQuantizedType(AnyQuantizedType type, DialectAsmPrinter &out)
Helper that prints a AnyQuantizedType.
LogicalResult isScaleInExpressedTypeRange(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double scale)
Checks if the given scale value is within the valid range of the expressed type.
static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser, double &min, double &max)
static void printUniformQuantizedSubChannelType(UniformQuantizedSubChannelType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedSubChannelType.
static Type parseUniformType(DialectAsmParser &parser)
Parses a UniformQuantizedType.
static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned)
static Type parseAnyType(DialectAsmParser &parser)
Parses an AnyQuantizedType.
static void printStorageType(QuantizedType type, DialectAsmPrinter &out)
static ParseResult parseBlockSizeInfoUntilRBrace(DialectAsmParser &parser, SmallVectorImpl< int32_t > &quantizedDimensions, SmallVectorImpl< int64_t > &blockSizes)
Parses block size information for sub-channel quantization, assuming the leading '{' has already been...
static void printQuantParams(double scale, int64_t zeroPoint, DialectAsmPrinter &out)
static Type parseCalibratedType(DialectAsmParser &parser)
Parses an CalibratedQuantizedType.
static ParseResult parseQuantParamListUntilRBrace(DialectAsmParser &parser, Type expressedType, SmallVectorImpl< double > &scales, SmallVectorImpl< int64_t > &zeroPoints, SmallVectorImpl< int64_t > &dims)
Parses a bracketed list of quantization parameters, returning the dimensions of the parsed sub-tensor...
static ParseResult parseStorageRange(DialectAsmParser &parser, IntegerType storageType, bool isSigned, int64_t &storageTypeMin, int64_t &storageTypeMax)
static void printCalibratedQuantizedType(CalibratedQuantizedType type, DialectAsmPrinter &out)
Helper that prints a CalibratedQuantizedType.
static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedPerAxisType.
static void printUniformQuantizedType(UniformQuantizedType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedType.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
virtual ParseResult parseLBrace()=0
Parse a { token.
virtual ParseResult parseOptionalRBrace()=0
Parse a } token if present.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual OptionalParseResult parseOptionalType(Type &result)=0
Parse an optional type.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseFloat(double &result)=0
Parse a floating point value from the stream.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
IntegerType getIntegerType(unsigned width)
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class represents a diagnostic that is inflight and set to be reported.
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
A quantized type that maps storage to/from expressed types in an unspecified way.
A quantized type that infers its range from given min/max values.
Base class for all quantized types known to this dialect.
static constexpr unsigned MaxStorageBits
The maximum number of bits supported for storage types.
bool hasStorageTypeBounds() const
Return whether the storage type has explicit min or max boundaries different from the minimum and max...
bool isSigned() const
Whether the storage type should be interpreted as a signed quantity (true) or an unsigned value (fals...
int64_t getStorageTypeMax() const
The maximum value that storageType can take.
static int64_t getDefaultMaximumForInteger(bool isSigned, unsigned integralWidth)
Gets the maximum possible stored by a storageType.
unsigned getStorageTypeIntegralWidth() const
Gets the integral bit width that the underlying storage type can exactly represent.
int64_t getStorageTypeMin() const
The minimum value that storageType can take.
static int64_t getDefaultMinimumForInteger(bool isSigned, unsigned integralWidth)
Gets the minimum possible stored by a storageType.
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.