15#include "llvm/ADT/APFloat.h"
16#include "llvm/ADT/SmallVectorExtras.h"
27 unsigned storageTypeWidth = 0;
33 if (
auto quantStorageTypeInterface =
34 llvm::dyn_cast<QuantStorageTypeInterface>(type)) {
37 isSigned = quantStorageTypeInterface.shouldDefaultToSigned();
38 storageTypeWidth = quantStorageTypeInterface.getStorageWidth();
40 parser.
emitError(typeLoc,
"illegal storage type prefix");
43 }
else if (succeeded(parser.
parseKeyword(&identifier))) {
45 if (identifier.consume_front(
"u")) {
46 if (identifier.getAsInteger(10, storageTypeWidth)) {
47 parser.
emitError(typeLoc,
"expected storage type width");
53 parser.
emitError(typeLoc,
"illegal storage type prefix");
60 if (storageTypeWidth == 0 ||
61 storageTypeWidth > QuantizedType::MaxStorageBits) {
62 parser.
emitError(typeLoc,
"illegal storage type size: ")
71 bool isSigned,
int64_t &storageTypeMin,
73 auto quantStorageTypeInterface =
74 llvm::dyn_cast<QuantStorageTypeInterface>(storageType);
76 int64_t defaultMin = quantStorageTypeInterface.getDefaultMinimum(isSigned);
77 int64_t defaultMax = quantStorageTypeInterface.getDefaultMaximum(isSigned);
80 storageTypeMin = defaultMin;
81 storageTypeMax = defaultMax;
91 if (storageTypeMin < defaultMin) {
92 return parser.
emitError(minLoc,
"illegal storage type minimum: ")
95 if (storageTypeMax > defaultMax) {
96 return parser.
emitError(maxLoc,
"illegal storage type maximum: ")
103 double &
min,
double &
max) {
108 parser.
emitError(typeLoc,
"expecting float expressed type");
115 parser.
emitError(typeLoc,
"calibrated values must be present");
130 FloatType expressedType;
131 unsigned typeFlags = 0;
140 bool isSigned =
false;
167 typeFlags, storageType, expressedType, storageTypeMin, storageTypeMax);
175 Type expressedType,
double scale) {
176 auto floatType = cast<FloatType>(expressedType);
178 APFloat::getSmallest(floatType.getFloatSemantics()).convertToDouble();
180 APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble();
181 if (scale < minScale || scale > maxScale)
182 return emitError() <<
"scale " << scale <<
" out of expressed type range ["
183 << minScale <<
", " << maxScale <<
"]";
192 Type expressedType,
double &scale,
201 expressedType, scale))) {
228 auto parseBlockSizeElements = [&]() -> ParseResult {
229 quantizedDimensions.resize(quantizedDimensions.size() + 1);
230 blockSizes.resize(blockSizes.size() + 1);
265 if (prevDims == newDims)
268 <<
"tensor literal is invalid; ranks are not consistent "
276 auto parseOneElement = [&]() -> ParseResult {
280 zeroPoints, thisDims))
283 zeroPoints.resize(zeroPoints.size() + 1);
284 scales.resize(scales.size() + 1);
286 zeroPoints.back())) {
292 return checkDims(newDims, thisDims);
304 dims.push_back(size);
305 dims.append(newDims.begin(), newDims.end());
339 FloatType expressedType;
340 unsigned typeFlags = 0;
343 bool isPerAxis =
false;
344 bool isSubChannel =
false;
356 bool isSigned =
false;
387 quantizedDimensions.resize(1);
400 bool isPerTensor = !isPerAxis && !isSubChannel;
403 zeroPoints.resize(zeroPoints.size() + 1);
404 scales.resize(scales.size() + 1);
406 zeroPoints.back())) {
424 typeFlags, storageType, expressedType, scales, zeroPoints,
425 quantizedDimensions[0], storageTypeMin, storageTypeMax);
429 llvm::map_to_vector(scales, [&](
double scale) -> APFloat {
430 APFloat apFloatScale(scale);
432 apFloatScale.convert(expressedType.getFloatSemantics(),
433 APFloat::rmNearestTiesToEven, &unused);
437 llvm::map_to_vector(zeroPoints, [&](
int64_t zeroPoint) -> APInt {
441 RankedTensorType::get(dims, expressedType), apFloatScales);
443 RankedTensorType::get(dims, storageType), apIntZeroPoints);
445 typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
446 quantizedDimensions, blockSizes, storageTypeMin, storageTypeMax);
450 typeFlags, storageType, expressedType, scales.front(), zeroPoints.front(),
451 storageTypeMin, storageTypeMax);
461 FloatType expressedType;
471 if (!expressedType) {
485 StringRef typeNameSpelling;
489 if (typeNameSpelling ==
"uniform")
491 if (typeNameSpelling ==
"any")
493 if (typeNameSpelling ==
"calibrated")
497 "unknown quantized type " + typeNameSpelling);
503 auto quantStorageTypeInterface =
506 out << quantStorageTypeInterface.getStorageTypeName(type.
isSigned());
518 if (zeroPoint != 0) {
519 out <<
":" << zeroPoint;
527 llvm::interleaveComma(
528 llvm::seq<size_t>(0, blockSizeInfo.size()), out, [&](
size_t index) {
529 out << blockSizeInfo[index].first <<
":" << blockSizeInfo[index].second;
539 if (
Type expressedType = type.getExpressedType()) {
540 out <<
":" << expressedType;
550 out <<
":" << type.getExpressedType() <<
", ";
562 out <<
":" << type.getExpressedType() <<
":";
571 llvm::seq<size_t>(0, scales.size()), out,
573 printQuantParams(scales[index], zeroPoints[index], out);
598 unsigned openBrackets = 0;
600 auto incrementCounterAndDelimit = [&]() {
602 for (
unsigned i = rank - 1; i > 0; --i) {
603 if (counter[i] >=
shape[i]) {
612 for (
unsigned idx = 0, e = scales.size(); idx < e; ++idx) {
615 while (openBrackets++ < rank)
619 if (zeroPoints[idx] != 0) {
620 out <<
":" << zeroPoints[idx];
622 incrementCounterAndDelimit();
624 while (openBrackets-- > 0)
634 out <<
":" << type.getExpressedType() <<
":";
650 out <<
"calibrated<" << type.getExpressedType();
651 out <<
"<" << type.
getMin() <<
":" << type.
getMax() <<
">";
657 if (
auto anyType = llvm::dyn_cast<AnyQuantizedType>(type))
659 else if (
auto uniformType = llvm::dyn_cast<UniformQuantizedType>(type))
661 else if (
auto perAxisType = llvm::dyn_cast<UniformQuantizedPerAxisType>(type))
663 else if (
auto perAxisType =
664 llvm::dyn_cast<UniformQuantizedSubChannelType>(type))
666 else if (
auto calibratedType = llvm::dyn_cast<CalibratedQuantizedType>(type))
669 llvm_unreachable(
"Unhandled quantized type");
static void printBlockSizeInfo(ArrayRef< std::pair< int32_t, int64_t > > blockSizeInfo, DialectAsmPrinter &out)
static ParseResult parseQuantParams(DialectAsmParser &parser, Type expressedType, double &scale, int64_t &zeroPoint)
Parses a quantization parameter, which is either a scale value (float) or a scale-zero point pair (fl...
static void printAnyQuantizedType(AnyQuantizedType type, DialectAsmPrinter &out)
Helper that prints a AnyQuantizedType.
static Type parseStorageType(DialectAsmParser &parser, bool &isSigned)
static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser, double &min, double &max)
static ParseResult parseStorageRange(DialectAsmParser &parser, Type storageType, bool isSigned, int64_t &storageTypeMin, int64_t &storageTypeMax)
static void printUniformQuantizedSubChannelType(UniformQuantizedSubChannelType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedSubChannelType.
static Type parseUniformType(DialectAsmParser &parser)
Parses a UniformQuantizedType.
static Type parseAnyType(DialectAsmParser &parser)
Parses an AnyQuantizedType.
static void printStorageType(QuantizedType type, DialectAsmPrinter &out)
static ParseResult parseBlockSizeInfoUntilRBrace(DialectAsmParser &parser, SmallVectorImpl< int32_t > &quantizedDimensions, SmallVectorImpl< int64_t > &blockSizes)
Parses block size information for sub-channel quantization, assuming the leading '{' has already been...
static void printQuantParams(double scale, int64_t zeroPoint, DialectAsmPrinter &out)
static Type parseCalibratedType(DialectAsmParser &parser)
Parses an CalibratedQuantizedType.
static ParseResult parseQuantParamListUntilRBrace(DialectAsmParser &parser, Type expressedType, SmallVectorImpl< double > &scales, SmallVectorImpl< int64_t > &zeroPoints, SmallVectorImpl< int64_t > &dims)
Parses a bracketed list of quantization parameters, returning the dimensions of the parsed sub-tensor...
static void printCalibratedQuantizedType(CalibratedQuantizedType type, DialectAsmPrinter &out)
Helper that prints a CalibratedQuantizedType.
static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedPerAxisType.
static void printUniformQuantizedType(UniformQuantizedType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedType.
static LogicalResult isScaleInExpressedTypeRange(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double scale)
Checks if the given scale value is within the valid range of the expressed type.
static void printDenseQuantizationParameters(ArrayRef< APFloat > scales, ArrayRef< APInt > zeroPoints, ArrayRef< int64_t > shape, DialectAsmPrinter &out)
Prints quantization parameters as a nested list of scale[:zero_point] elements.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
virtual ParseResult parseLBrace()=0
Parse a { token.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalRBrace()=0
Parse a } token if present.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual OptionalParseResult parseOptionalType(Type &result)=0
Parse an optional type.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseFloat(double &result)=0
Parse a floating point value from the stream.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
IntegerType getIntegerType(unsigned width)
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class represents a diagnostic that is inflight and set to be reported.
This class implements Optional functionality for ParseResult.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
A quantized type that maps storage to/from expressed types in an unspecified way.
A quantized type that infers its range from given min/max values.
Base class for all quantized types known to this dialect.
bool hasStorageTypeBounds() const
Return whether the storage type has explicit min or max boundaries different from the minimum and max...
bool isSigned() const
Whether the storage type should be interpreted as a signed quantity (true) or an unsigned value (fals...
int64_t getStorageTypeMax() const
The maximum value that storageType can take.
int64_t getStorageTypeMin() const
The minimum value that storageType can take.
Type getStorageType() const
Gets the underlying type used for to store values.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::function_ref< Fn > function_ref