22 #include "llvm/ADT/STLExtras.h"
36 case Token::kw_memref:
37 case Token::kw_tensor:
38 case Token::kw_complex:
40 case Token::kw_vector:
42 case Token::kw_f8E5M2:
43 case Token::kw_f8E4M3:
44 case Token::kw_f8E4M3FN:
45 case Token::kw_f8E5M2FNUZ:
46 case Token::kw_f8E4M3FNUZ:
47 case Token::kw_f8E4M3B11FNUZ:
57 case Token::exclamation_identifier:
88 elements.push_back(t);
98 auto parseElt = [&]() -> ParseResult {
100 elements.push_back(elt);
101 return elt ? success() : failure();
113 if (
parseToken(Token::l_paren,
"expected '('"))
134 if (
parseToken(Token::less,
"expected '<' in complex type"))
140 parseToken(Token::greater,
"expected '>' in complex type"))
142 if (!isa<FloatType>(elementType) && !isa<IntegerType>(elementType))
143 return emitError(elementTypeLoc,
"invalid element type for complex"),
154 assert(
getToken().is(Token::l_paren));
158 parseToken(Token::arrow,
"expected '->' in function type") ||
183 if (
parseToken(Token::less,
"expected '<' in memref type"))
209 return emitError(typeLoc,
"invalid memref element type"),
nullptr;
211 MemRefLayoutAttrInterface layout;
214 auto parseElt = [&]() -> ParseResult {
220 if (isa<MemRefLayoutAttrInterface>(attr)) {
221 layout = cast<MemRefLayoutAttrInterface>(attr);
222 }
else if (memorySpace) {
223 return emitError(
"multiple memory spaces specified in memref type");
230 return emitError(
"cannot have affine map for unranked memref type");
232 return emitError(
"expected memory space to be last in memref type");
240 if (
parseToken(Token::comma,
"expected ',' or '>' in memref type") ||
248 return getChecked<UnrankedMemRefType>(loc, elementType, memorySpace);
250 return getChecked<MemRefType>(loc, dimensions, elementType, layout,
275 case Token::kw_memref:
277 case Token::kw_tensor:
279 case Token::kw_complex:
281 case Token::kw_tuple:
283 case Token::kw_vector:
286 case Token::inttype: {
288 if (!width.has_value())
289 return (
emitError(
"invalid integer width"),
nullptr);
290 if (*width > IntegerType::kMaxWidth) {
292 << IntegerType::kMaxWidth <<
" bits";
296 IntegerType::SignednessSemantics signSemantics = IntegerType::Signless;
297 if (std::optional<bool> signedness =
getToken().getIntTypeSignedness())
305 case Token::kw_f8E5M2:
308 case Token::kw_f8E4M3:
311 case Token::kw_f8E4M3FN:
314 case Token::kw_f8E5M2FNUZ:
317 case Token::kw_f8E4M3FNUZ:
320 case Token::kw_f8E4M3B11FNUZ:
346 case Token::kw_index:
356 case Token::exclamation_identifier:
360 case Token::code_complete:
361 if (
getToken().isCodeCompletionFor(Token::exclamation_identifier))
375 if (
parseToken(Token::less,
"expected '<' in tensor type"))
402 if (parseResult.has_value()) {
403 if (failed(parseResult.value()))
405 if (
auto v = dyn_cast_or_null<VerifiableTensorEncoding>(encoding)) {
406 if (failed(v.verifyEncoding(dimensions, elementType,
407 [&] { return emitError(); })))
413 if (!elementType ||
parseToken(Token::greater,
"expected '>' in tensor type"))
416 return emitError(elementTypeLoc,
"invalid tensor element type"),
nullptr;
420 return emitError(
"cannot apply encoding to unranked tensor"),
nullptr;
434 if (
parseToken(Token::less,
"expected '<' in tuple type"))
444 parseToken(Token::greater,
"expected '>' in tuple type"))
459 if (
parseToken(Token::less,
"expected '<' in vector type"))
466 if (any_of(dimensions, [](int64_t i) {
return i <= 0; }))
468 "vector types must have positive constant sizes"),
474 if (!elementType ||
parseToken(Token::greater,
"expected '>' in vector type"))
477 if (!VectorType::isValidElementType(elementType))
478 return emitError(typeLoc,
"vector elements must be int/index/float type"),
499 bool scalable =
consumeIf(Token::l_square);
502 dimensions.push_back(value);
507 scalableDims.push_back(scalable);
531 bool allowDynamic,
bool withTrailingX) {
532 auto parseDim = [&]() -> LogicalResult {
536 return emitError(loc,
"expected static shape");
537 dimensions.push_back(ShapedType::kDynamic);
542 dimensions.push_back(value);
548 while (
getToken().isAny(Token::integer, Token::question)) {
555 if (
getToken().isAny(Token::integer, Token::question)) {
556 if (failed(parseDim()))
558 while (
getToken().is(Token::bare_identifier) &&
586 value = (int64_t)*dimension;
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a memref.
FloatType getFloat8E5M2Type()
FloatType getFloat8E4M3B11FNUZType()
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
FloatType getFloat8E4M3Type()
FloatType getFloat8E4M3FNType()
FloatType getFloat8E4M3FNUZType()
FloatType getFloat8E5M2FNUZType()
void resetPointer(const char *newPointer)
Change the position of the lexer cursor.
This class implements Optional functionality for ParseResult.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a tensor.
static std::optional< uint64_t > getUInt64IntegerValue(StringRef spelling)
For an integer token, return its value as an uint64_t.
std::optional< unsigned > getIntTypeBitwidth() const
For an inttype token, return its bitwidth.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
ParseResult parseXInDimensionList()
Parse an 'x' token in a dimension list, handling the case where the x is juxtaposed with an element t...
OptionalParseResult parseOptionalType(Type &type)
Optionally parse a type.
ParseResult parseToken(Token::Kind expectedToken, const Twine &message)
Consume the specified token if present and return success.
ParseResult parseCommaSeparatedListUntil(Token::Kind rightToken, function_ref< ParseResult()> parseElement, bool allowEmptyList=true)
Parse a comma-separated list of elements up until the specified end token.
Type parseType()
Parse an arbitrary type.
ParseResult parseTypeListParens(SmallVectorImpl< Type > &elements)
Parse a parenthesized list of types.
ParseResult parseVectorDimensionList(SmallVectorImpl< int64_t > &dimensions, SmallVectorImpl< bool > &scalableDims)
Parse a dimension list in a vector type.
Type parseMemRefType()
Parse a memref type.
Type parseNonFunctionType()
Parse a non function type.
Type parseExtendedType()
Parse an extended type.
Type parseTupleType()
Parse a tuple type.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error and return failure.
ParserState & state
The Parser is subclassed and reinstantiated.
Attribute parseAttribute(Type type={})
Parse an arbitrary attribute with an optional type.
StringRef getTokenSpelling() const
void consumeToken()
Advance the current lexer onto the next token.
ParseResult parseIntegerInDimensionList(int64_t &value)
Type parseComplexType()
Parse a complex type.
ParseResult parseDimensionListRanked(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)
Parse a dimension list of a tensor or memref type.
ParseResult parseFunctionResultTypes(SmallVectorImpl< Type > &elements)
Parse a function result type.
MLIRContext * getContext() const
InFlightDiagnostic emitWrongTokenError(const Twine &message={})
Emit an error about a "wrong token".
ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())
Parse a list of comma-separated items with an optional delimiter.
VectorType parseVectorType()
Parse a vector type.
Type parseFunctionType()
Parse a function type.
OptionalParseResult parseOptionalAttribute(Attribute &attribute, Type type={})
Parse an optional attribute with the provided type.
ParseResult parseTypeListNoParens(SmallVectorImpl< Type > &elements)
Parse a list of types without an enclosing parenthesis.
const Token & getToken() const
Return the current token the parser is inspecting.
Type parseTensorType()
Parse a tensor type.
bool consumeIf(Token::Kind kind)
If the current token has the specified kind, consume it and return true.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Lexer lex
The lexer for the source file we're parsing.