24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/Support/Endian.h"
55 case Token::kw_affine_map: {
59 if (
parseToken(Token::less,
"expected '<' in affine map") ||
61 parseToken(Token::greater,
"expected '>' in affine map"))
65 case Token::kw_affine_set: {
69 if (
parseToken(Token::less,
"expected '<' in integer set") ||
71 parseToken(Token::greater,
"expected '>' in integer set"))
77 case Token::l_square: {
103 case Token::kw_dense_resource:
107 case Token::kw_array:
111 case Token::l_brace: {
119 case Token::hash_identifier:
123 case Token::floatliteral:
131 if (
getToken().is(Token::floatliteral))
135 "expected constant integer or floating point value"),
140 case Token::kw_loc: {
144 if (
parseToken(Token::l_paren,
"expected '(' in inline location") ||
146 parseToken(Token::r_paren,
"expected ')' in inline location"))
152 case Token::kw_sparse:
156 case Token::kw_strided:
160 case Token::kw_distinct:
164 case Token::string: {
176 case Token::at_identifier: {
181 referenceLocations.push_back(
getToken().getLocRange());
188 std::vector<FlatSymbolRefAttr> nestedRefs;
189 while (
getToken().is(Token::colon)) {
194 if (
getToken().isNot(Token::eof, Token::error)) {
202 if (
getToken().isNot(Token::at_identifier)) {
203 emitError(curLoc,
"expected nested symbol reference identifier");
210 referenceLocations.push_back(
getToken().getLocRange());
216 SymbolRefAttr symbolRefAttr =
222 return symbolRefAttr;
231 case Token::code_complete:
232 if (
getToken().isCodeCompletionFor(Token::hash_identifier))
251 case Token::at_identifier:
252 case Token::floatliteral:
254 case Token::hash_identifier:
255 case Token::kw_affine_map:
256 case Token::kw_affine_set:
257 case Token::kw_dense:
258 case Token::kw_dense_resource:
259 case Token::kw_false:
261 case Token::kw_sparse:
265 case Token::l_square:
269 return success(attribute !=
nullptr);
300 llvm::SmallDenseSet<StringAttr> seenKeys;
303 std::optional<StringAttr> nameId;
306 else if (
getToken().
isAny(Token::bare_identifier, Token::inttype) ||
313 return emitError(
"expected valid attribute name");
315 if (!seenKeys.insert(*nameId).second)
317 << nameId->getValue() <<
"' in dictionary attribute";
321 auto splitName = nameId->strref().split(
'.');
322 if (!splitName.second.empty())
340 " in attribute dictionary");
347 return (
emitError(
"floating point value too large for attribute"),
nullptr);
356 if (!isa<FloatType>(type))
357 return (
emitError(
"floating point value not valid for specified type"),
365 StringRef spelling) {
368 bool isHex = spelling.size() > 1 && spelling[1] ==
'x';
369 if (spelling.getAsInteger(isHex ? 0 : 10, result))
373 unsigned width = type.
isIndex() ? IndexType::kInternalStorageBitWidth
376 if (width > result.getBitWidth()) {
377 result = result.zext(width);
378 }
else if (width < result.getBitWidth()) {
381 if (result.countl_zero() < result.getBitWidth() - width)
384 result = result.trunc(width);
392 }
else if (isNegative) {
396 if (!result.isSignBitSet())
399 result.isSignBitSet()) {
424 if (
auto floatType = dyn_cast<FloatType>(type)) {
425 std::optional<APFloat> result;
427 floatType.getFloatSemantics(),
428 floatType.getWidth())))
433 if (!isa<IntegerType, IndexType>(type))
434 return emitError(loc,
"integer literal not valid for specified type"),
439 "negative integer literal not valid for unsigned integer type");
445 return emitError(loc,
"integer constant out of range for attribute"),
457 std::string &result) {
459 result = std::move(*value);
463 tok.
getLoc(),
"expected string containing hex digits starting with `0x`");
470 class TensorLiteralParser {
472 TensorLiteralParser(
Parser &p) : p(p) {}
487 std::vector<APInt> &intValues);
491 std::vector<APFloat> &floatValues);
522 std::vector<std::pair<bool, Token>> storage;
525 std::optional<Token> hexStorage;
533 if (allowHex && p.getToken().is(Token::string)) {
534 hexStorage = p.getToken();
535 p.consumeToken(Token::string);
539 if (p.getToken().is(Token::l_square))
540 return parseList(shape);
541 return parseElement();
547 Type eltType = type.getElementType();
552 return getHexAttr(loc, type);
556 if (!shape.empty() &&
getShape() != type.getShape()) {
557 p.emitError(loc) <<
"inferred shape of elements literal ([" <<
getShape()
558 <<
"]) does not match type ([" << type.getShape() <<
"])";
563 if (!hexStorage && storage.empty() && type.getNumElements()) {
564 p.emitError(loc) <<
"parsed zero elements, but type (" << type
565 <<
") expected at least 1";
570 bool isComplex =
false;
571 if (ComplexType complexTy = dyn_cast<ComplexType>(eltType)) {
572 eltType = complexTy.getElementType();
578 std::vector<APInt> intValues;
579 if (
failed(getIntAttrElements(loc, eltType, intValues)))
584 reinterpret_cast<std::complex<APInt> *
>(intValues.data()),
585 intValues.size() / 2);
591 if (
FloatType floatTy = dyn_cast<FloatType>(eltType)) {
592 std::vector<APFloat> floatValues;
593 if (
failed(getFloatAttrElements(loc, floatTy, floatValues)))
598 reinterpret_cast<std::complex<APFloat> *
>(floatValues.data()),
599 floatValues.size() / 2);
606 return getStringAttr(loc, type, type.getElementType());
611 TensorLiteralParser::getIntAttrElements(SMLoc loc,
Type eltTy,
612 std::vector<APInt> &intValues) {
613 intValues.reserve(storage.size());
615 for (
const auto &signAndToken : storage) {
616 bool isNegative = signAndToken.first;
617 const Token &token = signAndToken.second;
618 auto tokenLoc = token.
getLoc();
620 if (isNegative && isUintType) {
621 return p.emitError(tokenLoc)
622 <<
"expected unsigned integer elements, but parsed negative value";
626 if (token.
is(Token::floatliteral)) {
627 return p.emitError(tokenLoc)
628 <<
"expected integer elements, but parsed floating-point";
631 assert(token.
isAny(Token::integer, Token::kw_true, Token::kw_false) &&
632 "unexpected token type");
633 if (token.
isAny(Token::kw_true, Token::kw_false)) {
635 return p.emitError(tokenLoc)
636 <<
"expected i1 type for 'true' or 'false' values";
638 APInt apInt(1, token.
is(Token::kw_true),
false);
639 intValues.push_back(apInt);
644 std::optional<APInt> apInt =
647 return p.emitError(tokenLoc,
"integer constant out of range for type");
648 intValues.push_back(*apInt);
655 TensorLiteralParser::getFloatAttrElements(SMLoc loc,
FloatType eltTy,
656 std::vector<APFloat> &floatValues) {
657 floatValues.reserve(storage.size());
658 for (
const auto &signAndToken : storage) {
659 bool isNegative = signAndToken.first;
660 const Token &token = signAndToken.second;
663 if (token.
is(Token::integer) && token.
getSpelling().startswith(
"0x")) {
664 std::optional<APFloat> result;
665 if (
failed(p.parseFloatFromIntegerLiteral(result, token, isNegative,
670 floatValues.push_back(*result);
675 if (!token.
is(Token::floatliteral))
677 <<
"expected floating-point elements, but parsed integer";
682 return p.emitError(
"floating point value too large for attribute");
684 APFloat apVal(isNegative ? -*val : *val);
685 if (!eltTy.
isF64()) {
690 floatValues.push_back(apVal);
698 if (hexStorage.has_value()) {
699 auto stringValue = hexStorage->getStringValue();
703 std::vector<std::string> stringValues;
704 std::vector<StringRef> stringRefValues;
705 stringValues.reserve(storage.size());
706 stringRefValues.reserve(storage.size());
708 for (
auto val : storage) {
709 stringValues.push_back(val.second.getStringValue());
710 stringRefValues.emplace_back(stringValues.back());
718 Type elementType = type.getElementType();
721 <<
"expected floating-point, integer, or complex element type, got "
731 bool detectedSplat =
false;
733 p.emitError(loc) <<
"elements hex data size is invalid for provided type: "
738 if (llvm::endianness::native == llvm::endianness::big) {
745 DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
746 rawData, convRawData, type);
754 switch (p.getToken().getKind()) {
757 case Token::kw_false:
758 case Token::floatliteral:
760 storage.emplace_back(
false, p.getToken());
766 p.consumeToken(Token::minus);
767 if (!p.getToken().isAny(Token::floatliteral, Token::integer))
768 return p.emitError(
"expected integer or floating point literal");
769 storage.emplace_back(
true, p.getToken());
774 storage.emplace_back(
false, p.getToken());
780 p.consumeToken(Token::l_paren);
781 if (parseElement() ||
782 p.parseToken(Token::comma,
"expected ',' between complex elements") ||
784 p.parseToken(Token::r_paren,
"expected ')' after complex elements"))
789 return p.emitError(
"expected element literal of primitive type");
804 if (prevDims == newDims)
806 return p.emitError(
"tensor literal is invalid; ranks are not consistent "
815 if (p.getToken().getKind() == Token::l_square) {
816 if (parseList(thisDims))
818 }
else if (parseElement()) {
823 return checkDims(newDims, thisDims);
833 dims.push_back(size);
834 dims.append(newDims.begin(), newDims.end());
845 class DenseArrayElementParser {
847 explicit DenseArrayElementParser(
Type type) : type(type) {}
860 void append(
const APInt &data);
865 std::vector<char> rawData;
871 void DenseArrayElementParser::append(
const APInt &data) {
872 if (data.getBitWidth()) {
873 assert(data.getBitWidth() % 8 == 0);
874 unsigned byteSize = data.getBitWidth() / 8;
875 size_t offset = rawData.size();
876 rawData.insert(rawData.end(), byteSize, 0);
877 llvm::StoreIntToMemory(
878 data,
reinterpret_cast<uint8_t *
>(rawData.data() + offset), byteSize);
884 bool isNegative = p.
consumeIf(Token::minus);
887 std::optional<APInt> value;
890 if (!type.isInteger(1))
891 return p.
emitError(
"expected i1 type for 'true' or 'false' values");
892 value = APInt(8, p.
getToken().
is(Token::kw_true),
893 !type.isUnsignedInteger());
895 }
else if (p.
consumeIf(Token::integer)) {
898 return p.
emitError(
"integer constant out of range");
900 return p.
emitError(
"expected integer literal");
907 bool isNegative = p.
consumeIf(Token::minus);
910 std::optional<APFloat> result;
911 auto floatType = cast<FloatType>(type);
915 floatType.getFloatSemantics(),
916 floatType.getWidth()))
918 }
else if (p.
consumeIf(Token::floatliteral)) {
923 result = APFloat(isNegative ? -*val : *val);
926 result->convert(floatType.getFloatSemantics(),
927 APFloat::rmNearestTiesToEven, &unused);
930 return p.
emitError(
"expected integer or floating point literal");
933 append(result->bitcastToAPInt());
940 if (
parseToken(Token::less,
"expected '<' after 'array'"))
946 emitError(typeLoc,
"expected an integer or floating point type");
953 emitError(typeLoc,
"expected integer or float type, got: ") << eltType;
957 emitError(typeLoc,
"element type bitwidth must be a multiple of 8");
965 if (
parseToken(Token::colon,
"expected ':' after dense array type"))
968 DenseArrayElementParser eltParser(eltType);
971 [&] {
return eltParser.parseIntegerElement(*
this); }))
975 [&] {
return eltParser.parseFloatElement(*
this); }))
978 if (
parseToken(Token::greater,
"expected '>' to close an array attribute"))
980 return eltParser.getAttr();
987 if (
parseToken(Token::less,
"expected '<' after 'dense'"))
991 TensorLiteralParser literalParser(*
this);
993 if (literalParser.parse(
true) ||
1005 return literalParser.getAttr(loc, type);
1011 if (
parseToken(Token::less,
"expected '<' after 'dense_resource'"))
1020 auto *handle = dyn_cast<DenseResourceElementsHandle>(&*rawHandle);
1022 return emitError(loc,
"invalid `dense_resource` handle type"),
nullptr;
1025 SMLoc typeLoc = loc;
1032 ShapedType shapedType = dyn_cast<ShapedType>(attrType);
1034 emitError(typeLoc,
"`dense_resource` expected a shaped type");
1049 if (
parseToken(Token::colon,
"expected ':'"))
1055 auto sType = dyn_cast<ShapedType>(type);
1057 emitError(
"elements literal must be a shaped type");
1061 if (!sType.hasStaticShape())
1062 return (
emitError(
"elements literal type must have static shape"),
nullptr);
1071 if (
parseToken(Token::less,
"Expected '<' after 'sparse'"))
1085 ShapedType indicesType =
1088 return getChecked<SparseElementsAttr>(
1096 TensorLiteralParser indiceParser(*
this);
1097 if (indiceParser.parse(
false))
1100 if (
parseToken(Token::comma,
"expected ','"))
1105 TensorLiteralParser valuesParser(*
this);
1106 if (valuesParser.parse(
true))
1109 if (
parseToken(Token::greater,
"expected '>'"))
1121 ShapedType indicesType;
1122 if (indiceParser.getShape().empty()) {
1128 auto indices = indiceParser.getAttr(indicesLoc, indicesType);
1133 auto valuesEltType = type.getElementType();
1134 ShapedType valuesType =
1135 valuesParser.getShape().empty()
1138 auto values = valuesParser.getAttr(valuesLoc, valuesType);
1141 return getChecked<SparseElementsAttr>(loc, type, indices, values);
1147 auto errorEmitter = [&] {
return emitError(loc); };
1157 auto parseStrideOrOffset = [&]() -> std::optional<int64_t> {
1159 return ShapedType::kDynamic;
1163 emitError(loc,
"expected a 64-bit signed integer or '?'");
1164 return std::nullopt;
1167 bool negative =
consumeIf(Token::minus);
1169 if (
getToken().is(Token::integer)) {
1175 auto result =
static_cast<int64_t
>(*value);
1187 if (!
getToken().is(Token::r_square)) {
1189 std::optional<int64_t> stride = parseStrideOrOffset();
1192 strides.push_back(*stride);
1212 std::optional<int64_t> offset = parseStrideOrOffset();
1230 if (
parseToken(Token::l_square,
"expected '[' after 'distinct'"))
1235 if (
parseToken(Token::integer,
"expected distinct ID"))
1239 emitError(
"expected an unsigned 64-bit integer");
1244 if (
parseToken(Token::r_square,
"expected ']' to close distinct ID") ||
1245 parseToken(Token::less,
"expected '<' after distinct ID"))
1249 if (
getToken().is(Token::greater)) {
1254 if (!referencedAttr) {
1259 if (
parseToken(Token::greater,
"expected '>' to close distinct attribute"))
1268 auto it = distinctAttrs.find(*value);
1269 if (it == distinctAttrs.end()) {
1271 it = distinctAttrs.try_emplace(*value, distinctAttr).first;
1272 }
else if (it->getSecond().getReferencedAttr() != referencedAttr) {
1273 emitError(loc,
"referenced attribute does not match previous definition: ")
1274 << it->getSecond().getReferencedAttr();
1278 return it->getSecond();
static ParseResult parseElementAttrHexValues(Parser &parser, Token tok, std::string &result)
Parse elements values stored within a hex string.
static std::optional< APInt > buildAttributeAPInt(Type type, bool isNegative, StringRef spelling)
Construct an APint from a parsed value, a known attribute type and sign.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
void addUses(Value value, ArrayRef< SMLoc > locations)
Add a source uses of the given value.
@ Braces
{} brackets surrounding zero or more operands.
@ Square
Square brackets surrounding zero or more operands.
Attributes are known-constant values of operations.
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
An attribute that represents a reference to a dense vector or tensor object.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static bool isValidRawBuffer(ShapedType type, ArrayRef< char > rawBuffer, bool &detectedSplat)
Returns true if the given buffer is a valid raw buffer for the given type.
An attribute that associates a referenced attribute with a unique identifier.
static DistinctAttr create(Attribute referencedAttr)
Creates a distinct attribute that associates a referenced attribute with a unique identifier.
This class provides support for representing a failure result, or a valid value of type T.
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
unsigned getWidth()
Return the bitwidth of this float type.
An integer set representing a conjunction of one or more affine equalities and inequalities.
void resetPointer(const char *newPointer)
Change the position of the lexer cursor.
Location objects represent source locations information in MLIR.
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class represents success/failure for parsing-like operations that find it important to chain tog...
This represents a token in the MLIR syntax.
bool isKeyword() const
Return true if this is one of the keyword token kinds (e.g. kw_if).
std::string getStringValue() const
Given a token containing a string literal, return its value, including removing the quote characters ...
std::string getSymbolReference() const
Given a token containing a symbol reference, return the unescaped string value.
static std::optional< uint64_t > getUInt64IntegerValue(StringRef spelling)
For an integer token, return its value as an uint64_t.
std::optional< double > getFloatingPointValue() const
For a floatliteral token, return its value as a double.
bool isAny(Kind k1, Kind k2) const
StringRef getSpelling() const
std::optional< std::string > getHexStringValue() const
Given a token containing a hex string literal, return its value or std::nullopt if the token does not...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class implement support for parsing global entities like attributes and types.
Attribute parseDenseArrayAttr(Type type)
Parse a DenseArrayAttr.
Attribute parseStridedLayoutAttr()
Parse a strided layout attribute.
Attribute parseDecOrHexAttr(Type type, bool isNegative)
Parse a decimal or a hexadecimal literal, which can be either an integer or a float attribute.
OptionalParseResult parseOptionalType(Type &type)
Optionally parse a type.
ParseResult parseToken(Token::Kind expectedToken, const Twine &message)
Consume the specified token if present and return success.
ParseResult parseCommaSeparatedListUntil(Token::Kind rightToken, function_ref< ParseResult()> parseElement, bool allowEmptyList=true)
Parse a comma-separated list of elements up until the specified end token.
Type parseType()
Parse an arbitrary type.
Attribute parseDenseElementsAttr(Type attrType)
Parse a dense elements attribute.
Attribute parseDenseResourceElementsAttr(Type attrType)
Parse a dense resource elements attribute.
ParseResult parseAffineMapReference(AffineMap &map)
InFlightDiagnostic emitError(const Twine &message={})
Emit an error and return failure.
ParserState & state
The Parser is subclassed and reinstantiated.
Attribute parseAttribute(Type type={})
Parse an arbitrary attribute with an optional type.
StringRef getTokenSpelling() const
ParseResult parseLocationInstance(LocationAttr &loc)
Parse a raw location instance.
void consumeToken()
Advance the current lexer onto the next token.
Attribute codeCompleteAttribute()
ParseResult parseAttributeDict(NamedAttrList &attributes)
Parse an attribute dictionary.
ShapedType parseElementsLiteralType(Type type)
Shaped type for elements attribute.
MLIRContext * getContext() const
Attribute parseDistinctAttr(Type type)
Parse a distinct attribute.
InFlightDiagnostic emitWrongTokenError(const Twine &message={})
Emit an error about a "wrong token".
ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())
Parse a list of comma-separated items with an optional delimiter.
Attribute parseSparseElementsAttr(Type attrType)
Parse a sparse elements attribute.
OptionalParseResult parseOptionalAttribute(Attribute &attribute, Type type={})
Parse an optional attribute with the provided type.
ParseResult parseFloatFromIntegerLiteral(std::optional< APFloat > &result, const Token &tok, bool isNegative, const llvm::fltSemantics &semantics, size_t typeSizeInBits)
Parse a floating point value from an integer literal token.
Attribute parseFloatAttr(Type type, bool isNegative)
Parse a float attribute.
ParseResult parseIntegerSetReference(IntegerSet &set)
Attribute parseExtendedAttr(Type type)
Parse an extended attribute.
const Token & getToken() const
Return the current token the parser is inspecting.
FailureOr< AsmDialectResourceHandle > parseResourceHandle(const OpAsmDialectInterface *dialect, StringRef &name)
Parse a handle to a dialect resource within the assembly format.
bool consumeIf(Token::Kind kind)
If the current token has the specified kind, consume it and return true.
OptionalParseResult parseOptionalAttributeWithToken(Token::Kind kind, AttributeT &attr, Type type={})
Parse an optional attribute that is demarcated by a specific token.
Detect if any of the given parameter types has a sub-element handler.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
SymbolState & symbols
The current state for symbol parsing.
Lexer lex
The lexer for the source file we're parsing.
AsmParserState * asmState
An optional pointer to a struct containing high level parser state to be populated during parsing.
DenseMap< uint64_t, DistinctAttr > distinctAttributes
A map from unique integer identifier to DistinctAttr.