15 #include "llvm/ADT/APFloat.h"
16 #include "llvm/Support/Format.h"
17 #include "llvm/Support/MathExtras.h"
18 #include "llvm/Support/SourceMgr.h"
19 #include "llvm/Support/raw_ostream.h"
22 using namespace quant;
30 unsigned storageTypeWidth = 0;
33 if (!succeeded(*result))
35 isSigned = !type.isUnsigned();
36 storageTypeWidth = type.getWidth();
37 }
else if (succeeded(parser.
parseKeyword(&identifier))) {
39 if (!identifier.consume_front(
"u")) {
40 parser.
emitError(typeLoc,
"illegal storage type prefix");
43 if (identifier.getAsInteger(10, storageTypeWidth)) {
44 parser.
emitError(typeLoc,
"expected storage type width");
53 if (storageTypeWidth == 0 ||
55 parser.
emitError(typeLoc,
"illegal storage type size: ")
64 IntegerType storageType,
bool isSigned,
65 int64_t &storageTypeMin,
66 int64_t &storageTypeMax) {
68 isSigned, storageType.getWidth());
70 isSigned, storageType.getWidth());
72 storageTypeMin = defaultIntegerMin;
73 storageTypeMax = defaultIntegerMax;
83 if (storageTypeMin < defaultIntegerMin) {
84 return parser.
emitError(minLoc,
"illegal storage type minimum: ")
87 if (storageTypeMax > defaultIntegerMax) {
88 return parser.
emitError(maxLoc,
"illegal storage type maximum: ")
95 double &
min,
double &
max) {
100 parser.
emitError(typeLoc,
"expecting float expressed type");
107 parser.
emitError(typeLoc,
"calibrated values must be present");
121 IntegerType storageType;
122 FloatType expressedType;
123 unsigned typeFlags = 0;
124 int64_t storageTypeMin;
125 int64_t storageTypeMax;
132 bool isSigned =
false;
159 typeFlags, storageType, expressedType, storageTypeMin, storageTypeMax);
167 Type expressedType,
double scale) {
168 auto floatType = cast<FloatType>(expressedType);
170 APFloat::getSmallest(floatType.getFloatSemantics()).convertToDouble();
172 APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble();
173 if (scale < minScale || scale > maxScale)
174 return emitError() <<
"scale " << scale <<
" out of expressed type range ["
175 << minScale <<
", " << maxScale <<
"]";
184 Type expressedType,
double &scale,
185 int64_t &zeroPoint) {
193 expressedType, scale))) {
220 auto parseBlockSizeElements = [&]() -> ParseResult {
221 quantizedDimensions.resize(quantizedDimensions.size() + 1);
222 blockSizes.resize(blockSizes.size() + 1);
257 if (prevDims == newDims)
260 <<
"tensor literal is invalid; ranks are not consistent "
268 auto parseOneElement = [&]() -> ParseResult {
272 zeroPoints, thisDims))
275 zeroPoints.resize(zeroPoints.size() + 1);
276 scales.resize(scales.size() + 1);
278 zeroPoints.back())) {
284 return checkDims(newDims, thisDims);
296 dims.push_back(size);
297 dims.append(newDims.begin(), newDims.end());
330 IntegerType storageType;
331 FloatType expressedType;
332 unsigned typeFlags = 0;
333 int64_t storageTypeMin;
334 int64_t storageTypeMax;
335 bool isPerAxis =
false;
336 bool isSubChannel =
false;
348 bool isSigned =
false;
379 quantizedDimensions.resize(1);
392 bool isPerTensor = !isPerAxis && !isSubChannel;
395 zeroPoints.resize(zeroPoints.size() + 1);
396 scales.resize(scales.size() + 1);
398 zeroPoints.back())) {
416 typeFlags, storageType, expressedType, scales, zeroPoints,
417 quantizedDimensions[0], storageTypeMin, storageTypeMax);
418 }
else if (isSubChannel) {
420 llvm::to_vector(llvm::map_range(scales, [&](
double scale) -> APFloat {
421 APFloat apFloatScale(scale);
423 apFloatScale.convert(expressedType.getFloatSemantics(),
424 APFloat::rmNearestTiesToEven, &unused);
428 llvm::map_range(zeroPoints, [&](int64_t zeroPoint) -> APInt {
429 return APInt(storageType.getIntOrFloatBitWidth(), zeroPoint);
436 typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
437 quantizedDimensions, blockSizes, storageTypeMin, storageTypeMax);
441 typeFlags, storageType, expressedType, scales.front(), zeroPoints.front(),
442 storageTypeMin, storageTypeMax);
452 FloatType expressedType;
462 if (!expressedType) {
476 StringRef typeNameSpelling;
480 if (typeNameSpelling ==
"uniform")
482 if (typeNameSpelling ==
"any")
484 if (typeNameSpelling ==
"calibrated")
488 "unknown quantized type " + typeNameSpelling);
497 out <<
"i" << storageWidth;
499 out <<
"u" << storageWidth;
512 if (zeroPoint != 0) {
513 out <<
":" << zeroPoint;
521 llvm::interleaveComma(
522 llvm::seq<size_t>(0, blockSizeInfo.size()), out, [&](
size_t index) {
523 out << blockSizeInfo[index].first <<
":" << blockSizeInfo[index].second;
533 if (
Type expressedType = type.getExpressedType()) {
534 out <<
":" << expressedType;
544 out <<
":" << type.getExpressedType() <<
", ";
556 out <<
":" << type.getExpressedType() <<
":";
565 llvm::seq<size_t>(0, scales.size()), out,
567 printQuantParams(scales[index], zeroPoints[index], out);
590 int64_t rank = shape.size();
592 unsigned openBrackets = 0;
594 auto incrementCounterAndDelimit = [&]() {
596 for (
unsigned i = rank - 1; i > 0; --i) {
597 if (counter[i] >= shape[i]) {
606 for (
unsigned idx = 0, e = scales.size(); idx < e; ++idx) {
609 while (openBrackets++ < rank)
613 if (zeroPoints[idx] != 0) {
614 out <<
":" << zeroPoints[idx];
616 incrementCounterAndDelimit();
618 while (openBrackets-- > 0)
628 out <<
":" << type.getExpressedType() <<
":";
644 out <<
"calibrated<" << type.getExpressedType();
645 out <<
"<" << type.
getMin() <<
":" << type.
getMax() <<
">";
651 if (
auto anyType = llvm::dyn_cast<AnyQuantizedType>(type))
653 else if (
auto uniformType = llvm::dyn_cast<UniformQuantizedType>(type))
655 else if (
auto perAxisType = llvm::dyn_cast<UniformQuantizedPerAxisType>(type))
657 else if (
auto perAxisType =
658 llvm::dyn_cast<UniformQuantizedSubChannelType>(type))
660 else if (
auto calibratedType = llvm::dyn_cast<CalibratedQuantizedType>(type))
663 llvm_unreachable(
"Unhandled quantized type");
static void printBlockSizeInfo(ArrayRef< std::pair< int32_t, int64_t >> blockSizeInfo, DialectAsmPrinter &out)
static ParseResult parseQuantParams(DialectAsmParser &parser, Type expressedType, double &scale, int64_t &zeroPoint)
Parses a quantization parameter, which is either a scale value (float) or a scale-zero point pair (fl...
void printDenseQuantizationParameters(ArrayRef< APFloat > scales, ArrayRef< APInt > zeroPoints, ArrayRef< int64_t > shape, DialectAsmPrinter &out)
Prints quantization parameters as a nested list of scale[:zero_point] elements.
static void printAnyQuantizedType(AnyQuantizedType type, DialectAsmPrinter &out)
Helper that prints a AnyQuantizedType.
LogicalResult isScaleInExpressedTypeRange(function_ref< InFlightDiagnostic()> emitError, Type expressedType, double scale)
Checks if the given scale value is within the valid range of the expressed type.
static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser, double &min, double &max)
static void printUniformQuantizedSubChannelType(UniformQuantizedSubChannelType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedSubChannelType.
static Type parseUniformType(DialectAsmParser &parser)
Parses a UniformQuantizedType.
static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned)
static Type parseAnyType(DialectAsmParser &parser)
Parses an AnyQuantizedType.
static void printStorageType(QuantizedType type, DialectAsmPrinter &out)
static ParseResult parseBlockSizeInfoUntilRBrace(DialectAsmParser &parser, SmallVectorImpl< int32_t > &quantizedDimensions, SmallVectorImpl< int64_t > &blockSizes)
Parses block size information for sub-channel quantization, assuming the leading '{' has already been...
static void printQuantParams(double scale, int64_t zeroPoint, DialectAsmPrinter &out)
static Type parseCalibratedType(DialectAsmParser &parser)
Parses an CalibratedQuantizedType.
static ParseResult parseQuantParamListUntilRBrace(DialectAsmParser &parser, Type expressedType, SmallVectorImpl< double > &scales, SmallVectorImpl< int64_t > &zeroPoints, SmallVectorImpl< int64_t > &dims)
Parses a bracketed list of quantization parameters, returning the dimensions of the parsed sub-tensor...
static ParseResult parseStorageRange(DialectAsmParser &parser, IntegerType storageType, bool isSigned, int64_t &storageTypeMin, int64_t &storageTypeMax)
static void printCalibratedQuantizedType(CalibratedQuantizedType type, DialectAsmPrinter &out)
Helper that prints a CalibratedQuantizedType.
static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedPerAxisType.
static void printUniformQuantizedType(UniformQuantizedType type, DialectAsmPrinter &out)
Helper that prints a UniformQuantizedType.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
virtual ParseResult parseLBrace()=0
Parse a { token.
virtual ParseResult parseOptionalRBrace()=0
Parse a } token if present.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual OptionalParseResult parseOptionalType(Type &result)=0
Parse an optional type.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseFloat(double &result)=0
Parse a floating point value from the stream.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
IntegerType getIntegerType(unsigned width)
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class represents a diagnostic that is inflight and set to be reported.
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
A quantized type that maps storage to/from expressed types in an unspecified way.
A quantized type that infers its range from given min/max values.
Base class for all quantized types known to this dialect.
static constexpr unsigned MaxStorageBits
The maximum number of bits supported for storage types.
bool hasStorageTypeBounds() const
Return whether the storage type has explicit min or max boundaries different from the minimum and max...
bool isSigned() const
Whether the storage type should be interpreted as a signed quantity (true) or an unsigned value (fals...
int64_t getStorageTypeMax() const
The maximum value that storageType can take.
static int64_t getDefaultMaximumForInteger(bool isSigned, unsigned integralWidth)
Gets the maximum possible stored by a storageType.
unsigned getStorageTypeIntegralWidth() const
Gets the integral bit width that the underlying storage type can exactly represent.
int64_t getStorageTypeMin() const
The minimum value that storageType can take.
static int64_t getDefaultMinimumForInteger(bool isSigned, unsigned integralWidth)
Gets the minimum possible stored by a storageType.
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.