22 #include "llvm/ADT/STLExtras.h"
36 case Token::kw_memref:
37 case Token::kw_tensor:
38 case Token::kw_complex:
40 case Token::kw_vector:
42 case Token::kw_f4E2M1FN:
43 case Token::kw_f6E2M3FN:
44 case Token::kw_f6E3M2FN:
45 case Token::kw_f8E5M2:
46 case Token::kw_f8E4M3:
47 case Token::kw_f8E4M3FN:
48 case Token::kw_f8E5M2FNUZ:
49 case Token::kw_f8E4M3FNUZ:
50 case Token::kw_f8E4M3B11FNUZ:
51 case Token::kw_f8E3M4:
52 case Token::kw_f8E8M0FNU:
62 case Token::exclamation_identifier:
93 elements.push_back(t);
103 auto parseElt = [&]() -> ParseResult {
105 elements.push_back(elt);
106 return elt ? success() : failure();
118 if (
parseToken(Token::l_paren,
"expected '('"))
139 if (
parseToken(Token::less,
"expected '<' in complex type"))
145 parseToken(Token::greater,
"expected '>' in complex type"))
147 if (!isa<FloatType>(elementType) && !isa<IntegerType>(elementType))
148 return emitError(elementTypeLoc,
"invalid element type for complex"),
159 assert(
getToken().is(Token::l_paren));
163 parseToken(Token::arrow,
"expected '->' in function type") ||
188 if (
parseToken(Token::less,
"expected '<' in memref type"))
214 return emitError(typeLoc,
"invalid memref element type"),
nullptr;
216 MemRefLayoutAttrInterface layout;
219 auto parseElt = [&]() -> ParseResult {
225 if (isa<MemRefLayoutAttrInterface>(attr)) {
226 layout = cast<MemRefLayoutAttrInterface>(attr);
227 }
else if (memorySpace) {
228 return emitError(
"multiple memory spaces specified in memref type");
235 return emitError(
"cannot have affine map for unranked memref type");
237 return emitError(
"expected memory space to be last in memref type");
245 if (
parseToken(Token::comma,
"expected ',' or '>' in memref type") ||
253 return getChecked<UnrankedMemRefType>(loc, elementType, memorySpace);
255 return getChecked<MemRefType>(loc, dimensions, elementType, layout,
280 case Token::kw_memref:
282 case Token::kw_tensor:
284 case Token::kw_complex:
286 case Token::kw_tuple:
288 case Token::kw_vector:
291 case Token::inttype: {
293 if (!width.has_value())
294 return (
emitError(
"invalid integer width"),
nullptr);
295 if (*width > IntegerType::kMaxWidth) {
297 << IntegerType::kMaxWidth <<
" bits";
301 IntegerType::SignednessSemantics signSemantics = IntegerType::Signless;
302 if (std::optional<bool> signedness =
getToken().getIntTypeSignedness())
310 case Token::kw_f4E2M1FN:
313 case Token::kw_f6E2M3FN:
316 case Token::kw_f6E3M2FN:
319 case Token::kw_f8E5M2:
322 case Token::kw_f8E4M3:
325 case Token::kw_f8E4M3FN:
328 case Token::kw_f8E5M2FNUZ:
331 case Token::kw_f8E4M3FNUZ:
334 case Token::kw_f8E4M3B11FNUZ:
337 case Token::kw_f8E3M4:
340 case Token::kw_f8E8M0FNU:
366 case Token::kw_index:
376 case Token::exclamation_identifier:
380 case Token::code_complete:
381 if (
getToken().isCodeCompletionFor(Token::exclamation_identifier))
395 if (
parseToken(Token::less,
"expected '<' in tensor type"))
422 if (parseResult.has_value()) {
423 if (failed(parseResult.value()))
425 if (
auto v = dyn_cast_or_null<VerifiableTensorEncoding>(encoding)) {
426 if (failed(v.verifyEncoding(dimensions, elementType,
427 [&] { return emitError(); })))
433 if (!elementType ||
parseToken(Token::greater,
"expected '>' in tensor type"))
436 return emitError(elementTypeLoc,
"invalid tensor element type"),
nullptr;
440 return emitError(
"cannot apply encoding to unranked tensor"),
nullptr;
454 if (
parseToken(Token::less,
"expected '<' in tuple type"))
464 parseToken(Token::greater,
"expected '>' in tuple type"))
480 if (
parseToken(Token::less,
"expected '<' in vector type"))
491 if (!elementType ||
parseToken(Token::greater,
"expected '>' in vector type"))
494 return getChecked<VectorType>(loc, dimensions, elementType, scalableDims);
512 bool scalable =
consumeIf(Token::l_square);
515 dimensions.push_back(value);
520 scalableDims.push_back(scalable);
544 bool allowDynamic,
bool withTrailingX) {
545 auto parseDim = [&]() -> LogicalResult {
549 return emitError(loc,
"expected static shape");
550 dimensions.push_back(ShapedType::kDynamic);
555 dimensions.push_back(value);
561 while (
getToken().isAny(Token::integer, Token::question)) {
568 if (
getToken().isAny(Token::integer, Token::question)) {
569 if (failed(parseDim()))
571 while (
getToken().is(Token::bare_identifier) &&
599 value = (int64_t)*dimension;
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a memref.
FloatType getFloat8E5M2Type()
FloatType getFloat8E8M0FNUType()
FloatType getFloat8E4M3B11FNUZType()
FloatType getFloat6E3M2FNType()
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
FloatType getFloat8E3M4Type()
FloatType getFloat8E4M3Type()
FloatType getFloat4E2M1FNType()
FloatType getFloat8E4M3FNType()
FloatType getFloat6E2M3FNType()
FloatType getFloat8E4M3FNUZType()
FloatType getFloat8E5M2FNUZType()
void resetPointer(const char *newPointer)
Change the position of the lexer cursor.
This class implements Optional functionality for ParseResult.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a tensor.
static std::optional< uint64_t > getUInt64IntegerValue(StringRef spelling)
For an integer token, return its value as an uint64_t.
std::optional< unsigned > getIntTypeBitwidth() const
For an inttype token, return its bitwidth.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
ParseResult parseXInDimensionList()
Parse an 'x' token in a dimension list, handling the case where the x is juxtaposed with an element t...
OptionalParseResult parseOptionalType(Type &type)
Optionally parse a type.
ParseResult parseToken(Token::Kind expectedToken, const Twine &message)
Consume the specified token if present and return success.
ParseResult parseCommaSeparatedListUntil(Token::Kind rightToken, function_ref< ParseResult()> parseElement, bool allowEmptyList=true)
Parse a comma-separated list of elements up until the specified end token.
Type parseType()
Parse an arbitrary type.
ParseResult parseTypeListParens(SmallVectorImpl< Type > &elements)
Parse a parenthesized list of types.
ParseResult parseVectorDimensionList(SmallVectorImpl< int64_t > &dimensions, SmallVectorImpl< bool > &scalableDims)
Parse a dimension list in a vector type.
Type parseMemRefType()
Parse a memref type.
Type parseNonFunctionType()
Parse a non function type.
Type parseExtendedType()
Parse an extended type.
Type parseTupleType()
Parse a tuple type.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error and return failure.
ParserState & state
The Parser is subclassed and reinstantiated.
Attribute parseAttribute(Type type={})
Parse an arbitrary attribute with an optional type.
StringRef getTokenSpelling() const
void consumeToken()
Advance the current lexer onto the next token.
ParseResult parseIntegerInDimensionList(int64_t &value)
Type parseComplexType()
Parse a complex type.
ParseResult parseDimensionListRanked(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)
Parse a dimension list of a tensor or memref type.
ParseResult parseFunctionResultTypes(SmallVectorImpl< Type > &elements)
Parse a function result type.
MLIRContext * getContext() const
InFlightDiagnostic emitWrongTokenError(const Twine &message={})
Emit an error about a "wrong token".
ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())
Parse a list of comma-separated items with an optional delimiter.
VectorType parseVectorType()
Parse a vector type.
Type parseFunctionType()
Parse a function type.
OptionalParseResult parseOptionalAttribute(Attribute &attribute, Type type={})
Parse an optional attribute with the provided type.
ParseResult parseTypeListNoParens(SmallVectorImpl< Type > &elements)
Parse a list of types without an enclosing parenthesis.
const Token & getToken() const
Return the current token the parser is inspecting.
Type parseTensorType()
Parse a tensor type.
bool consumeIf(Token::Kind kind)
If the current token has the specified kind, consume it and return true.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Lexer lex
The lexer for the source file we're parsing.