15#include "llvm/ADT/APFloat.h"
16#include "llvm/ADT/SmallVectorExtras.h"
27 unsigned storageTypeWidth = 0;
33 if (
auto quantStorageTypeInterface =
34 llvm::dyn_cast<QuantStorageTypeInterface>(type)) {
37 isSigned = quantStorageTypeInterface.shouldDefaultToSigned();
38 storageTypeWidth = quantStorageTypeInterface.getStorageWidth();
40 parser.
emitError(typeLoc,
"illegal storage type prefix");
43 }
else if (succeeded(parser.
parseKeyword(&identifier))) {
45 if (identifier.consume_front(
"u")) {
46 if (identifier.getAsInteger(10, storageTypeWidth)) {
47 parser.
emitError(typeLoc,
"expected storage type width");
53 parser.
emitError(typeLoc,
"illegal storage type prefix");
60 if (storageTypeWidth == 0 ||
61 storageTypeWidth > QuantizedType::MaxStorageBits) {
62 parser.
emitError(typeLoc,
"illegal storage type size: ")
71 bool isSigned,
int64_t &storageTypeMin,
73 auto quantStorageTypeInterface =
74 llvm::dyn_cast<QuantStorageTypeInterface>(storageType);
76 int64_t defaultMin = quantStorageTypeInterface.getDefaultMinimum(isSigned);
77 int64_t defaultMax = quantStorageTypeInterface.getDefaultMaximum(isSigned);
80 storageTypeMin = defaultMin;
81 storageTypeMax = defaultMax;
91 if (storageTypeMin < defaultMin) {
92 return parser.
emitError(minLoc,
"illegal storage type minimum: ")
95 if (storageTypeMax > defaultMax) {
96 return parser.
emitError(maxLoc,
"illegal storage type maximum: ")
103 double &
min,
double &
max) {
108 parser.
emitError(typeLoc,
"expecting float expressed type");
115 parser.
emitError(typeLoc,
"calibrated values must be present");
130 FloatType expressedType;
131 unsigned typeFlags = 0;
140 bool isSigned =
false;
167 typeFlags, storageType, expressedType, storageTypeMin, storageTypeMax);
175 Type expressedType,
double scale) {
176 auto floatType = cast<FloatType>(expressedType);
178 APFloat::getSmallest(floatType.getFloatSemantics()).convertToDouble();
180 APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble();
181 if (scale < minScale || scale > maxScale)
182 return emitError() <<
"scale " << scale <<
" out of expressed type range ["
183 << minScale <<
", " << maxScale <<
"]";
192 Type expressedType,
double &scale,
201 expressedType, scale))) {
228 auto parseBlockSizeElements = [&]() -> ParseResult {
229 quantizedDimensions.resize(quantizedDimensions.size() + 1);
230 blockSizes.resize(blockSizes.size() + 1);
265 if (prevDims == newDims)
268 <<
"tensor literal is invalid; ranks are not consistent "
276 auto parseOneElement = [&]() -> ParseResult {
280 zeroPoints, thisDims))
283 zeroPoints.resize(zeroPoints.size() + 1);
284 scales.resize(scales.size() + 1);
286 zeroPoints.back())) {
292 return checkDims(newDims, thisDims);
304 dims.push_back(size);
305 dims.append(newDims.begin(), newDims.end());
340 FloatType expressedType;
341 unsigned typeFlags = 0;
344 bool isPerAxis =
false;
345 bool isSubChannel =
false;
357 bool isSigned =
false;
388 quantizedDimensions.resize(1);
401 bool isPerTensor = !isPerAxis && !isSubChannel;
404 zeroPoints.resize(zeroPoints.size() + 1);
405 scales.resize(scales.size() + 1);
407 zeroPoints.back())) {
425 typeFlags, storageType, expressedType, scales, zeroPoints,
426 quantizedDimensions[0], storageTypeMin, storageTypeMax);
430 llvm::map_to_vector(scales, [&](
double scale) -> APFloat {
431 APFloat apFloatScale(scale);
433 apFloatScale.convert(expressedType.getFloatSemantics(),
434 APFloat::rmNearestTiesToEven, &unused);
438 llvm::map_to_vector(zeroPoints, [&](
int64_t zeroPoint) -> APInt {
442 RankedTensorType::get(dims, expressedType), apFloatScales);
444 RankedTensorType::get(dims, storageType), apIntZeroPoints);
446 typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
447 quantizedDimensions, blockSizes, storageTypeMin, storageTypeMax);
451 typeFlags, storageType, expressedType, scales.front(), zeroPoints.front(),
452 storageTypeMin, storageTypeMax);
462 FloatType expressedType;
472 if (!expressedType) {
504 quantiles.emplace_back();
515 std::optional<int64_t> storageMin, storageMax;
534 quantiles, storageMin, storageMax);
540 StringRef typeNameSpelling;
544 if (typeNameSpelling ==
"uniform")
546 if (typeNameSpelling ==
"any")
548 if (typeNameSpelling ==
"calibrated")
550 if (typeNameSpelling ==
"quantile")
554 "unknown quantized type " + typeNameSpelling);
560 auto quantStorageTypeInterface =
563 out << quantStorageTypeInterface.getStorageTypeName(type.
isSigned());
575 if (zeroPoint != 0) {
576 out <<
":" << zeroPoint;
584 llvm::interleaveComma(
585 llvm::seq<size_t>(0, blockSizeInfo.size()), out, [&](
size_t index) {
586 out << blockSizeInfo[index].first <<
":" << blockSizeInfo[index].second;
596 if (
Type expressedType = type.getExpressedType()) {
597 out <<
":" << expressedType;
607 out <<
":" << type.getExpressedType() <<
", ";
619 out <<
":" << type.getExpressedType() <<
":";
628 llvm::seq<size_t>(0, scales.size()), out,
630 printQuantParams(scales[index], zeroPoints[index], out);
655 unsigned openBrackets = 0;
657 auto incrementCounterAndDelimit = [&]() {
659 for (
unsigned i = rank - 1; i > 0; --i) {
660 if (counter[i] >=
shape[i]) {
669 for (
unsigned idx = 0, e = scales.size(); idx < e; ++idx) {
672 while (openBrackets++ < rank)
676 if (zeroPoints[idx] != 0) {
677 out <<
":" << zeroPoints[idx];
679 incrementCounterAndDelimit();
681 while (openBrackets-- > 0)
691 out <<
":" << type.getExpressedType() <<
":";
707 out <<
"calibrated<" << type.getExpressedType();
708 out <<
"<" << type.
getMin() <<
":" << type.
getMax() <<
">";
720 llvm::seq<size_t>(0, quantiles.size()), out,
721 [&](
size_t index) { out << quantiles[index]; },
",");
725 out <<
", <" << *minVal <<
":" << *maxVal <<
">";
731 if (
auto anyType = llvm::dyn_cast<AnyQuantizedType>(type))
733 else if (
auto uniformType = llvm::dyn_cast<UniformQuantizedType>(type))
735 else if (
auto perAxisType = llvm::dyn_cast<UniformQuantizedPerAxisType>(type))
737 else if (
auto perAxisType =
738 llvm::dyn_cast<UniformQuantizedSubChannelType>(type))
740 else if (
auto calibratedType = llvm::dyn_cast<CalibratedQuantizedType>(type))
742 else if (
auto quantileType = llvm::dyn_cast<QuantileType>(type))
745 llvm_unreachable(
"Unhandled quantized type");
static void printBlockSizeInfo(ArrayRef< std::pair< int32_t, int64_t > > blockSizeInfo, DialectAsmPrinter &out)
static ParseResult parseQuantParams(DialectAsmParser &parser, Type expressedType, double &scale, int64_t &zeroPoint)
Parses a quantization parameter, which is either a scale value (float) or a scale-zero point pair (fl...
static void printAnyQuantizedType(AnyQuantizedType type, DialectAsmPrinter &out)
Helper that prints a AnyQuantizedType.
static Type parseStorageType(DialectAsmParser &parser, bool &isSigned)
static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser, double &min, double &max)
static ParseResult parseStorageRange(DialectAsmParser &parser, Type storageType, bool isSigned, int64_t &storageTypeMin, int64_t &storageTypeMax)
static void printUniformQuantizedSubChannelType(UniformQuantizedSubChannelType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedSubChannelType.
static Type parseUniformType(DialectAsmParser &parser)
Parses a UniformQuantizedType.
static void printQuantileType(QuantileType type, DialectAsmPrinter &out)
static Type parseAnyType(DialectAsmParser &parser)
Parses an AnyQuantizedType.
static void printStorageType(QuantizedType type, DialectAsmPrinter &out)
static ParseResult parseBlockSizeInfoUntilRBrace(DialectAsmParser &parser, SmallVectorImpl< int32_t > &quantizedDimensions, SmallVectorImpl< int64_t > &blockSizes)
Parses block size information for sub-channel quantization, assuming the leading '{' has already been...
static void printQuantParams(double scale, int64_t zeroPoint, DialectAsmPrinter &out)
static Type parseCalibratedType(DialectAsmParser &parser)
Parses an CalibratedQuantizedType.
static ParseResult parseQuantParamListUntilRBrace(DialectAsmParser &parser, Type expressedType, SmallVectorImpl< double > &scales, SmallVectorImpl< int64_t > &zeroPoints, SmallVectorImpl< int64_t > &dims)
Parses a bracketed list of quantization parameters, returning the dimensions of the parsed sub-tensor...
static void printCalibratedQuantizedType(CalibratedQuantizedType type, DialectAsmPrinter &out)
Helper that prints a CalibratedQuantizedType.
static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedPerAxisType.
static void printUniformQuantizedType(UniformQuantizedType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedType.
static LogicalResult isScaleInExpressedTypeRange(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double scale)
Checks if the given scale value is within the valid range of the expressed type.
static Type parseQuantileType(DialectAsmParser &parser)
static void printDenseQuantizationParameters(ArrayRef< APFloat > scales, ArrayRef< APInt > zeroPoints, ArrayRef< int64_t > shape, DialectAsmPrinter &out)
Prints quantization parameters as a nested list of scale[:zero_point] elements.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
virtual ParseResult parseLBrace()=0
Parse a { token.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalRBrace()=0
Parse a } token if present.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual OptionalParseResult parseOptionalType(Type &result)=0
Parse an optional type.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseFloat(double &result)=0
Parse a floating point value from the stream.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
IntegerType getIntegerType(unsigned width)
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
This class implements Optional functionality for ParseResult.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
A quantized type that maps storage to/from expressed types in an unspecified way.
A quantized type that infers its range from given min/max values.
std::optional< int64_t > getStorageMin() const
Return the explicit storage minimum, if set.
Type getQuantileType() const
ArrayRef< double > getQuantiles() const
Return the quantile table of this float type.
std::optional< int64_t > getStorageMax() const
Return the explicit storage maximum, if set.
Type getStorageType() const
Base class for all quantized types known to this dialect.
bool hasStorageTypeBounds() const
Return whether the storage type has explicit min or max boundaries different from the minimum and max...
bool isSigned() const
Whether the storage type should be interpreted as a signed quantity (true) or an unsigned value (fals...
int64_t getStorageTypeMax() const
The maximum value that storageType can take.
int64_t getStorageTypeMin() const
The minimum value that storageType can take.
Type getStorageType() const
Gets the underlying type used for to store values.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::function_ref< Fn > function_ref