23 #include "llvm/ADT/STLExtras.h"
37 case Token::kw_memref:
38 case Token::kw_tensor:
39 case Token::kw_complex:
41 case Token::kw_vector:
43 case Token::kw_f8E5M2:
44 case Token::kw_f8E4M3FN:
45 case Token::kw_f8E5M2FNUZ:
46 case Token::kw_f8E4M3FNUZ:
47 case Token::kw_f8E4M3B11FNUZ:
57 case Token::exclamation_identifier:
88 elements.push_back(t);
100 elements.push_back(elt);
113 if (
parseToken(Token::l_paren,
"expected '('"))
134 if (
parseToken(Token::less,
"expected '<' in complex type"))
140 parseToken(Token::greater,
"expected '>' in complex type"))
142 if (!isa<FloatType>(elementType) && !isa<IntegerType>(elementType))
143 return emitError(elementTypeLoc,
"invalid element type for complex"),
154 assert(
getToken().is(Token::l_paren));
158 parseToken(Token::arrow,
"expected '->' in function type") ||
183 if (
parseToken(Token::less,
"expected '<' in memref type"))
209 return emitError(typeLoc,
"invalid memref element type"),
nullptr;
211 MemRefLayoutAttrInterface layout;
220 if (isa<MemRefLayoutAttrInterface>(attr)) {
221 layout = cast<MemRefLayoutAttrInterface>(attr);
222 }
else if (memorySpace) {
223 return emitError(
"multiple memory spaces specified in memref type");
230 return emitError(
"cannot have affine map for unranked memref type");
232 return emitError(
"expected memory space to be last in memref type");
240 if (
parseToken(Token::comma,
"expected ',' or '>' in memref type") ||
248 return getChecked<UnrankedMemRefType>(loc, elementType, memorySpace);
250 return getChecked<MemRefType>(loc, dimensions, elementType, layout,
275 case Token::kw_memref:
277 case Token::kw_tensor:
279 case Token::kw_complex:
281 case Token::kw_tuple:
283 case Token::kw_vector:
286 case Token::inttype: {
288 if (!width.has_value())
289 return (
emitError(
"invalid integer width"),
nullptr);
290 if (*width > IntegerType::kMaxWidth) {
292 << IntegerType::kMaxWidth <<
" bits";
296 IntegerType::SignednessSemantics signSemantics = IntegerType::Signless;
297 if (std::optional<bool> signedness =
getToken().getIntTypeSignedness())
305 case Token::kw_f8E5M2:
308 case Token::kw_f8E4M3FN:
311 case Token::kw_f8E5M2FNUZ:
314 case Token::kw_f8E4M3FNUZ:
317 case Token::kw_f8E4M3B11FNUZ:
343 case Token::kw_index:
353 case Token::exclamation_identifier:
357 case Token::code_complete:
358 if (
getToken().isCodeCompletionFor(Token::exclamation_identifier))
372 if (
parseToken(Token::less,
"expected '<' in tensor type"))
399 if (parseResult.has_value()) {
400 if (
failed(parseResult.value()))
402 if (
auto v = dyn_cast_or_null<VerifiableTensorEncoding>(encoding)) {
403 if (
failed(v.verifyEncoding(dimensions, elementType,
404 [&] { return emitError(); })))
410 if (!elementType ||
parseToken(Token::greater,
"expected '>' in tensor type"))
413 return emitError(elementTypeLoc,
"invalid tensor element type"),
nullptr;
417 return emitError(
"cannot apply encoding to unranked tensor"),
nullptr;
431 if (
parseToken(Token::less,
"expected '<' in tuple type"))
441 parseToken(Token::greater,
"expected '>' in tuple type"))
456 if (
parseToken(Token::less,
"expected '<' in vector type"))
463 if (any_of(dimensions, [](int64_t i) {
return i <= 0; }))
465 "vector types must have positive constant sizes"),
471 if (!elementType ||
parseToken(Token::greater,
"expected '>' in vector type"))
474 if (!VectorType::isValidElementType(elementType))
475 return emitError(typeLoc,
"vector elements must be int/index/float type"),
496 bool scalable =
consumeIf(Token::l_square);
499 dimensions.push_back(value);
504 scalableDims.push_back(scalable);
528 bool allowDynamic,
bool withTrailingX) {
533 return emitError(loc,
"expected static shape");
534 dimensions.push_back(ShapedType::kDynamic);
539 dimensions.push_back(value);
545 while (
getToken().isAny(Token::integer, Token::question)) {
552 if (
getToken().isAny(Token::integer, Token::question)) {
555 while (
getToken().is(Token::bare_identifier) &&
583 value = (int64_t)*dimension;
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a memref.
FloatType getFloat8E5M2Type()
FloatType getFloat8E4M3B11FNUZType()
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
FloatType getFloat8E4M3FNType()
FloatType getFloat8E4M3FNUZType()
FloatType getFloat8E5M2FNUZType()
void resetPointer(const char *newPointer)
Change the position of the lexer cursor.
This class implements Optional functionality for ParseResult.
This class represents success/failure for parsing-like operations that find it important to chain tog...
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a tensor.
static std::optional< uint64_t > getUInt64IntegerValue(StringRef spelling)
For an integer token, return its value as an uint64_t.
std::optional< unsigned > getIntTypeBitwidth() const
For an inttype token, return its bitwidth.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
ParseResult parseXInDimensionList()
Parse an 'x' token in a dimension list, handling the case where the x is juxtaposed with an element t...
OptionalParseResult parseOptionalType(Type &type)
Optionally parse a type.
ParseResult parseToken(Token::Kind expectedToken, const Twine &message)
Consume the specified token if present and return success.
ParseResult parseCommaSeparatedListUntil(Token::Kind rightToken, function_ref< ParseResult()> parseElement, bool allowEmptyList=true)
Parse a comma-separated list of elements up until the specified end token.
Type parseType()
Parse an arbitrary type.
ParseResult parseTypeListParens(SmallVectorImpl< Type > &elements)
Parse a parenthesized list of types.
ParseResult parseVectorDimensionList(SmallVectorImpl< int64_t > &dimensions, SmallVectorImpl< bool > &scalableDims)
Parse a dimension list in a vector type.
Type parseMemRefType()
Parse a memref type.
Type parseNonFunctionType()
Parse a non function type.
Type parseExtendedType()
Parse an extended type.
Type parseTupleType()
Parse a tuple type.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error and return failure.
ParserState & state
The Parser is subclassed and reinstantiated.
Attribute parseAttribute(Type type={})
Parse an arbitrary attribute with an optional type.
StringRef getTokenSpelling() const
void consumeToken()
Advance the current lexer onto the next token.
ParseResult parseIntegerInDimensionList(int64_t &value)
Type parseComplexType()
Parse a complex type.
ParseResult parseDimensionListRanked(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)
Parse a dimension list of a tensor or memref type.
ParseResult parseFunctionResultTypes(SmallVectorImpl< Type > &elements)
Parse a function result type.
MLIRContext * getContext() const
InFlightDiagnostic emitWrongTokenError(const Twine &message={})
Emit an error about a "wrong token".
ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())
Parse a list of comma-separated items with an optional delimiter.
VectorType parseVectorType()
Parse a vector type.
Type parseFunctionType()
Parse a function type.
OptionalParseResult parseOptionalAttribute(Attribute &attribute, Type type={})
Parse an optional attribute with the provided type.
ParseResult parseTypeListNoParens(SmallVectorImpl< Type > &elements)
Parse a list of types without an enclosing parenthesis.
const Token & getToken() const
Return the current token the parser is inspecting.
Type parseTensorType()
Parse a tensor type.
bool consumeIf(Token::Kind kind)
If the current token has the specified kind, consume it and return true.
Detect if any of the given parameter types has a sub-element handler.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
Lexer lex
The lexer for the source file we're parsing.