MLIR 23.0.0git
TypeParser.cpp
Go to the documentation of this file.
1//===- TypeParser.cpp - MLIR Type Parser Implementation -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the parser for the MLIR Types.
10//
11//===----------------------------------------------------------------------===//
12
13#include "Parser.h"
14#include "mlir/IR/AffineMap.h"
20#include "mlir/IR/Types.h"
21#include "mlir/Support/LLVM.h"
22#include <cassert>
23#include <cstdint>
24#include <limits>
25#include <optional>
26
27using namespace mlir;
28using namespace mlir::detail;
29
30/// Optionally parse a type.
32 // There are many different starting tokens for a type, check them here.
33 switch (getToken().getKind()) {
34 case Token::l_paren:
35 case Token::kw_memref:
36 case Token::kw_tensor:
37 case Token::kw_complex:
38 case Token::kw_tuple:
39 case Token::kw_vector:
40 case Token::inttype:
41 case Token::kw_f4E2M1FN:
42 case Token::kw_f6E2M3FN:
43 case Token::kw_f6E3M2FN:
44 case Token::kw_f8E5M2:
45 case Token::kw_f8E4M3:
46 case Token::kw_f8E4M3FN:
47 case Token::kw_f8E5M2FNUZ:
48 case Token::kw_f8E4M3FNUZ:
49 case Token::kw_f8E4M3B11FNUZ:
50 case Token::kw_f8E3M4:
51 case Token::kw_f8E8M0FNU:
52 case Token::kw_bf16:
53 case Token::kw_f16:
54 case Token::kw_tf32:
55 case Token::kw_f32:
56 case Token::kw_f64:
57 case Token::kw_f80:
58 case Token::kw_f128:
59 case Token::kw_index:
60 case Token::kw_none:
61 case Token::kw_token:
62 case Token::exclamation_identifier:
63 return failure(!(type = parseType()));
64
65 default:
66 return std::nullopt;
67 }
68}
69
70/// Parse an arbitrary type.
71///
72/// type ::= function-type
73/// | non-function-type
74///
76 if (getToken().is(Token::l_paren))
77 return parseFunctionType();
78 return parseNonFunctionType();
79}
80
81/// Parse a function result type.
82///
83/// function-result-type ::= type-list-parens
84/// | non-function-type
85///
87 if (getToken().is(Token::l_paren))
88 return parseTypeListParens(elements);
89
91 if (!t)
92 return failure();
93 elements.push_back(t);
94 return success();
95}
96
97/// Parse a list of types without an enclosing parenthesis. The list must have
98/// at least one member.
99///
100/// type-list-no-parens ::= type (`,` type)*
101///
103 auto parseElt = [&]() -> ParseResult {
104 auto elt = parseType();
105 elements.push_back(elt);
106 return elt ? success() : failure();
107 };
108
109 return parseCommaSeparatedList(parseElt);
110}
111
112/// Parse a parenthesized list of types.
113///
114/// type-list-parens ::= `(` `)`
115/// | `(` type-list-no-parens `)`
116///
118 if (parseToken(Token::l_paren, "expected '('"))
119 return failure();
120
121 // Handle empty lists.
122 if (getToken().is(Token::r_paren))
123 return consumeToken(), success();
124
125 if (parseTypeListNoParens(elements) ||
126 parseToken(Token::r_paren, "expected ')'"))
127 return failure();
128 return success();
129}
130
131/// Parse a complex type.
132///
133/// complex-type ::= `complex` `<` type `>`
134///
136 consumeToken(Token::kw_complex);
137
138 // Parse the '<'.
139 if (parseToken(Token::less, "expected '<' in complex type"))
140 return nullptr;
141
142 SMLoc elementTypeLoc = getToken().getLoc();
143 auto elementType = parseType();
144 if (!elementType ||
145 parseToken(Token::greater, "expected '>' in complex type"))
146 return nullptr;
147 if (!isa<FloatType>(elementType) && !isa<IntegerType>(elementType))
148 return emitError(elementTypeLoc, "invalid element type for complex"),
149 nullptr;
150
151 return ComplexType::get(elementType);
152}
153
154/// Parse a function type.
155///
156/// function-type ::= type-list-parens `->` function-result-type
157///
159 assert(getToken().is(Token::l_paren));
160
161 SmallVector<Type, 4> arguments, results;
162 if (parseTypeListParens(arguments) ||
163 parseToken(Token::arrow, "expected '->' in function type") ||
165 return nullptr;
166
167 return builder.getFunctionType(arguments, results);
168}
169
170/// Parse a memref type.
171///
172/// memref-type ::= ranked-memref-type | unranked-memref-type
173///
174/// ranked-memref-type ::= `memref` `<` dimension-list-ranked type
175/// (`,` layout-specification)? (`,` memory-space)? `>`
176///
177/// unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>`
178///
179/// stride-list ::= `[` (dimension (`,` dimension)*)? `]`
180/// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
181/// layout-specification ::= semi-affine-map | strided-layout | attribute
182/// memory-space ::= integer-literal | attribute
183///
185 SMLoc loc = getToken().getLoc();
186 consumeToken(Token::kw_memref);
187
188 if (parseToken(Token::less, "expected '<' in memref type"))
189 return nullptr;
190
191 bool isUnranked;
192 SmallVector<int64_t, 4> dimensions;
193
194 if (consumeIf(Token::star)) {
195 // This is an unranked memref type.
196 isUnranked = true;
198 return nullptr;
199
200 } else {
201 isUnranked = false;
202 if (parseDimensionListRanked(dimensions))
203 return nullptr;
204 }
205
206 // Parse the element type.
207 auto typeLoc = getToken().getLoc();
208 auto elementType = parseType();
209 if (!elementType)
210 return nullptr;
211
212 // Check that memref is formed from allowed types.
213 if (!BaseMemRefType::isValidElementType(elementType))
214 return emitError(typeLoc, "invalid memref element type"), nullptr;
215
216 MemRefLayoutAttrInterface layout;
217 Attribute memorySpace;
218
219 auto parseElt = [&]() -> ParseResult {
220 // Either it is MemRefLayoutAttrInterface or memory space attribute.
221 Attribute attr = parseAttribute();
222 if (!attr)
223 return failure();
224
225 if (isa<MemRefLayoutAttrInterface>(attr)) {
226 layout = cast<MemRefLayoutAttrInterface>(attr);
227 } else if (memorySpace) {
228 return emitError("multiple memory spaces specified in memref type");
229 } else {
230 memorySpace = attr;
231 return success();
232 }
233
234 if (isUnranked)
235 return emitError("cannot have affine map for unranked memref type");
236 if (memorySpace)
237 return emitError("expected memory space to be last in memref type");
238
239 return success();
240 };
241
242 // Parse a list of mappings and address space if present.
243 if (!consumeIf(Token::greater)) {
244 // Parse comma separated list of affine maps, followed by memory space.
245 if (parseToken(Token::comma, "expected ',' or '>' in memref type") ||
246 parseCommaSeparatedListUntil(Token::greater, parseElt,
247 /*allowEmptyList=*/false)) {
248 return nullptr;
249 }
250 }
251
252 if (isUnranked)
253 return getChecked<UnrankedMemRefType>(loc, elementType, memorySpace);
254
255 return getChecked<MemRefType>(loc, dimensions, elementType, layout,
256 memorySpace);
257}
258
259/// Parse any type except the function type.
260///
261/// non-function-type ::= integer-type
262/// | index-type
263/// | float-type
264/// | extended-type
265/// | vector-type
266/// | tensor-type
267/// | memref-type
268/// | complex-type
269/// | tuple-type
270/// | none-type
271///
272/// index-type ::= `index`
273/// float-type ::= `f16` | `bf16` | `f32` | `f64` | `f80` | `f128`
274/// none-type ::= `none`
275///
277 switch (getToken().getKind()) {
278 default:
279 return (emitWrongTokenError("expected non-function type"), nullptr);
280 case Token::kw_memref:
281 return parseMemRefType();
282 case Token::kw_tensor:
283 return parseTensorType();
284 case Token::kw_complex:
285 return parseComplexType();
286 case Token::kw_tuple:
287 return parseTupleType();
288 case Token::kw_vector:
289 return parseVectorType();
290 // integer-type
291 case Token::inttype: {
292 auto width = getToken().getIntTypeBitwidth();
293 if (!width.has_value())
294 return (emitError("invalid integer width"), nullptr);
295 if (*width > IntegerType::kMaxWidth) {
296 emitError(getToken().getLoc(), "integer bitwidth is limited to ")
297 << IntegerType::kMaxWidth << " bits";
298 return nullptr;
299 }
300
301 IntegerType::SignednessSemantics signSemantics = IntegerType::Signless;
302 if (std::optional<bool> signedness = getToken().getIntTypeSignedness())
303 signSemantics = *signedness ? IntegerType::Signed : IntegerType::Unsigned;
304
305 consumeToken(Token::inttype);
306 return IntegerType::get(getContext(), *width, signSemantics);
307 }
308
309 // float-type
310 case Token::kw_f4E2M1FN:
311 consumeToken(Token::kw_f4E2M1FN);
312 return builder.getType<Float4E2M1FNType>();
313 case Token::kw_f6E2M3FN:
314 consumeToken(Token::kw_f6E2M3FN);
315 return builder.getType<Float6E2M3FNType>();
316 case Token::kw_f6E3M2FN:
317 consumeToken(Token::kw_f6E3M2FN);
318 return builder.getType<Float6E3M2FNType>();
319 case Token::kw_f8E5M2:
320 consumeToken(Token::kw_f8E5M2);
321 return builder.getType<Float8E5M2Type>();
322 case Token::kw_f8E4M3:
323 consumeToken(Token::kw_f8E4M3);
324 return builder.getType<Float8E4M3Type>();
325 case Token::kw_f8E4M3FN:
326 consumeToken(Token::kw_f8E4M3FN);
327 return builder.getType<Float8E4M3FNType>();
328 case Token::kw_f8E5M2FNUZ:
329 consumeToken(Token::kw_f8E5M2FNUZ);
330 return builder.getType<Float8E5M2FNUZType>();
331 case Token::kw_f8E4M3FNUZ:
332 consumeToken(Token::kw_f8E4M3FNUZ);
333 return builder.getType<Float8E4M3FNUZType>();
334 case Token::kw_f8E4M3B11FNUZ:
335 consumeToken(Token::kw_f8E4M3B11FNUZ);
336 return builder.getType<Float8E4M3B11FNUZType>();
337 case Token::kw_f8E3M4:
338 consumeToken(Token::kw_f8E3M4);
339 return builder.getType<Float8E3M4Type>();
340 case Token::kw_f8E8M0FNU:
341 consumeToken(Token::kw_f8E8M0FNU);
342 return builder.getType<Float8E8M0FNUType>();
343 case Token::kw_bf16:
344 consumeToken(Token::kw_bf16);
345 return builder.getType<BFloat16Type>();
346 case Token::kw_f16:
347 consumeToken(Token::kw_f16);
348 return builder.getType<Float16Type>();
349 case Token::kw_tf32:
350 consumeToken(Token::kw_tf32);
351 return builder.getType<FloatTF32Type>();
352 case Token::kw_f32:
353 consumeToken(Token::kw_f32);
354 return builder.getType<Float32Type>();
355 case Token::kw_f64:
356 consumeToken(Token::kw_f64);
357 return builder.getType<Float64Type>();
358 case Token::kw_f80:
359 consumeToken(Token::kw_f80);
360 return builder.getType<Float80Type>();
361 case Token::kw_f128:
362 consumeToken(Token::kw_f128);
363 return builder.getType<Float128Type>();
364
365 // index-type
366 case Token::kw_index:
367 consumeToken(Token::kw_index);
368 return builder.getIndexType();
369
370 // none-type
371 case Token::kw_none:
372 consumeToken(Token::kw_none);
373 return builder.getNoneType();
374
375 // token-type
376 case Token::kw_token:
377 consumeToken(Token::kw_token);
378 return builder.getType<TokenType>();
379
380 // extended type
381 case Token::exclamation_identifier:
382 return parseExtendedType();
383
384 // Handle completion of a dialect type.
385 case Token::code_complete:
386 if (getToken().isCodeCompletionFor(Token::exclamation_identifier))
387 return parseExtendedType();
388 return codeCompleteType();
389 }
390}
391
392/// Parse a tensor type.
393///
394/// tensor-type ::= `tensor` `<` dimension-list type `>`
395/// dimension-list ::= dimension-list-ranked | `*x`
396///
398 consumeToken(Token::kw_tensor);
399
400 if (parseToken(Token::less, "expected '<' in tensor type"))
401 return nullptr;
402
403 bool isUnranked;
404 SmallVector<int64_t, 4> dimensions;
405
406 if (consumeIf(Token::star)) {
407 // This is an unranked tensor type.
408 isUnranked = true;
409
411 return nullptr;
412
413 } else {
414 isUnranked = false;
415 if (parseDimensionListRanked(dimensions))
416 return nullptr;
417 }
418
419 // Parse the element type.
420 auto elementTypeLoc = getToken().getLoc();
421 auto elementType = parseType();
422
423 // Parse an optional encoding attribute.
424 Attribute encoding;
425 if (consumeIf(Token::comma)) {
426 auto parseResult = parseOptionalAttribute(encoding);
427 if (parseResult.has_value()) {
428 if (failed(parseResult.value()))
429 return nullptr;
430 if (auto v = dyn_cast_or_null<VerifiableTensorEncoding>(encoding)) {
431 if (failed(v.verifyEncoding(dimensions, elementType,
432 [&] { return emitError(); })))
433 return nullptr;
434 }
435 }
436 }
437
438 if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
439 return nullptr;
440 if (!TensorType::isValidElementType(elementType))
441 return emitError(elementTypeLoc, "invalid tensor element type"), nullptr;
442
443 if (isUnranked) {
444 if (encoding)
445 return emitError("cannot apply encoding to unranked tensor"), nullptr;
446 return UnrankedTensorType::get(elementType);
447 }
448 return RankedTensorType::get(dimensions, elementType, encoding);
449}
450
451/// Parse a tuple type.
452///
453/// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>`
454///
456 consumeToken(Token::kw_tuple);
457
458 // Parse the '<'.
459 if (parseToken(Token::less, "expected '<' in tuple type"))
460 return nullptr;
461
462 // Check for an empty tuple by directly parsing '>'.
463 if (consumeIf(Token::greater))
464 return TupleType::get(getContext());
465
466 // Parse the element types and the '>'.
468 if (parseTypeListNoParens(types) ||
469 parseToken(Token::greater, "expected '>' in tuple type"))
470 return nullptr;
471
472 return TupleType::get(getContext(), types);
473}
474
475/// Parse a vector type.
476///
477/// vector-type ::= `vector` `<` vector-dim-list vector-element-type `>`
478/// vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)?
479/// static-dim-list ::= decimal-literal (`x` decimal-literal)*
480///
482 SMLoc loc = getToken().getLoc();
483 consumeToken(Token::kw_vector);
484
485 if (parseToken(Token::less, "expected '<' in vector type"))
486 return nullptr;
487
488 // Parse the dimensions.
489 SmallVector<int64_t, 4> dimensions;
490 SmallVector<bool, 4> scalableDims;
491 if (parseVectorDimensionList(dimensions, scalableDims))
492 return nullptr;
493
494 // Parse the element type.
495 auto elementType = parseType();
496 if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
497 return nullptr;
498
499 return getChecked<VectorType>(loc, dimensions, elementType, scalableDims);
500}
501
502/// Parse a dimension list in a vector type. This populates the dimension list.
503/// For i-th dimension, `scalableDims[i]` contains either:
504/// * `false` for a non-scalable dimension (e.g. `4`),
505/// * `true` for a scalable dimension (e.g. `[4]`).
506///
507/// vector-dim-list := (static-dim-list `x`)?
508/// static-dim-list ::= static-dim (`x` static-dim)*
509/// static-dim ::= (decimal-literal | `[` decimal-literal `]`)
510///
511ParseResult
513 SmallVectorImpl<bool> &scalableDims) {
514 // If there is a set of fixed-length dimensions, consume it
515 while (getToken().is(Token::integer) || getToken().is(Token::l_square)) {
516 int64_t value;
517 bool scalable = consumeIf(Token::l_square);
519 return failure();
520 dimensions.push_back(value);
521 if (scalable) {
522 if (!consumeIf(Token::r_square))
523 return emitWrongTokenError("missing ']' closing scalable dimension");
524 }
525 scalableDims.push_back(scalable);
526 // Make sure we have an 'x' or something like 'xbf32'.
528 return failure();
529 }
530
531 return success();
532}
533
534/// Parse a dimension list of a tensor or memref type. This populates the
535/// dimension list, using ShapedType::kDynamic for the `?` dimensions if
536/// `allowDynamic` is set and errors out on `?` otherwise. Parsing the trailing
537/// `x` is configurable.
538///
539/// dimension-list ::= eps | dimension (`x` dimension)*
540/// dimension-list-with-trailing-x ::= (dimension `x`)*
541/// dimension ::= `?` | decimal-literal
542///
543/// When `allowDynamic` is not set, this is used to parse:
544///
545/// static-dimension-list ::= eps | decimal-literal (`x` decimal-literal)*
546/// static-dimension-list-with-trailing-x ::= (dimension `x`)*
547ParseResult
549 bool allowDynamic, bool withTrailingX) {
550 auto parseDim = [&]() -> LogicalResult {
551 auto loc = getToken().getLoc();
552 if (consumeIf(Token::question)) {
553 if (!allowDynamic)
554 return emitError(loc, "expected static shape");
555 dimensions.push_back(ShapedType::kDynamic);
556 } else {
557 int64_t value;
558 if (failed(parseIntegerInDimensionList(value)))
559 return failure();
560 dimensions.push_back(value);
561 }
562 return success();
563 };
564
565 if (withTrailingX) {
566 while (getToken().isAny(Token::integer, Token::question)) {
567 if (failed(parseDim()) || failed(parseXInDimensionList()))
568 return failure();
569 }
570 return success();
571 }
572
573 if (getToken().isAny(Token::integer, Token::question)) {
574 if (failed(parseDim()))
575 return failure();
576 while (getToken().is(Token::bare_identifier) &&
577 getTokenSpelling()[0] == 'x') {
578 if (failed(parseXInDimensionList()) || failed(parseDim()))
579 return failure();
580 }
581 }
582 return success();
583}
584
586 // Hexadecimal integer literals (starting with `0x`) are not allowed in
587 // aggregate type declarations. Therefore, `0xf32` should be processed as
588 // a sequence of separate elements `0`, `x`, `f32`.
589 if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') {
590 // We can get here only if the token is an integer literal. Hexadecimal
591 // integer literals can only start with `0x` (`1x` wouldn't lex as a
592 // literal, just `1` would, at which point we don't get into this
593 // branch).
594 assert(getTokenSpelling()[0] == '0' && "invalid integer literal");
595 value = 0;
596 state.lex.resetPointer(getTokenSpelling().data() + 1);
597 consumeToken();
598 } else {
599 // Make sure this integer value is in bound and valid.
600 std::optional<uint64_t> dimension = getToken().getUInt64IntegerValue();
601 if (!dimension ||
602 *dimension > (uint64_t)std::numeric_limits<int64_t>::max())
603 return emitError("invalid dimension");
604 value = (int64_t)*dimension;
605 consumeToken(Token::integer);
606 }
607 return success();
608}
609
610/// Parse an 'x' token in a dimension list, handling the case where the x is
611/// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next
612/// token.
614 if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x')
615 return emitWrongTokenError("expected 'x' in dimension list");
616
617 // If we had a prefix of 'x', lex the next token immediately after the 'x'.
618 if (getTokenSpelling().size() != 1)
619 state.lex.resetPointer(getTokenSpelling().data() + 1);
620
621 // Consume the 'x'.
622 consumeToken(Token::bare_identifier);
623
624 return success();
625}
return success()
Attributes are known-constant values of operations.
Definition Attributes.h:25
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a memref.
This class implements Optional functionality for ParseResult.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a tensor.
SMLoc getLoc() const
Definition Token.cpp:24
static std::optional< uint64_t > getUInt64IntegerValue(StringRef spelling)
For an integer token, return its value as an uint64_t.
Definition Token.cpp:45
std::optional< unsigned > getIntTypeBitwidth() const
For an inttype token, return its bitwidth.
Definition Token.cpp:64
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
ParseResult parseXInDimensionList()
Parse an 'x' token in a dimension list, handling the case where the x is juxtaposed with an element t...
T getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
Definition Parser.h:198
OptionalParseResult parseOptionalType(Type &type)
Optionally parse a type.
ParseResult parseToken(Token::Kind expectedToken, const Twine &message)
Consume the specified token if present and return success.
Definition Parser.cpp:305
ParseResult parseCommaSeparatedListUntil(Token::Kind rightToken, function_ref< ParseResult()> parseElement, bool allowEmptyList=true)
Parse a comma-separated list of elements up until the specified end token.
Definition Parser.cpp:173
Type parseType()
Parse an arbitrary type.
ParseResult parseTypeListParens(SmallVectorImpl< Type > &elements)
Parse a parenthesized list of types.
ParseResult parseVectorDimensionList(SmallVectorImpl< int64_t > &dimensions, SmallVectorImpl< bool > &scalableDims)
Parse a dimension list in a vector type.
Type parseMemRefType()
Parse a memref type.
Type parseNonFunctionType()
Parse a non function type.
Type parseExtendedType()
Parse an extended type.
Type parseTupleType()
Parse a tuple type.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error and return failure.
Definition Parser.cpp:192
ParserState & state
The Parser is subclassed and reinstantiated.
Definition Parser.h:370
Attribute parseAttribute(Type type={})
Parse an arbitrary attribute with an optional type.
StringRef getTokenSpelling() const
Definition Parser.h:104
void consumeToken()
Advance the current lexer onto the next token.
Definition Parser.h:119
ParseResult parseIntegerInDimensionList(int64_t &value)
Type parseComplexType()
Parse a complex type.
ParseResult parseDimensionListRanked(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)
Parse a dimension list of a tensor or memref type.
ParseResult parseFunctionResultTypes(SmallVectorImpl< Type > &elements)
Parse a function result type.
InFlightDiagnostic emitWrongTokenError(const Twine &message={})
Emit an error about a "wrong token".
Definition Parser.cpp:254
ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())
Parse a list of comma-separated items with an optional delimiter.
Definition Parser.cpp:84
VectorType parseVectorType()
Parse a vector type.
Type parseFunctionType()
Parse a function type.
OptionalParseResult parseOptionalAttribute(Attribute &attribute, Type type={})
Parse an optional attribute with the provided type.
ParseResult parseTypeListNoParens(SmallVectorImpl< Type > &elements)
Parse a list of types without an enclosing parenthesis.
const Token & getToken() const
Return the current token the parser is inspecting.
Definition Parser.h:103
MLIRContext * getContext() const
Definition Parser.h:38
Type parseTensorType()
Parse a tensor type.
bool consumeIf(Token::Kind kind)
If the current token has the specified kind, consume it and return true.
Definition Parser.h:111
AttrTypeReplacer.
Include the generated interface declarations.