14#include "llvm/ADT/APFloat.h"
15#include "llvm/ADT/SmallVectorExtras.h"
26 unsigned storageTypeWidth = 0;
31 isSigned = !type.isUnsigned();
32 storageTypeWidth = type.getWidth();
33 }
else if (succeeded(parser.
parseKeyword(&identifier))) {
35 if (!identifier.consume_front(
"u")) {
36 parser.
emitError(typeLoc,
"illegal storage type prefix");
39 if (identifier.getAsInteger(10, storageTypeWidth)) {
40 parser.
emitError(typeLoc,
"expected storage type width");
49 if (storageTypeWidth == 0 ||
50 storageTypeWidth > QuantizedType::MaxStorageBits) {
51 parser.
emitError(typeLoc,
"illegal storage type size: ")
60 IntegerType storageType,
bool isSigned,
63 int64_t defaultIntegerMin = QuantizedType::getDefaultMinimumForInteger(
64 isSigned, storageType.getWidth());
65 int64_t defaultIntegerMax = QuantizedType::getDefaultMaximumForInteger(
66 isSigned, storageType.getWidth());
68 storageTypeMin = defaultIntegerMin;
69 storageTypeMax = defaultIntegerMax;
79 if (storageTypeMin < defaultIntegerMin) {
80 return parser.
emitError(minLoc,
"illegal storage type minimum: ")
83 if (storageTypeMax > defaultIntegerMax) {
84 return parser.
emitError(maxLoc,
"illegal storage type maximum: ")
91 double &
min,
double &
max) {
96 parser.
emitError(typeLoc,
"expecting float expressed type");
103 parser.
emitError(typeLoc,
"calibrated values must be present");
117 IntegerType storageType;
118 FloatType expressedType;
119 unsigned typeFlags = 0;
128 bool isSigned =
false;
155 typeFlags, storageType, expressedType, storageTypeMin, storageTypeMax);
163 Type expressedType,
double scale) {
164 auto floatType = cast<FloatType>(expressedType);
166 APFloat::getSmallest(floatType.getFloatSemantics()).convertToDouble();
168 APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble();
169 if (scale < minScale || scale > maxScale)
170 return emitError() <<
"scale " << scale <<
" out of expressed type range ["
171 << minScale <<
", " << maxScale <<
"]";
180 Type expressedType,
double &scale,
189 expressedType, scale))) {
216 auto parseBlockSizeElements = [&]() -> ParseResult {
217 quantizedDimensions.resize(quantizedDimensions.size() + 1);
218 blockSizes.resize(blockSizes.size() + 1);
253 if (prevDims == newDims)
256 <<
"tensor literal is invalid; ranks are not consistent "
264 auto parseOneElement = [&]() -> ParseResult {
268 zeroPoints, thisDims))
271 zeroPoints.resize(zeroPoints.size() + 1);
272 scales.resize(scales.size() + 1);
274 zeroPoints.back())) {
280 return checkDims(newDims, thisDims);
292 dims.push_back(size);
293 dims.append(newDims.begin(), newDims.end());
326 IntegerType storageType;
327 FloatType expressedType;
328 unsigned typeFlags = 0;
331 bool isPerAxis =
false;
332 bool isSubChannel =
false;
344 bool isSigned =
false;
375 quantizedDimensions.resize(1);
388 bool isPerTensor = !isPerAxis && !isSubChannel;
391 zeroPoints.resize(zeroPoints.size() + 1);
392 scales.resize(scales.size() + 1);
394 zeroPoints.back())) {
412 typeFlags, storageType, expressedType, scales, zeroPoints,
413 quantizedDimensions[0], storageTypeMin, storageTypeMax);
417 llvm::map_to_vector(scales, [&](
double scale) -> APFloat {
418 APFloat apFloatScale(scale);
420 apFloatScale.convert(expressedType.getFloatSemantics(),
421 APFloat::rmNearestTiesToEven, &unused);
425 llvm::map_to_vector(zeroPoints, [&](
int64_t zeroPoint) -> APInt {
426 return APInt(storageType.getIntOrFloatBitWidth(), zeroPoint);
429 RankedTensorType::get(dims, expressedType), apFloatScales);
431 RankedTensorType::get(dims, storageType), apIntZeroPoints);
433 typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
434 quantizedDimensions, blockSizes, storageTypeMin, storageTypeMax);
438 typeFlags, storageType, expressedType, scales.front(), zeroPoints.front(),
439 storageTypeMin, storageTypeMax);
449 FloatType expressedType;
459 if (!expressedType) {
473 StringRef typeNameSpelling;
477 if (typeNameSpelling ==
"uniform")
479 if (typeNameSpelling ==
"any")
481 if (typeNameSpelling ==
"calibrated")
485 "unknown quantized type " + typeNameSpelling);
494 out <<
"i" << storageWidth;
496 out <<
"u" << storageWidth;
509 if (zeroPoint != 0) {
510 out <<
":" << zeroPoint;
518 llvm::interleaveComma(
519 llvm::seq<size_t>(0, blockSizeInfo.size()), out, [&](
size_t index) {
520 out << blockSizeInfo[index].first <<
":" << blockSizeInfo[index].second;
530 if (
Type expressedType = type.getExpressedType()) {
531 out <<
":" << expressedType;
541 out <<
":" << type.getExpressedType() <<
", ";
553 out <<
":" << type.getExpressedType() <<
":";
562 llvm::seq<size_t>(0, scales.size()), out,
564 printQuantParams(scales[index], zeroPoints[index], out);
589 unsigned openBrackets = 0;
591 auto incrementCounterAndDelimit = [&]() {
593 for (
unsigned i = rank - 1; i > 0; --i) {
594 if (counter[i] >=
shape[i]) {
603 for (
unsigned idx = 0, e = scales.size(); idx < e; ++idx) {
606 while (openBrackets++ < rank)
610 if (zeroPoints[idx] != 0) {
611 out <<
":" << zeroPoints[idx];
613 incrementCounterAndDelimit();
615 while (openBrackets-- > 0)
625 out <<
":" << type.getExpressedType() <<
":";
641 out <<
"calibrated<" << type.getExpressedType();
642 out <<
"<" << type.
getMin() <<
":" << type.
getMax() <<
">";
648 if (
auto anyType = llvm::dyn_cast<AnyQuantizedType>(type))
650 else if (
auto uniformType = llvm::dyn_cast<UniformQuantizedType>(type))
652 else if (
auto perAxisType = llvm::dyn_cast<UniformQuantizedPerAxisType>(type))
654 else if (
auto perAxisType =
655 llvm::dyn_cast<UniformQuantizedSubChannelType>(type))
657 else if (
auto calibratedType = llvm::dyn_cast<CalibratedQuantizedType>(type))
660 llvm_unreachable(
"Unhandled quantized type");
static void printBlockSizeInfo(ArrayRef< std::pair< int32_t, int64_t > > blockSizeInfo, DialectAsmPrinter &out)
static ParseResult parseQuantParams(DialectAsmParser &parser, Type expressedType, double &scale, int64_t &zeroPoint)
Parses a quantization parameter, which is either a scale value (float) or a scale-zero point pair (fl...
static void printAnyQuantizedType(AnyQuantizedType type, DialectAsmPrinter &out)
Helper that prints a AnyQuantizedType.
static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser, double &min, double &max)
static void printUniformQuantizedSubChannelType(UniformQuantizedSubChannelType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedSubChannelType.
static Type parseUniformType(DialectAsmParser &parser)
Parses a UniformQuantizedType.
static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned)
static Type parseAnyType(DialectAsmParser &parser)
Parses an AnyQuantizedType.
static void printStorageType(QuantizedType type, DialectAsmPrinter &out)
static ParseResult parseBlockSizeInfoUntilRBrace(DialectAsmParser &parser, SmallVectorImpl< int32_t > &quantizedDimensions, SmallVectorImpl< int64_t > &blockSizes)
Parses block size information for sub-channel quantization, assuming the leading '{' has already been...
static void printQuantParams(double scale, int64_t zeroPoint, DialectAsmPrinter &out)
static Type parseCalibratedType(DialectAsmParser &parser)
Parses an CalibratedQuantizedType.
static ParseResult parseQuantParamListUntilRBrace(DialectAsmParser &parser, Type expressedType, SmallVectorImpl< double > &scales, SmallVectorImpl< int64_t > &zeroPoints, SmallVectorImpl< int64_t > &dims)
Parses a bracketed list of quantization parameters, returning the dimensions of the parsed sub-tensor...
static ParseResult parseStorageRange(DialectAsmParser &parser, IntegerType storageType, bool isSigned, int64_t &storageTypeMin, int64_t &storageTypeMax)
static void printCalibratedQuantizedType(CalibratedQuantizedType type, DialectAsmPrinter &out)
Helper that prints a CalibratedQuantizedType.
static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedPerAxisType.
static void printUniformQuantizedType(UniformQuantizedType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedType.
static LogicalResult isScaleInExpressedTypeRange(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double scale)
Checks if the given scale value is within the valid range of the expressed type.
static void printDenseQuantizationParameters(ArrayRef< APFloat > scales, ArrayRef< APInt > zeroPoints, ArrayRef< int64_t > shape, DialectAsmPrinter &out)
Prints quantization parameters as a nested list of scale[:zero_point] elements.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
virtual ParseResult parseLBrace()=0
Parse a { token.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalRBrace()=0
Parse a } token if present.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual OptionalParseResult parseOptionalType(Type &result)=0
Parse an optional type.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseFloat(double &result)=0
Parse a floating point value from the stream.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
IntegerType getIntegerType(unsigned width)
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class represents a diagnostic that is inflight and set to be reported.
This class implements Optional functionality for ParseResult.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
A quantized type that maps storage to/from expressed types in an unspecified way.
A quantized type that infers its range from given min/max values.
Base class for all quantized types known to this dialect.
bool hasStorageTypeBounds() const
Return whether the storage type has explicit min or max boundaries different from the minimum and max...
bool isSigned() const
Whether the storage type should be interpreted as a signed quantity (true) or an unsigned value (fals...
int64_t getStorageTypeMax() const
The maximum value that storageType can take.
unsigned getStorageTypeIntegralWidth() const
Gets the integral bit width that the underlying storage type can exactly represent.
int64_t getStorageTypeMin() const
The minimum value that storageType can take.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::function_ref< Fn > function_ref