MLIR 22.0.0git
TypeParser.cpp
Go to the documentation of this file.
1//===- TypeParser.cpp - MLIR Type Parser Implementation -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the parser for the MLIR Types.
10//
11//===----------------------------------------------------------------------===//
12
13#include "Parser.h"
14#include "mlir/IR/AffineMap.h"
20#include "mlir/IR/Types.h"
21#include "mlir/Support/LLVM.h"
22#include <cassert>
23#include <cstdint>
24#include <limits>
25#include <optional>
26
27using namespace mlir;
28using namespace mlir::detail;
29
30/// Optionally parse a type.
32 // There are many different starting tokens for a type, check them here.
33 switch (getToken().getKind()) {
34 case Token::l_paren:
35 case Token::kw_memref:
36 case Token::kw_tensor:
37 case Token::kw_complex:
38 case Token::kw_tuple:
39 case Token::kw_vector:
40 case Token::inttype:
41 case Token::kw_f4E2M1FN:
42 case Token::kw_f6E2M3FN:
43 case Token::kw_f6E3M2FN:
44 case Token::kw_f8E5M2:
45 case Token::kw_f8E4M3:
46 case Token::kw_f8E4M3FN:
47 case Token::kw_f8E5M2FNUZ:
48 case Token::kw_f8E4M3FNUZ:
49 case Token::kw_f8E4M3B11FNUZ:
50 case Token::kw_f8E3M4:
51 case Token::kw_f8E8M0FNU:
52 case Token::kw_bf16:
53 case Token::kw_f16:
54 case Token::kw_tf32:
55 case Token::kw_f32:
56 case Token::kw_f64:
57 case Token::kw_f80:
58 case Token::kw_f128:
59 case Token::kw_index:
60 case Token::kw_none:
61 case Token::exclamation_identifier:
62 return failure(!(type = parseType()));
63
64 default:
65 return std::nullopt;
66 }
67}
68
69/// Parse an arbitrary type.
70///
71/// type ::= function-type
72/// | non-function-type
73///
75 if (getToken().is(Token::l_paren))
76 return parseFunctionType();
77 return parseNonFunctionType();
78}
79
80/// Parse a function result type.
81///
82/// function-result-type ::= type-list-parens
83/// | non-function-type
84///
86 if (getToken().is(Token::l_paren))
87 return parseTypeListParens(elements);
88
90 if (!t)
91 return failure();
92 elements.push_back(t);
93 return success();
94}
95
96/// Parse a list of types without an enclosing parenthesis. The list must have
97/// at least one member.
98///
99/// type-list-no-parens ::= type (`,` type)*
100///
102 auto parseElt = [&]() -> ParseResult {
103 auto elt = parseType();
104 elements.push_back(elt);
105 return elt ? success() : failure();
106 };
107
108 return parseCommaSeparatedList(parseElt);
109}
110
111/// Parse a parenthesized list of types.
112///
113/// type-list-parens ::= `(` `)`
114/// | `(` type-list-no-parens `)`
115///
117 if (parseToken(Token::l_paren, "expected '('"))
118 return failure();
119
120 // Handle empty lists.
121 if (getToken().is(Token::r_paren))
122 return consumeToken(), success();
123
124 if (parseTypeListNoParens(elements) ||
125 parseToken(Token::r_paren, "expected ')'"))
126 return failure();
127 return success();
128}
129
130/// Parse a complex type.
131///
132/// complex-type ::= `complex` `<` type `>`
133///
135 consumeToken(Token::kw_complex);
136
137 // Parse the '<'.
138 if (parseToken(Token::less, "expected '<' in complex type"))
139 return nullptr;
140
141 SMLoc elementTypeLoc = getToken().getLoc();
142 auto elementType = parseType();
143 if (!elementType ||
144 parseToken(Token::greater, "expected '>' in complex type"))
145 return nullptr;
146 if (!isa<FloatType>(elementType) && !isa<IntegerType>(elementType))
147 return emitError(elementTypeLoc, "invalid element type for complex"),
148 nullptr;
149
150 return ComplexType::get(elementType);
151}
152
153/// Parse a function type.
154///
155/// function-type ::= type-list-parens `->` function-result-type
156///
158 assert(getToken().is(Token::l_paren));
159
160 SmallVector<Type, 4> arguments, results;
161 if (parseTypeListParens(arguments) ||
162 parseToken(Token::arrow, "expected '->' in function type") ||
164 return nullptr;
165
166 return builder.getFunctionType(arguments, results);
167}
168
169/// Parse a memref type.
170///
171/// memref-type ::= ranked-memref-type | unranked-memref-type
172///
173/// ranked-memref-type ::= `memref` `<` dimension-list-ranked type
174/// (`,` layout-specification)? (`,` memory-space)? `>`
175///
176/// unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>`
177///
178/// stride-list ::= `[` (dimension (`,` dimension)*)? `]`
179/// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
180/// layout-specification ::= semi-affine-map | strided-layout | attribute
181/// memory-space ::= integer-literal | attribute
182///
184 SMLoc loc = getToken().getLoc();
185 consumeToken(Token::kw_memref);
186
187 if (parseToken(Token::less, "expected '<' in memref type"))
188 return nullptr;
189
190 bool isUnranked;
191 SmallVector<int64_t, 4> dimensions;
192
193 if (consumeIf(Token::star)) {
194 // This is an unranked memref type.
195 isUnranked = true;
197 return nullptr;
198
199 } else {
200 isUnranked = false;
201 if (parseDimensionListRanked(dimensions))
202 return nullptr;
203 }
204
205 // Parse the element type.
206 auto typeLoc = getToken().getLoc();
207 auto elementType = parseType();
208 if (!elementType)
209 return nullptr;
210
211 // Check that memref is formed from allowed types.
212 if (!BaseMemRefType::isValidElementType(elementType))
213 return emitError(typeLoc, "invalid memref element type"), nullptr;
214
215 MemRefLayoutAttrInterface layout;
216 Attribute memorySpace;
217
218 auto parseElt = [&]() -> ParseResult {
219 // Either it is MemRefLayoutAttrInterface or memory space attribute.
220 Attribute attr = parseAttribute();
221 if (!attr)
222 return failure();
223
224 if (isa<MemRefLayoutAttrInterface>(attr)) {
225 layout = cast<MemRefLayoutAttrInterface>(attr);
226 } else if (memorySpace) {
227 return emitError("multiple memory spaces specified in memref type");
228 } else {
229 memorySpace = attr;
230 return success();
231 }
232
233 if (isUnranked)
234 return emitError("cannot have affine map for unranked memref type");
235 if (memorySpace)
236 return emitError("expected memory space to be last in memref type");
237
238 return success();
239 };
240
241 // Parse a list of mappings and address space if present.
242 if (!consumeIf(Token::greater)) {
243 // Parse comma separated list of affine maps, followed by memory space.
244 if (parseToken(Token::comma, "expected ',' or '>' in memref type") ||
245 parseCommaSeparatedListUntil(Token::greater, parseElt,
246 /*allowEmptyList=*/false)) {
247 return nullptr;
248 }
249 }
250
251 if (isUnranked)
252 return getChecked<UnrankedMemRefType>(loc, elementType, memorySpace);
253
254 return getChecked<MemRefType>(loc, dimensions, elementType, layout,
255 memorySpace);
256}
257
258/// Parse any type except the function type.
259///
260/// non-function-type ::= integer-type
261/// | index-type
262/// | float-type
263/// | extended-type
264/// | vector-type
265/// | tensor-type
266/// | memref-type
267/// | complex-type
268/// | tuple-type
269/// | none-type
270///
271/// index-type ::= `index`
272/// float-type ::= `f16` | `bf16` | `f32` | `f64` | `f80` | `f128`
273/// none-type ::= `none`
274///
276 switch (getToken().getKind()) {
277 default:
278 return (emitWrongTokenError("expected non-function type"), nullptr);
279 case Token::kw_memref:
280 return parseMemRefType();
281 case Token::kw_tensor:
282 return parseTensorType();
283 case Token::kw_complex:
284 return parseComplexType();
285 case Token::kw_tuple:
286 return parseTupleType();
287 case Token::kw_vector:
288 return parseVectorType();
289 // integer-type
290 case Token::inttype: {
291 auto width = getToken().getIntTypeBitwidth();
292 if (!width.has_value())
293 return (emitError("invalid integer width"), nullptr);
294 if (*width > IntegerType::kMaxWidth) {
295 emitError(getToken().getLoc(), "integer bitwidth is limited to ")
296 << IntegerType::kMaxWidth << " bits";
297 return nullptr;
298 }
299
300 IntegerType::SignednessSemantics signSemantics = IntegerType::Signless;
301 if (std::optional<bool> signedness = getToken().getIntTypeSignedness())
302 signSemantics = *signedness ? IntegerType::Signed : IntegerType::Unsigned;
303
304 consumeToken(Token::inttype);
305 return IntegerType::get(getContext(), *width, signSemantics);
306 }
307
308 // float-type
309 case Token::kw_f4E2M1FN:
310 consumeToken(Token::kw_f4E2M1FN);
311 return builder.getType<Float4E2M1FNType>();
312 case Token::kw_f6E2M3FN:
313 consumeToken(Token::kw_f6E2M3FN);
314 return builder.getType<Float6E2M3FNType>();
315 case Token::kw_f6E3M2FN:
316 consumeToken(Token::kw_f6E3M2FN);
317 return builder.getType<Float6E3M2FNType>();
318 case Token::kw_f8E5M2:
319 consumeToken(Token::kw_f8E5M2);
320 return builder.getType<Float8E5M2Type>();
321 case Token::kw_f8E4M3:
322 consumeToken(Token::kw_f8E4M3);
323 return builder.getType<Float8E4M3Type>();
324 case Token::kw_f8E4M3FN:
325 consumeToken(Token::kw_f8E4M3FN);
326 return builder.getType<Float8E4M3FNType>();
327 case Token::kw_f8E5M2FNUZ:
328 consumeToken(Token::kw_f8E5M2FNUZ);
329 return builder.getType<Float8E5M2FNUZType>();
330 case Token::kw_f8E4M3FNUZ:
331 consumeToken(Token::kw_f8E4M3FNUZ);
332 return builder.getType<Float8E4M3FNUZType>();
333 case Token::kw_f8E4M3B11FNUZ:
334 consumeToken(Token::kw_f8E4M3B11FNUZ);
335 return builder.getType<Float8E4M3B11FNUZType>();
336 case Token::kw_f8E3M4:
337 consumeToken(Token::kw_f8E3M4);
338 return builder.getType<Float8E3M4Type>();
339 case Token::kw_f8E8M0FNU:
340 consumeToken(Token::kw_f8E8M0FNU);
341 return builder.getType<Float8E8M0FNUType>();
342 case Token::kw_bf16:
343 consumeToken(Token::kw_bf16);
344 return builder.getType<BFloat16Type>();
345 case Token::kw_f16:
346 consumeToken(Token::kw_f16);
347 return builder.getType<Float16Type>();
348 case Token::kw_tf32:
349 consumeToken(Token::kw_tf32);
350 return builder.getType<FloatTF32Type>();
351 case Token::kw_f32:
352 consumeToken(Token::kw_f32);
353 return builder.getType<Float32Type>();
354 case Token::kw_f64:
355 consumeToken(Token::kw_f64);
356 return builder.getType<Float64Type>();
357 case Token::kw_f80:
358 consumeToken(Token::kw_f80);
359 return builder.getType<Float80Type>();
360 case Token::kw_f128:
361 consumeToken(Token::kw_f128);
362 return builder.getType<Float128Type>();
363
364 // index-type
365 case Token::kw_index:
366 consumeToken(Token::kw_index);
367 return builder.getIndexType();
368
369 // none-type
370 case Token::kw_none:
371 consumeToken(Token::kw_none);
372 return builder.getNoneType();
373
374 // extended type
375 case Token::exclamation_identifier:
376 return parseExtendedType();
377
378 // Handle completion of a dialect type.
379 case Token::code_complete:
380 if (getToken().isCodeCompletionFor(Token::exclamation_identifier))
381 return parseExtendedType();
382 return codeCompleteType();
383 }
384}
385
386/// Parse a tensor type.
387///
388/// tensor-type ::= `tensor` `<` dimension-list type `>`
389/// dimension-list ::= dimension-list-ranked | `*x`
390///
392 consumeToken(Token::kw_tensor);
393
394 if (parseToken(Token::less, "expected '<' in tensor type"))
395 return nullptr;
396
397 bool isUnranked;
398 SmallVector<int64_t, 4> dimensions;
399
400 if (consumeIf(Token::star)) {
401 // This is an unranked tensor type.
402 isUnranked = true;
403
405 return nullptr;
406
407 } else {
408 isUnranked = false;
409 if (parseDimensionListRanked(dimensions))
410 return nullptr;
411 }
412
413 // Parse the element type.
414 auto elementTypeLoc = getToken().getLoc();
415 auto elementType = parseType();
416
417 // Parse an optional encoding attribute.
418 Attribute encoding;
419 if (consumeIf(Token::comma)) {
420 auto parseResult = parseOptionalAttribute(encoding);
421 if (parseResult.has_value()) {
422 if (failed(parseResult.value()))
423 return nullptr;
424 if (auto v = dyn_cast_or_null<VerifiableTensorEncoding>(encoding)) {
425 if (failed(v.verifyEncoding(dimensions, elementType,
426 [&] { return emitError(); })))
427 return nullptr;
428 }
429 }
430 }
431
432 if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
433 return nullptr;
434 if (!TensorType::isValidElementType(elementType))
435 return emitError(elementTypeLoc, "invalid tensor element type"), nullptr;
436
437 if (isUnranked) {
438 if (encoding)
439 return emitError("cannot apply encoding to unranked tensor"), nullptr;
440 return UnrankedTensorType::get(elementType);
441 }
442 return RankedTensorType::get(dimensions, elementType, encoding);
443}
444
445/// Parse a tuple type.
446///
447/// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>`
448///
450 consumeToken(Token::kw_tuple);
451
452 // Parse the '<'.
453 if (parseToken(Token::less, "expected '<' in tuple type"))
454 return nullptr;
455
456 // Check for an empty tuple by directly parsing '>'.
457 if (consumeIf(Token::greater))
458 return TupleType::get(getContext());
459
460 // Parse the element types and the '>'.
462 if (parseTypeListNoParens(types) ||
463 parseToken(Token::greater, "expected '>' in tuple type"))
464 return nullptr;
465
466 return TupleType::get(getContext(), types);
467}
468
469/// Parse a vector type.
470///
471/// vector-type ::= `vector` `<` vector-dim-list vector-element-type `>`
472/// vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)?
473/// static-dim-list ::= decimal-literal (`x` decimal-literal)*
474///
476 SMLoc loc = getToken().getLoc();
477 consumeToken(Token::kw_vector);
478
479 if (parseToken(Token::less, "expected '<' in vector type"))
480 return nullptr;
481
482 // Parse the dimensions.
483 SmallVector<int64_t, 4> dimensions;
484 SmallVector<bool, 4> scalableDims;
485 if (parseVectorDimensionList(dimensions, scalableDims))
486 return nullptr;
487
488 // Parse the element type.
489 auto elementType = parseType();
490 if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
491 return nullptr;
492
493 return getChecked<VectorType>(loc, dimensions, elementType, scalableDims);
494}
495
496/// Parse a dimension list in a vector type. This populates the dimension list.
497/// For i-th dimension, `scalableDims[i]` contains either:
498/// * `false` for a non-scalable dimension (e.g. `4`),
499/// * `true` for a scalable dimension (e.g. `[4]`).
500///
501/// vector-dim-list := (static-dim-list `x`)?
502/// static-dim-list ::= static-dim (`x` static-dim)*
503/// static-dim ::= (decimal-literal | `[` decimal-literal `]`)
504///
505ParseResult
507 SmallVectorImpl<bool> &scalableDims) {
508 // If there is a set of fixed-length dimensions, consume it
509 while (getToken().is(Token::integer) || getToken().is(Token::l_square)) {
510 int64_t value;
511 bool scalable = consumeIf(Token::l_square);
513 return failure();
514 dimensions.push_back(value);
515 if (scalable) {
516 if (!consumeIf(Token::r_square))
517 return emitWrongTokenError("missing ']' closing scalable dimension");
518 }
519 scalableDims.push_back(scalable);
520 // Make sure we have an 'x' or something like 'xbf32'.
522 return failure();
523 }
524
525 return success();
526}
527
528/// Parse a dimension list of a tensor or memref type. This populates the
529/// dimension list, using ShapedType::kDynamic for the `?` dimensions if
530/// `allowDynamic` is set and errors out on `?` otherwise. Parsing the trailing
531/// `x` is configurable.
532///
533/// dimension-list ::= eps | dimension (`x` dimension)*
534/// dimension-list-with-trailing-x ::= (dimension `x`)*
535/// dimension ::= `?` | decimal-literal
536///
537/// When `allowDynamic` is not set, this is used to parse:
538///
539/// static-dimension-list ::= eps | decimal-literal (`x` decimal-literal)*
540/// static-dimension-list-with-trailing-x ::= (dimension `x`)*
541ParseResult
543 bool allowDynamic, bool withTrailingX) {
544 auto parseDim = [&]() -> LogicalResult {
545 auto loc = getToken().getLoc();
546 if (consumeIf(Token::question)) {
547 if (!allowDynamic)
548 return emitError(loc, "expected static shape");
549 dimensions.push_back(ShapedType::kDynamic);
550 } else {
551 int64_t value;
552 if (failed(parseIntegerInDimensionList(value)))
553 return failure();
554 dimensions.push_back(value);
555 }
556 return success();
557 };
558
559 if (withTrailingX) {
560 while (getToken().isAny(Token::integer, Token::question)) {
561 if (failed(parseDim()) || failed(parseXInDimensionList()))
562 return failure();
563 }
564 return success();
565 }
566
567 if (getToken().isAny(Token::integer, Token::question)) {
568 if (failed(parseDim()))
569 return failure();
570 while (getToken().is(Token::bare_identifier) &&
571 getTokenSpelling()[0] == 'x') {
572 if (failed(parseXInDimensionList()) || failed(parseDim()))
573 return failure();
574 }
575 }
576 return success();
577}
578
580 // Hexadecimal integer literals (starting with `0x`) are not allowed in
581 // aggregate type declarations. Therefore, `0xf32` should be processed as
582 // a sequence of separate elements `0`, `x`, `f32`.
583 if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') {
584 // We can get here only if the token is an integer literal. Hexadecimal
585 // integer literals can only start with `0x` (`1x` wouldn't lex as a
586 // literal, just `1` would, at which point we don't get into this
587 // branch).
588 assert(getTokenSpelling()[0] == '0' && "invalid integer literal");
589 value = 0;
590 state.lex.resetPointer(getTokenSpelling().data() + 1);
591 consumeToken();
592 } else {
593 // Make sure this integer value is in bound and valid.
594 std::optional<uint64_t> dimension = getToken().getUInt64IntegerValue();
595 if (!dimension ||
596 *dimension > (uint64_t)std::numeric_limits<int64_t>::max())
597 return emitError("invalid dimension");
598 value = (int64_t)*dimension;
599 consumeToken(Token::integer);
600 }
601 return success();
602}
603
604/// Parse an 'x' token in a dimension list, handling the case where the x is
605/// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next
606/// token.
608 if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x')
609 return emitWrongTokenError("expected 'x' in dimension list");
610
611 // If we had a prefix of 'x', lex the next token immediately after the 'x'.
612 if (getTokenSpelling().size() != 1)
613 state.lex.resetPointer(getTokenSpelling().data() + 1);
614
615 // Consume the 'x'.
616 consumeToken(Token::bare_identifier);
617
618 return success();
619}
return success()
Attributes are known-constant values of operations.
Definition Attributes.h:25
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a memref.
This class implements Optional functionality for ParseResult.
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a tensor.
SMLoc getLoc() const
Definition Token.cpp:24
static std::optional< uint64_t > getUInt64IntegerValue(StringRef spelling)
For an integer token, return its value as an uint64_t.
Definition Token.cpp:45
std::optional< unsigned > getIntTypeBitwidth() const
For an inttype token, return its bitwidth.
Definition Token.cpp:64
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
ParseResult parseXInDimensionList()
Parse an 'x' token in a dimension list, handling the case where the x is juxtaposed with an element t...
T getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
Definition Parser.h:198
OptionalParseResult parseOptionalType(Type &type)
Optionally parse a type.
ParseResult parseToken(Token::Kind expectedToken, const Twine &message)
Consume the specified token if present and return success.
Definition Parser.cpp:267
ParseResult parseCommaSeparatedListUntil(Token::Kind rightToken, function_ref< ParseResult()> parseElement, bool allowEmptyList=true)
Parse a comma-separated list of elements up until the specified end token.
Definition Parser.cpp:173
Type parseType()
Parse an arbitrary type.
ParseResult parseTypeListParens(SmallVectorImpl< Type > &elements)
Parse a parenthesized list of types.
ParseResult parseVectorDimensionList(SmallVectorImpl< int64_t > &dimensions, SmallVectorImpl< bool > &scalableDims)
Parse a dimension list in a vector type.
Type parseMemRefType()
Parse a memref type.
Type parseNonFunctionType()
Parse a non function type.
Type parseExtendedType()
Parse an extended type.
Type parseTupleType()
Parse a tuple type.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error and return failure.
Definition Parser.cpp:192
ParserState & state
The Parser is subclassed and reinstantiated.
Definition Parser.h:370
Attribute parseAttribute(Type type={})
Parse an arbitrary attribute with an optional type.
StringRef getTokenSpelling() const
Definition Parser.h:104
void consumeToken()
Advance the current lexer onto the next token.
Definition Parser.h:119
ParseResult parseIntegerInDimensionList(int64_t &value)
Type parseComplexType()
Parse a complex type.
ParseResult parseDimensionListRanked(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)
Parse a dimension list of a tensor or memref type.
ParseResult parseFunctionResultTypes(SmallVectorImpl< Type > &elements)
Parse a function result type.
InFlightDiagnostic emitWrongTokenError(const Twine &message={})
Emit an error about a "wrong token".
Definition Parser.cpp:215
ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())
Parse a list of comma-separated items with an optional delimiter.
Definition Parser.cpp:84
VectorType parseVectorType()
Parse a vector type.
Type parseFunctionType()
Parse a function type.
OptionalParseResult parseOptionalAttribute(Attribute &attribute, Type type={})
Parse an optional attribute with the provided type.
ParseResult parseTypeListNoParens(SmallVectorImpl< Type > &elements)
Parse a list of types without an enclosing parenthesis.
const Token & getToken() const
Return the current token the parser is inspecting.
Definition Parser.h:103
MLIRContext * getContext() const
Definition Parser.h:38
Type parseTensorType()
Parse a tensor type.
bool consumeIf(Token::Kind kind)
If the current token has the specified kind, consume it and return true.
Definition Parser.h:111
AttrTypeReplacer.
Include the generated interface declarations.