24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/Support/Endian.h"
55 case Token::kw_affine_map: {
59 if (
parseToken(Token::less,
"expected '<' in affine map") ||
61 parseToken(Token::greater,
"expected '>' in affine map"))
65 case Token::kw_affine_set: {
69 if (
parseToken(Token::less,
"expected '<' in integer set") ||
71 parseToken(Token::greater,
"expected '>' in integer set"))
77 case Token::l_square: {
80 auto parseElt = [&]() -> ParseResult {
82 return elements.back() ? success() : failure();
103 case Token::kw_dense_resource:
107 case Token::kw_array:
111 case Token::l_brace: {
119 case Token::hash_identifier:
123 case Token::floatliteral:
131 if (
getToken().is(Token::floatliteral))
135 "expected constant integer or floating point value"),
140 case Token::kw_loc: {
144 if (
parseToken(Token::l_paren,
"expected '(' in inline location") ||
146 parseToken(Token::r_paren,
"expected ')' in inline location"))
152 case Token::kw_sparse:
156 case Token::kw_strided:
160 case Token::kw_distinct:
164 case Token::string: {
176 case Token::at_identifier: {
181 referenceLocations.push_back(
getToken().getLocRange());
188 std::vector<FlatSymbolRefAttr> nestedRefs;
189 while (
getToken().is(Token::colon)) {
194 if (
getToken().isNot(Token::eof, Token::error)) {
202 if (
getToken().isNot(Token::at_identifier)) {
203 emitError(curLoc,
"expected nested symbol reference identifier");
210 referenceLocations.push_back(
getToken().getLocRange());
216 SymbolRefAttr symbolRefAttr =
222 return symbolRefAttr;
231 case Token::code_complete:
232 if (
getToken().isCodeCompletionFor(Token::hash_identifier))
251 case Token::at_identifier:
252 case Token::floatliteral:
254 case Token::hash_identifier:
255 case Token::kw_affine_map:
256 case Token::kw_affine_set:
257 case Token::kw_dense:
258 case Token::kw_dense_resource:
259 case Token::kw_false:
261 case Token::kw_sparse:
265 case Token::l_square:
269 return success(attribute !=
nullptr);
275 if (result.
has_value() && succeeded(*result))
300 llvm::SmallDenseSet<StringAttr> seenKeys;
301 auto parseElt = [&]() -> ParseResult {
303 std::optional<StringAttr> nameId;
306 else if (
getToken().
isAny(Token::bare_identifier, Token::inttype) ||
313 return emitError(
"expected valid attribute name");
315 if (!seenKeys.insert(*nameId).second)
317 << nameId->getValue() <<
"' in dictionary attribute";
321 auto splitName = nameId->strref().split(
'.');
322 if (!splitName.second.empty())
340 " in attribute dictionary");
347 return (
emitError(
"floating point value too large for attribute"),
nullptr);
356 if (!isa<FloatType>(type))
357 return (
emitError(
"floating point value not valid for specified type"),
365 StringRef spelling) {
368 bool isHex = spelling.size() > 1 && spelling[1] ==
'x';
369 if (spelling.getAsInteger(isHex ? 0 : 10, result))
373 unsigned width = type.
isIndex() ? IndexType::kInternalStorageBitWidth
376 if (width > result.getBitWidth()) {
377 result = result.zext(width);
378 }
else if (width < result.getBitWidth()) {
381 if (result.countl_zero() < result.getBitWidth() - width)
384 result = result.trunc(width);
392 }
else if (isNegative) {
396 if (!result.isSignBitSet())
399 result.isSignBitSet()) {
424 if (
auto floatType = dyn_cast<FloatType>(type)) {
425 std::optional<APFloat> result;
427 floatType.getFloatSemantics())))
432 if (!isa<IntegerType, IndexType>(type))
433 return emitError(loc,
"integer literal not valid for specified type"),
438 "negative integer literal not valid for unsigned integer type");
444 return emitError(loc,
"integer constant out of range for attribute"),
456 std::string &result) {
458 result = std::move(*value);
462 tok.
getLoc(),
"expected string containing hex digits starting with `0x`");
469 class TensorLiteralParser {
471 TensorLiteralParser(
Parser &p) : p(p) {}
475 ParseResult
parse(
bool allowHex);
485 ParseResult getIntAttrElements(SMLoc loc,
Type eltTy,
486 std::vector<APInt> &intValues);
489 ParseResult getFloatAttrElements(SMLoc loc,
FloatType eltTy,
490 std::vector<APFloat> &floatValues);
502 ParseResult parseElement();
513 ParseResult parseHexElements();
521 std::vector<std::pair<bool, Token>> storage;
524 std::optional<Token> hexStorage;
532 if (allowHex && p.getToken().is(Token::string)) {
533 hexStorage = p.getToken();
534 p.consumeToken(Token::string);
538 if (p.getToken().is(Token::l_square))
539 return parseList(shape);
540 return parseElement();
546 Type eltType = type.getElementType();
551 return getHexAttr(loc, type);
555 if (!shape.empty() &&
getShape() != type.getShape()) {
556 p.emitError(loc) <<
"inferred shape of elements literal ([" <<
getShape()
557 <<
"]) does not match type ([" << type.getShape() <<
"])";
562 if (!hexStorage && storage.empty() && type.getNumElements()) {
563 p.emitError(loc) <<
"parsed zero elements, but type (" << type
564 <<
") expected at least 1";
569 bool isComplex =
false;
570 if (ComplexType complexTy = dyn_cast<ComplexType>(eltType)) {
571 eltType = complexTy.getElementType();
577 std::vector<APInt> intValues;
578 if (failed(getIntAttrElements(loc, eltType, intValues)))
583 reinterpret_cast<std::complex<APInt> *
>(intValues.data()),
584 intValues.size() / 2);
590 if (
FloatType floatTy = dyn_cast<FloatType>(eltType)) {
591 std::vector<APFloat> floatValues;
592 if (failed(getFloatAttrElements(loc, floatTy, floatValues)))
597 reinterpret_cast<std::complex<APFloat> *
>(floatValues.data()),
598 floatValues.size() / 2);
605 return getStringAttr(loc, type, type.getElementType());
610 TensorLiteralParser::getIntAttrElements(SMLoc loc,
Type eltTy,
611 std::vector<APInt> &intValues) {
612 intValues.reserve(storage.size());
614 for (
const auto &signAndToken : storage) {
615 bool isNegative = signAndToken.first;
616 const Token &token = signAndToken.second;
617 auto tokenLoc = token.
getLoc();
619 if (isNegative && isUintType) {
620 return p.emitError(tokenLoc)
621 <<
"expected unsigned integer elements, but parsed negative value";
625 if (token.
is(Token::floatliteral)) {
626 return p.emitError(tokenLoc)
627 <<
"expected integer elements, but parsed floating-point";
630 assert(token.
isAny(Token::integer, Token::kw_true, Token::kw_false) &&
631 "unexpected token type");
632 if (token.
isAny(Token::kw_true, Token::kw_false)) {
634 return p.emitError(tokenLoc)
635 <<
"expected i1 type for 'true' or 'false' values";
637 APInt apInt(1, token.
is(Token::kw_true),
false);
638 intValues.push_back(apInt);
643 std::optional<APInt> apInt =
646 return p.emitError(tokenLoc,
"integer constant out of range for type");
647 intValues.push_back(*apInt);
654 TensorLiteralParser::getFloatAttrElements(SMLoc loc,
FloatType eltTy,
655 std::vector<APFloat> &floatValues) {
656 floatValues.reserve(storage.size());
657 for (
const auto &signAndToken : storage) {
658 bool isNegative = signAndToken.first;
659 const Token &token = signAndToken.second;
660 std::optional<APFloat> result;
661 if (failed(p.parseFloatFromLiteral(result, token, isNegative,
664 floatValues.push_back(*result);
672 if (hexStorage.has_value()) {
673 auto stringValue = hexStorage->getStringValue();
677 std::vector<std::string> stringValues;
678 std::vector<StringRef> stringRefValues;
679 stringValues.reserve(storage.size());
680 stringRefValues.reserve(storage.size());
682 for (
auto val : storage) {
683 stringValues.push_back(val.second.getStringValue());
684 stringRefValues.emplace_back(stringValues.back());
692 Type elementType = type.getElementType();
695 <<
"expected floating-point, integer, or complex element type, got "
705 bool detectedSplat =
false;
707 p.emitError(loc) <<
"elements hex data size is invalid for provided type: "
712 if (llvm::endianness::native == llvm::endianness::big) {
719 DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
720 rawData, convRawData, type);
727 ParseResult TensorLiteralParser::parseElement() {
728 switch (p.getToken().getKind()) {
731 case Token::kw_false:
732 case Token::floatliteral:
734 storage.emplace_back(
false, p.getToken());
740 p.consumeToken(Token::minus);
741 if (!p.getToken().isAny(Token::floatliteral, Token::integer))
742 return p.emitError(
"expected integer or floating point literal");
743 storage.emplace_back(
true, p.getToken());
748 storage.emplace_back(
false, p.getToken());
754 p.consumeToken(Token::l_paren);
755 if (parseElement() ||
756 p.parseToken(Token::comma,
"expected ',' between complex elements") ||
758 p.parseToken(Token::r_paren,
"expected ')' after complex elements"))
763 return p.emitError(
"expected element literal of primitive type");
778 if (prevDims == newDims)
780 return p.emitError(
"tensor literal is invalid; ranks are not consistent "
787 auto parseOneElement = [&]() -> ParseResult {
789 if (p.getToken().getKind() == Token::l_square) {
790 if (parseList(thisDims))
792 }
else if (parseElement()) {
797 return checkDims(newDims, thisDims);
807 dims.push_back(size);
808 dims.append(newDims.begin(), newDims.end());
819 class DenseArrayElementParser {
821 explicit DenseArrayElementParser(
Type type) : type(type) {}
824 ParseResult parseIntegerElement(
Parser &p);
827 ParseResult parseFloatElement(
Parser &p);
834 void append(
const APInt &data);
839 std::vector<char> rawData;
845 void DenseArrayElementParser::append(
const APInt &data) {
846 if (data.getBitWidth()) {
847 assert(data.getBitWidth() % 8 == 0);
848 unsigned byteSize = data.getBitWidth() / 8;
849 size_t offset = rawData.size();
850 rawData.insert(rawData.end(), byteSize, 0);
851 llvm::StoreIntToMemory(
852 data,
reinterpret_cast<uint8_t *
>(rawData.data() + offset), byteSize);
857 ParseResult DenseArrayElementParser::parseIntegerElement(
Parser &p) {
858 bool isNegative = p.
consumeIf(Token::minus);
861 std::optional<APInt> value;
864 if (!type.isInteger(1))
865 return p.
emitError(
"expected i1 type for 'true' or 'false' values");
866 value = APInt(8, p.
getToken().
is(Token::kw_true),
867 !type.isUnsignedInteger());
869 }
else if (p.
consumeIf(Token::integer)) {
872 return p.
emitError(
"integer constant out of range");
874 return p.
emitError(
"expected integer literal");
880 ParseResult DenseArrayElementParser::parseFloatElement(
Parser &p) {
881 bool isNegative = p.
consumeIf(Token::minus);
883 std::optional<APFloat> fromIntLit;
886 cast<FloatType>(type).getFloatSemantics())))
889 append(fromIntLit->bitcastToAPInt());
896 if (
parseToken(Token::less,
"expected '<' after 'array'"))
902 emitError(typeLoc,
"expected an integer or floating point type");
909 emitError(typeLoc,
"expected integer or float type, got: ") << eltType;
913 emitError(typeLoc,
"element type bitwidth must be a multiple of 8");
921 if (
parseToken(Token::colon,
"expected ':' after dense array type"))
924 DenseArrayElementParser eltParser(eltType);
927 [&] {
return eltParser.parseIntegerElement(*
this); }))
931 [&] {
return eltParser.parseFloatElement(*
this); }))
934 if (
parseToken(Token::greater,
"expected '>' to close an array attribute"))
936 return eltParser.getAttr();
943 if (
parseToken(Token::less,
"expected '<' after 'dense'"))
947 TensorLiteralParser literalParser(*
this);
949 if (literalParser.parse(
true) ||
961 return literalParser.getAttr(loc, type);
967 if (
parseToken(Token::less,
"expected '<' after 'dense_resource'"))
971 FailureOr<AsmDialectResourceHandle> rawHandle =
973 if (failed(rawHandle) ||
parseToken(Token::greater,
"expected '>'"))
976 auto *handle = dyn_cast<DenseResourceElementsHandle>(&*rawHandle);
978 return emitError(loc,
"invalid `dense_resource` handle type"),
nullptr;
988 ShapedType shapedType = dyn_cast<ShapedType>(attrType);
990 emitError(typeLoc,
"`dense_resource` expected a shaped type");
1005 if (
parseToken(Token::colon,
"expected ':'"))
1011 auto sType = dyn_cast<ShapedType>(type);
1013 emitError(
"elements literal must be a shaped type");
1017 if (!sType.hasStaticShape())
1018 return (
emitError(
"elements literal type must have static shape"),
nullptr);
1027 if (
parseToken(Token::less,
"Expected '<' after 'sparse'"))
1041 ShapedType indicesType =
1044 return getChecked<SparseElementsAttr>(
1052 TensorLiteralParser indiceParser(*
this);
1053 if (indiceParser.parse(
false))
1056 if (
parseToken(Token::comma,
"expected ','"))
1061 TensorLiteralParser valuesParser(*
this);
1062 if (valuesParser.parse(
true))
1065 if (
parseToken(Token::greater,
"expected '>'"))
1077 ShapedType indicesType;
1078 if (indiceParser.getShape().empty()) {
1084 auto indices = indiceParser.getAttr(indicesLoc, indicesType);
1089 auto valuesEltType = type.getElementType();
1090 ShapedType valuesType =
1091 valuesParser.getShape().empty()
1094 auto values = valuesParser.getAttr(valuesLoc, valuesType);
1097 return getChecked<SparseElementsAttr>(loc, type, indices, values);
1103 auto errorEmitter = [&] {
return emitError(loc); };
1106 if (failed(
parseToken(Token::less,
"expected '<' after 'strided'")) ||
1107 failed(
parseToken(Token::l_square,
"expected '['")))
1113 auto parseStrideOrOffset = [&]() -> std::optional<int64_t> {
1115 return ShapedType::kDynamic;
1119 emitError(loc,
"expected a 64-bit signed integer or '?'");
1120 return std::nullopt;
1123 bool negative =
consumeIf(Token::minus);
1125 if (
getToken().is(Token::integer)) {
1131 auto result =
static_cast<int64_t
>(*value);
1143 if (!
getToken().is(Token::r_square)) {
1145 std::optional<int64_t> stride = parseStrideOrOffset();
1148 strides.push_back(*stride);
1152 if (failed(
parseToken(Token::r_square,
"expected ']'")))
1163 if (failed(
parseToken(Token::comma,
"expected ','")) ||
1164 failed(
parseToken(Token::kw_offset,
"expected 'offset' after comma")) ||
1165 failed(
parseToken(Token::colon,
"expected ':' after 'offset'")))
1168 std::optional<int64_t> offset = parseStrideOrOffset();
1169 if (!offset || failed(
parseToken(Token::greater,
"expected '>'")))
1186 if (
parseToken(Token::l_square,
"expected '[' after 'distinct'"))
1191 if (
parseToken(Token::integer,
"expected distinct ID"))
1195 emitError(
"expected an unsigned 64-bit integer");
1200 if (
parseToken(Token::r_square,
"expected ']' to close distinct ID") ||
1201 parseToken(Token::less,
"expected '<' after distinct ID"))
1205 if (
getToken().is(Token::greater)) {
1210 if (!referencedAttr) {
1215 if (
parseToken(Token::greater,
"expected '>' to close distinct attribute"))
1224 auto it = distinctAttrs.find(*value);
1225 if (it == distinctAttrs.end()) {
1227 it = distinctAttrs.try_emplace(*value, distinctAttr).first;
1228 }
else if (it->getSecond().getReferencedAttr() != referencedAttr) {
1229 emitError(loc,
"referenced attribute does not match previous definition: ")
1230 << it->getSecond().getReferencedAttr();
1234 return it->getSecond();
static ParseResult parseElementAttrHexValues(Parser &parser, Token tok, std::string &result)
Parse elements values stored within a hex string.
static std::optional< APInt > buildAttributeAPInt(Type type, bool isNegative, StringRef spelling)
Construct an APint from a parsed value, a known attribute type and sign.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
void addUses(Value value, ArrayRef< SMLoc > locations)
Add a source uses of the given value.
@ Braces
{} brackets surrounding zero or more operands.
@ Square
Square brackets surrounding zero or more operands.
Attributes are known-constant values of operations.
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
An attribute that represents a reference to a dense vector or tensor object.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static bool isValidRawBuffer(ShapedType type, ArrayRef< char > rawBuffer, bool &detectedSplat)
Returns true if the given buffer is a valid raw buffer for the given type.
An attribute that associates a referenced attribute with a unique identifier.
static DistinctAttr create(Attribute referencedAttr)
Creates a distinct attribute that associates a referenced attribute with a unique identifier.
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
An integer set representing a conjunction of one or more affine equalities and inequalities.
void resetPointer(const char *newPointer)
Change the position of the lexer cursor.
Location objects represent source locations information in MLIR.
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This represents a token in the MLIR syntax.
bool isKeyword() const
Return true if this is one of the keyword token kinds (e.g. kw_if).
std::string getStringValue() const
Given a token containing a string literal, return its value, including removing the quote characters ...
std::string getSymbolReference() const
Given a token containing a symbol reference, return the unescaped string value.
static std::optional< uint64_t > getUInt64IntegerValue(StringRef spelling)
For an integer token, return its value as an uint64_t.
std::optional< double > getFloatingPointValue() const
For a floatliteral token, return its value as a double.
bool isAny(Kind k1, Kind k2) const
StringRef getSpelling() const
std::optional< std::string > getHexStringValue() const
Given a token containing a hex string literal, return its value or std::nullopt if the token does not...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class implement support for parsing global entities like attributes and types.
ParseResult parseFloatFromLiteral(std::optional< APFloat > &result, const Token &tok, bool isNegative, const llvm::fltSemantics &semantics)
Parse a floating point value from a literal.
Attribute parseDenseArrayAttr(Type type)
Parse a DenseArrayAttr.
Attribute parseStridedLayoutAttr()
Parse a strided layout attribute.
Attribute parseDecOrHexAttr(Type type, bool isNegative)
Parse a decimal or a hexadecimal literal, which can be either an integer or a float attribute.
OptionalParseResult parseOptionalType(Type &type)
Optionally parse a type.
ParseResult parseToken(Token::Kind expectedToken, const Twine &message)
Consume the specified token if present and return success.
ParseResult parseCommaSeparatedListUntil(Token::Kind rightToken, function_ref< ParseResult()> parseElement, bool allowEmptyList=true)
Parse a comma-separated list of elements up until the specified end token.
Type parseType()
Parse an arbitrary type.
Attribute parseDenseElementsAttr(Type attrType)
Parse a dense elements attribute.
Attribute parseDenseResourceElementsAttr(Type attrType)
Parse a dense resource elements attribute.
ParseResult parseAffineMapReference(AffineMap &map)
InFlightDiagnostic emitError(const Twine &message={})
Emit an error and return failure.
ParserState & state
The Parser is subclassed and reinstantiated.
Attribute parseAttribute(Type type={})
Parse an arbitrary attribute with an optional type.
StringRef getTokenSpelling() const
ParseResult parseLocationInstance(LocationAttr &loc)
Parse a raw location instance.
void consumeToken()
Advance the current lexer onto the next token.
Attribute codeCompleteAttribute()
ParseResult parseAttributeDict(NamedAttrList &attributes)
Parse an attribute dictionary.
ShapedType parseElementsLiteralType(Type type)
Shaped type for elements attribute.
MLIRContext * getContext() const
Attribute parseDistinctAttr(Type type)
Parse a distinct attribute.
InFlightDiagnostic emitWrongTokenError(const Twine &message={})
Emit an error about a "wrong token".
ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())
Parse a list of comma-separated items with an optional delimiter.
Attribute parseSparseElementsAttr(Type attrType)
Parse a sparse elements attribute.
OptionalParseResult parseOptionalAttribute(Attribute &attribute, Type type={})
Parse an optional attribute with the provided type.
Attribute parseFloatAttr(Type type, bool isNegative)
Parse a float attribute.
ParseResult parseFloatFromIntegerLiteral(std::optional< APFloat > &result, const Token &tok, bool isNegative, const llvm::fltSemantics &semantics)
Parse a floating point value from an integer literal token.
ParseResult parseIntegerSetReference(IntegerSet &set)
Attribute parseExtendedAttr(Type type)
Parse an extended attribute.
const Token & getToken() const
Return the current token the parser is inspecting.
FailureOr< AsmDialectResourceHandle > parseResourceHandle(const OpAsmDialectInterface *dialect, StringRef &name)
Parse a handle to a dialect resource within the assembly format.
bool consumeIf(Token::Kind kind)
If the current token has the specified kind, consume it and return true.
OptionalParseResult parseOptionalAttributeWithToken(Token::Kind kind, AttributeT &attr, Type type={})
Parse an optional attribute that is demarcated by a specific token.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
SymbolState & symbols
The current state for symbol parsing.
Lexer lex
The lexer for the source file we're parsing.
AsmParserState * asmState
An optional pointer to a struct containing high level parser state to be populated during parsing.
DenseMap< uint64_t, DistinctAttr > distinctAttributes
A map from unique integer identifier to DistinctAttr.