MLIR  16.0.0git
AttributeParser.cpp
Go to the documentation of this file.
1 //===- AttributeParser.cpp - MLIR Attribute Parser Implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the parser for the MLIR Types.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Parser.h"
14 
15 #include "AsmParserImpl.h"
17 #include "mlir/IR/AffineMap.h"
18 #include "mlir/IR/BuiltinDialect.h"
19 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/IntegerSet.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/Support/Endian.h"
25 
26 using namespace mlir;
27 using namespace mlir::detail;
28 
29 /// Parse an arbitrary attribute.
30 ///
31 /// attribute-value ::= `unit`
32 /// | bool-literal
33 /// | integer-literal (`:` (index-type | integer-type))?
34 /// | float-literal (`:` float-type)?
35 /// | string-literal (`:` type)?
36 /// | type
37 /// | `[` `:` (integer-type | float-type) tensor-literal `]`
38 /// | `[` (attribute-value (`,` attribute-value)*)? `]`
39 /// | `{` (attribute-entry (`,` attribute-entry)*)? `}`
40 /// | symbol-ref-id (`::` symbol-ref-id)*
41 /// | `dense` `<` tensor-literal `>` `:`
42 /// (tensor-type | vector-type)
43 /// | `sparse` `<` attribute-value `,` attribute-value `>`
44 /// `:` (tensor-type | vector-type)
45 /// | extended-attribute
46 ///
48  switch (getToken().getKind()) {
49  // Parse an AffineMap or IntegerSet attribute.
50  case Token::kw_affine_map: {
51  consumeToken(Token::kw_affine_map);
52 
53  AffineMap map;
54  if (parseToken(Token::less, "expected '<' in affine map") ||
56  parseToken(Token::greater, "expected '>' in affine map"))
57  return Attribute();
58  return AffineMapAttr::get(map);
59  }
60  case Token::kw_affine_set: {
61  consumeToken(Token::kw_affine_set);
62 
63  IntegerSet set;
64  if (parseToken(Token::less, "expected '<' in integer set") ||
66  parseToken(Token::greater, "expected '>' in integer set"))
67  return Attribute();
68  return IntegerSetAttr::get(set);
69  }
70 
71  // Parse an array attribute.
72  case Token::l_square: {
73  consumeToken(Token::l_square);
75  auto parseElt = [&]() -> ParseResult {
76  elements.push_back(parseAttribute());
77  return elements.back() ? success() : failure();
78  };
79 
80  if (parseCommaSeparatedListUntil(Token::r_square, parseElt))
81  return nullptr;
82  return builder.getArrayAttr(elements);
83  }
84 
85  // Parse a boolean attribute.
86  case Token::kw_false:
87  consumeToken(Token::kw_false);
88  return builder.getBoolAttr(false);
89  case Token::kw_true:
90  consumeToken(Token::kw_true);
91  return builder.getBoolAttr(true);
92 
93  // Parse a dense elements attribute.
94  case Token::kw_dense:
95  return parseDenseElementsAttr(type);
96 
97  // Parse a dense resource elements attribute.
98  case Token::kw_dense_resource:
99  return parseDenseResourceElementsAttr(type);
100 
101  // Parse a dense array attribute.
102  case Token::kw_array:
103  return parseDenseArrayAttr(type);
104 
105  // Parse a dictionary attribute.
106  case Token::l_brace: {
107  NamedAttrList elements;
108  if (parseAttributeDict(elements))
109  return nullptr;
110  return elements.getDictionary(getContext());
111  }
112 
113  // Parse an extended attribute, i.e. alias or dialect attribute.
114  case Token::hash_identifier:
115  return parseExtendedAttr(type);
116 
117  // Parse floating point and integer attributes.
118  case Token::floatliteral:
119  return parseFloatAttr(type, /*isNegative=*/false);
120  case Token::integer:
121  return parseDecOrHexAttr(type, /*isNegative=*/false);
122  case Token::minus: {
123  consumeToken(Token::minus);
124  if (getToken().is(Token::integer))
125  return parseDecOrHexAttr(type, /*isNegative=*/true);
126  if (getToken().is(Token::floatliteral))
127  return parseFloatAttr(type, /*isNegative=*/true);
128 
129  return (emitWrongTokenError(
130  "expected constant integer or floating point value"),
131  nullptr);
132  }
133 
134  // Parse a location attribute.
135  case Token::kw_loc: {
136  consumeToken(Token::kw_loc);
137 
138  LocationAttr locAttr;
139  if (parseToken(Token::l_paren, "expected '(' in inline location") ||
140  parseLocationInstance(locAttr) ||
141  parseToken(Token::r_paren, "expected ')' in inline location"))
142  return Attribute();
143  return locAttr;
144  }
145 
146  // Parse a sparse elements attribute.
147  case Token::kw_sparse:
148  return parseSparseElementsAttr(type);
149 
150  // Parse a string attribute.
151  case Token::string: {
152  auto val = getToken().getStringValue();
153  consumeToken(Token::string);
154  // Parse the optional trailing colon type if one wasn't explicitly provided.
155  if (!type && consumeIf(Token::colon) && !(type = parseType()))
156  return Attribute();
157 
158  return type ? StringAttr::get(val, type)
159  : StringAttr::get(getContext(), val);
160  }
161 
162  // Parse a symbol reference attribute.
163  case Token::at_identifier: {
164  // When populating the parser state, this is a list of locations for all of
165  // the nested references.
166  SmallVector<SMRange> referenceLocations;
167  if (state.asmState)
168  referenceLocations.push_back(getToken().getLocRange());
169 
170  // Parse the top-level reference.
171  std::string nameStr = getToken().getSymbolReference();
172  consumeToken(Token::at_identifier);
173 
174  // Parse any nested references.
175  std::vector<FlatSymbolRefAttr> nestedRefs;
176  while (getToken().is(Token::colon)) {
177  // Check for the '::' prefix.
178  const char *curPointer = getToken().getLoc().getPointer();
179  consumeToken(Token::colon);
180  if (!consumeIf(Token::colon)) {
181  if (getToken().isNot(Token::eof, Token::error)) {
182  state.lex.resetPointer(curPointer);
183  consumeToken();
184  }
185  break;
186  }
187  // Parse the reference itself.
188  auto curLoc = getToken().getLoc();
189  if (getToken().isNot(Token::at_identifier)) {
190  emitError(curLoc, "expected nested symbol reference identifier");
191  return Attribute();
192  }
193 
194  // If we are populating the assembly state, add the location for this
195  // reference.
196  if (state.asmState)
197  referenceLocations.push_back(getToken().getLocRange());
198 
199  std::string nameStr = getToken().getSymbolReference();
200  consumeToken(Token::at_identifier);
201  nestedRefs.push_back(SymbolRefAttr::get(getContext(), nameStr));
202  }
203  SymbolRefAttr symbolRefAttr =
204  SymbolRefAttr::get(getContext(), nameStr, nestedRefs);
205 
206  // If we are populating the assembly state, record this symbol reference.
207  if (state.asmState)
208  state.asmState->addUses(symbolRefAttr, referenceLocations);
209  return symbolRefAttr;
210  }
211 
212  // Parse a 'unit' attribute.
213  case Token::kw_unit:
214  consumeToken(Token::kw_unit);
215  return builder.getUnitAttr();
216 
217  // Handle completion of an attribute.
218  case Token::code_complete:
219  if (getToken().isCodeCompletionFor(Token::hash_identifier))
220  return parseExtendedAttr(type);
221  return codeCompleteAttribute();
222 
223  default:
224  // Parse a type attribute. We parse `Optional` here to allow for providing a
225  // better error message.
226  Type type;
227  OptionalParseResult result = parseOptionalType(type);
228  if (!result.has_value())
229  return emitWrongTokenError("expected attribute value"), Attribute();
230  return failed(*result) ? Attribute() : TypeAttr::get(type);
231  }
232 }
233 
234 /// Parse an optional attribute with the provided type.
236  Type type) {
237  switch (getToken().getKind()) {
238  case Token::at_identifier:
239  case Token::floatliteral:
240  case Token::integer:
241  case Token::hash_identifier:
242  case Token::kw_affine_map:
243  case Token::kw_affine_set:
244  case Token::kw_dense:
245  case Token::kw_dense_resource:
246  case Token::kw_false:
247  case Token::kw_loc:
248  case Token::kw_sparse:
249  case Token::kw_true:
250  case Token::kw_unit:
251  case Token::l_brace:
252  case Token::l_square:
253  case Token::minus:
254  case Token::string:
255  attribute = parseAttribute(type);
256  return success(attribute != nullptr);
257 
258  default:
259  // Parse an optional type attribute.
260  Type type;
261  OptionalParseResult result = parseOptionalType(type);
262  if (result.has_value() && succeeded(*result))
263  attribute = TypeAttr::get(type);
264  return result;
265  }
266 }
268  Type type) {
269  return parseOptionalAttributeWithToken(Token::l_square, attribute, type);
270 }
272  Type type) {
273  return parseOptionalAttributeWithToken(Token::string, attribute, type);
274 }
275 
276 /// Attribute dictionary.
277 ///
278 /// attribute-dict ::= `{` `}`
279 /// | `{` attribute-entry (`,` attribute-entry)* `}`
280 /// attribute-entry ::= (bare-id | string-literal) `=` attribute-value
281 ///
283  llvm::SmallDenseSet<StringAttr> seenKeys;
284  auto parseElt = [&]() -> ParseResult {
285  // The name of an attribute can either be a bare identifier, or a string.
286  Optional<StringAttr> nameId;
287  if (getToken().is(Token::string))
288  nameId = builder.getStringAttr(getToken().getStringValue());
289  else if (getToken().isAny(Token::bare_identifier, Token::inttype) ||
290  getToken().isKeyword())
292  else
293  return emitWrongTokenError("expected attribute name");
294 
295  if (nameId->size() == 0)
296  return emitError("expected valid attribute name");
297 
298  if (!seenKeys.insert(*nameId).second)
299  return emitError("duplicate key '")
300  << nameId->getValue() << "' in dictionary attribute";
301  consumeToken();
302 
303  // Lazy load a dialect in the context if there is a possible namespace.
304  auto splitName = nameId->strref().split('.');
305  if (!splitName.second.empty())
306  getContext()->getOrLoadDialect(splitName.first);
307 
308  // Try to parse the '=' for the attribute value.
309  if (!consumeIf(Token::equal)) {
310  // If there is no '=', we treat this as a unit attribute.
311  attributes.push_back({*nameId, builder.getUnitAttr()});
312  return success();
313  }
314 
315  auto attr = parseAttribute();
316  if (!attr)
317  return failure();
318  attributes.push_back({*nameId, attr});
319  return success();
320  };
321 
323  " in attribute dictionary");
324 }
325 
326 /// Parse a float attribute.
327 Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
328  auto val = getToken().getFloatingPointValue();
329  if (!val)
330  return (emitError("floating point value too large for attribute"), nullptr);
331  consumeToken(Token::floatliteral);
332  if (!type) {
333  // Default to F64 when no type is specified.
334  if (!consumeIf(Token::colon))
335  type = builder.getF64Type();
336  else if (!(type = parseType()))
337  return nullptr;
338  }
339  if (!type.isa<FloatType>())
340  return (emitError("floating point value not valid for specified type"),
341  nullptr);
342  return FloatAttr::get(type, isNegative ? -*val : *val);
343 }
344 
345 /// Construct an APint from a parsed value, a known attribute type and
346 /// sign.
347 static Optional<APInt> buildAttributeAPInt(Type type, bool isNegative,
348  StringRef spelling) {
349  // Parse the integer value into an APInt that is big enough to hold the value.
350  APInt result;
351  bool isHex = spelling.size() > 1 && spelling[1] == 'x';
352  if (spelling.getAsInteger(isHex ? 0 : 10, result))
353  return llvm::None;
354 
355  // Extend or truncate the bitwidth to the right size.
356  unsigned width = type.isIndex() ? IndexType::kInternalStorageBitWidth
357  : type.getIntOrFloatBitWidth();
358 
359  if (width > result.getBitWidth()) {
360  result = result.zext(width);
361  } else if (width < result.getBitWidth()) {
362  // The parser can return an unnecessarily wide result with leading zeros.
363  // This isn't a problem, but truncating off bits is bad.
364  if (result.countLeadingZeros() < result.getBitWidth() - width)
365  return llvm::None;
366 
367  result = result.trunc(width);
368  }
369 
370  if (width == 0) {
371  // 0 bit integers cannot be negative and manipulation of their sign bit will
372  // assert, so short-cut validation here.
373  if (isNegative)
374  return llvm::None;
375  } else if (isNegative) {
376  // The value is negative, we have an overflow if the sign bit is not set
377  // in the negated apInt.
378  result.negate();
379  if (!result.isSignBitSet())
380  return llvm::None;
381  } else if ((type.isSignedInteger() || type.isIndex()) &&
382  result.isSignBitSet()) {
383  // The value is a positive signed integer or index,
384  // we have an overflow if the sign bit is set.
385  return llvm::None;
386  }
387 
388  return result;
389 }
390 
391 /// Parse a decimal or a hexadecimal literal, which can be either an integer
392 /// or a float attribute.
393 Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
394  Token tok = getToken();
395  StringRef spelling = tok.getSpelling();
396  SMLoc loc = tok.getLoc();
397 
398  consumeToken(Token::integer);
399  if (!type) {
400  // Default to i64 if not type is specified.
401  if (!consumeIf(Token::colon))
402  type = builder.getIntegerType(64);
403  else if (!(type = parseType()))
404  return nullptr;
405  }
406 
407  if (auto floatType = type.dyn_cast<FloatType>()) {
408  Optional<APFloat> result;
409  if (failed(parseFloatFromIntegerLiteral(result, tok, isNegative,
410  floatType.getFloatSemantics(),
411  floatType.getWidth())))
412  return Attribute();
413  return FloatAttr::get(floatType, *result);
414  }
415 
416  if (!type.isa<IntegerType, IndexType>())
417  return emitError(loc, "integer literal not valid for specified type"),
418  nullptr;
419 
420  if (isNegative && type.isUnsignedInteger()) {
421  emitError(loc,
422  "negative integer literal not valid for unsigned integer type");
423  return nullptr;
424  }
425 
426  Optional<APInt> apInt = buildAttributeAPInt(type, isNegative, spelling);
427  if (!apInt)
428  return emitError(loc, "integer constant out of range for attribute"),
429  nullptr;
430  return builder.getIntegerAttr(type, *apInt);
431 }
432 
433 //===----------------------------------------------------------------------===//
434 // TensorLiteralParser
435 //===----------------------------------------------------------------------===//
436 
437 /// Parse elements values stored within a hex string. On success, the values are
438 /// stored into 'result'.
440  std::string &result) {
442  result = std::move(*value);
443  return success();
444  }
445  return parser.emitError(
446  tok.getLoc(), "expected string containing hex digits starting with `0x`");
447 }
448 
449 namespace {
450 /// This class implements a parser for TensorLiterals. A tensor literal is
451 /// either a single element (e.g, 5) or a multi-dimensional list of elements
452 /// (e.g., [[5, 5]]).
453 class TensorLiteralParser {
454 public:
455  TensorLiteralParser(Parser &p) : p(p) {}
456 
457  /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
458  /// may also parse a tensor literal that is store as a hex string.
459  ParseResult parse(bool allowHex);
460 
461  /// Build a dense attribute instance with the parsed elements and the given
462  /// shaped type.
463  DenseElementsAttr getAttr(SMLoc loc, ShapedType type);
464 
465  ArrayRef<int64_t> getShape() const { return shape; }
466 
467 private:
468  /// Get the parsed elements for an integer attribute.
469  ParseResult getIntAttrElements(SMLoc loc, Type eltTy,
470  std::vector<APInt> &intValues);
471 
472  /// Get the parsed elements for a float attribute.
473  ParseResult getFloatAttrElements(SMLoc loc, FloatType eltTy,
474  std::vector<APFloat> &floatValues);
475 
476  /// Build a Dense String attribute for the given type.
477  DenseElementsAttr getStringAttr(SMLoc loc, ShapedType type, Type eltTy);
478 
479  /// Build a Dense attribute with hex data for the given type.
480  DenseElementsAttr getHexAttr(SMLoc loc, ShapedType type);
481 
482  /// Parse a single element, returning failure if it isn't a valid element
483  /// literal. For example:
484  /// parseElement(1) -> Success, 1
485  /// parseElement([1]) -> Failure
486  ParseResult parseElement();
487 
488  /// Parse a list of either lists or elements, returning the dimensions of the
489  /// parsed sub-tensors in dims. For example:
490  /// parseList([1, 2, 3]) -> Success, [3]
491  /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
492  /// parseList([[1, 2], 3]) -> Failure
493  /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure
494  ParseResult parseList(SmallVectorImpl<int64_t> &dims);
495 
496  /// Parse a literal that was printed as a hex string.
497  ParseResult parseHexElements();
498 
499  Parser &p;
500 
501  /// The shape inferred from the parsed elements.
503 
504  /// Storage used when parsing elements, this is a pair of <is_negated, token>.
505  std::vector<std::pair<bool, Token>> storage;
506 
507  /// Storage used when parsing elements that were stored as hex values.
508  Optional<Token> hexStorage;
509 };
510 } // namespace
511 
512 /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
513 /// may also parse a tensor literal that is store as a hex string.
514 ParseResult TensorLiteralParser::parse(bool allowHex) {
515  // If hex is allowed, check for a string literal.
516  if (allowHex && p.getToken().is(Token::string)) {
517  hexStorage = p.getToken();
518  p.consumeToken(Token::string);
519  return success();
520  }
521  // Otherwise, parse a list or an individual element.
522  if (p.getToken().is(Token::l_square))
523  return parseList(shape);
524  return parseElement();
525 }
526 
527 /// Build a dense attribute instance with the parsed elements and the given
528 /// shaped type.
529 DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) {
530  Type eltType = type.getElementType();
531 
532  // Check to see if we parse the literal from a hex string.
533  if (hexStorage &&
534  (eltType.isIntOrIndexOrFloat() || eltType.isa<ComplexType>()))
535  return getHexAttr(loc, type);
536 
537  // Check that the parsed storage size has the same number of elements to the
538  // type, or is a known splat.
539  if (!shape.empty() && getShape() != type.getShape()) {
540  p.emitError(loc) << "inferred shape of elements literal ([" << getShape()
541  << "]) does not match type ([" << type.getShape() << "])";
542  return nullptr;
543  }
544 
545  // Handle the case where no elements were parsed.
546  if (!hexStorage && storage.empty() && type.getNumElements()) {
547  p.emitError(loc) << "parsed zero elements, but type (" << type
548  << ") expected at least 1";
549  return nullptr;
550  }
551 
552  // Handle complex types in the specific element type cases below.
553  bool isComplex = false;
554  if (ComplexType complexTy = eltType.dyn_cast<ComplexType>()) {
555  eltType = complexTy.getElementType();
556  isComplex = true;
557  }
558 
559  // Handle integer and index types.
560  if (eltType.isIntOrIndex()) {
561  std::vector<APInt> intValues;
562  if (failed(getIntAttrElements(loc, eltType, intValues)))
563  return nullptr;
564  if (isComplex) {
565  // If this is a complex, treat the parsed values as complex values.
566  auto complexData = llvm::makeArrayRef(
567  reinterpret_cast<std::complex<APInt> *>(intValues.data()),
568  intValues.size() / 2);
569  return DenseElementsAttr::get(type, complexData);
570  }
571  return DenseElementsAttr::get(type, intValues);
572  }
573  // Handle floating point types.
574  if (FloatType floatTy = eltType.dyn_cast<FloatType>()) {
575  std::vector<APFloat> floatValues;
576  if (failed(getFloatAttrElements(loc, floatTy, floatValues)))
577  return nullptr;
578  if (isComplex) {
579  // If this is a complex, treat the parsed values as complex values.
580  auto complexData = llvm::makeArrayRef(
581  reinterpret_cast<std::complex<APFloat> *>(floatValues.data()),
582  floatValues.size() / 2);
583  return DenseElementsAttr::get(type, complexData);
584  }
585  return DenseElementsAttr::get(type, floatValues);
586  }
587 
588  // Other types are assumed to be string representations.
589  return getStringAttr(loc, type, type.getElementType());
590 }
591 
592 /// Build a Dense Integer attribute for the given type.
594 TensorLiteralParser::getIntAttrElements(SMLoc loc, Type eltTy,
595  std::vector<APInt> &intValues) {
596  intValues.reserve(storage.size());
597  bool isUintType = eltTy.isUnsignedInteger();
598  for (const auto &signAndToken : storage) {
599  bool isNegative = signAndToken.first;
600  const Token &token = signAndToken.second;
601  auto tokenLoc = token.getLoc();
602 
603  if (isNegative && isUintType) {
604  return p.emitError(tokenLoc)
605  << "expected unsigned integer elements, but parsed negative value";
606  }
607 
608  // Check to see if floating point values were parsed.
609  if (token.is(Token::floatliteral)) {
610  return p.emitError(tokenLoc)
611  << "expected integer elements, but parsed floating-point";
612  }
613 
614  assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) &&
615  "unexpected token type");
616  if (token.isAny(Token::kw_true, Token::kw_false)) {
617  if (!eltTy.isInteger(1)) {
618  return p.emitError(tokenLoc)
619  << "expected i1 type for 'true' or 'false' values";
620  }
621  APInt apInt(1, token.is(Token::kw_true), /*isSigned=*/false);
622  intValues.push_back(apInt);
623  continue;
624  }
625 
626  // Create APInt values for each element with the correct bitwidth.
627  Optional<APInt> apInt =
628  buildAttributeAPInt(eltTy, isNegative, token.getSpelling());
629  if (!apInt)
630  return p.emitError(tokenLoc, "integer constant out of range for type");
631  intValues.push_back(*apInt);
632  }
633  return success();
634 }
635 
636 /// Build a Dense Float attribute for the given type.
638 TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy,
639  std::vector<APFloat> &floatValues) {
640  floatValues.reserve(storage.size());
641  for (const auto &signAndToken : storage) {
642  bool isNegative = signAndToken.first;
643  const Token &token = signAndToken.second;
644 
645  // Handle hexadecimal float literals.
646  if (token.is(Token::integer) && token.getSpelling().startswith("0x")) {
647  Optional<APFloat> result;
648  if (failed(p.parseFloatFromIntegerLiteral(result, token, isNegative,
649  eltTy.getFloatSemantics(),
650  eltTy.getWidth())))
651  return failure();
652 
653  floatValues.push_back(*result);
654  continue;
655  }
656 
657  // Check to see if any decimal integers or booleans were parsed.
658  if (!token.is(Token::floatliteral))
659  return p.emitError()
660  << "expected floating-point elements, but parsed integer";
661 
662  // Build the float values from tokens.
663  auto val = token.getFloatingPointValue();
664  if (!val)
665  return p.emitError("floating point value too large for attribute");
666 
667  APFloat apVal(isNegative ? -*val : *val);
668  if (!eltTy.isF64()) {
669  bool unused;
670  apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven,
671  &unused);
672  }
673  floatValues.push_back(apVal);
674  }
675  return success();
676 }
677 
678 /// Build a Dense String attribute for the given type.
679 DenseElementsAttr TensorLiteralParser::getStringAttr(SMLoc loc, ShapedType type,
680  Type eltTy) {
681  if (hexStorage.has_value()) {
682  auto stringValue = hexStorage.value().getStringValue();
683  return DenseStringElementsAttr::get(type, {stringValue});
684  }
685 
686  std::vector<std::string> stringValues;
687  std::vector<StringRef> stringRefValues;
688  stringValues.reserve(storage.size());
689  stringRefValues.reserve(storage.size());
690 
691  for (auto val : storage) {
692  stringValues.push_back(val.second.getStringValue());
693  stringRefValues.emplace_back(stringValues.back());
694  }
695 
696  return DenseStringElementsAttr::get(type, stringRefValues);
697 }
698 
699 /// Build a Dense attribute with hex data for the given type.
700 DenseElementsAttr TensorLiteralParser::getHexAttr(SMLoc loc, ShapedType type) {
701  Type elementType = type.getElementType();
702  if (!elementType.isIntOrIndexOrFloat() && !elementType.isa<ComplexType>()) {
703  p.emitError(loc)
704  << "expected floating-point, integer, or complex element type, got "
705  << elementType;
706  return nullptr;
707  }
708 
709  std::string data;
710  if (parseElementAttrHexValues(p, *hexStorage, data))
711  return nullptr;
712 
713  ArrayRef<char> rawData(data.data(), data.size());
714  bool detectedSplat = false;
715  if (!DenseElementsAttr::isValidRawBuffer(type, rawData, detectedSplat)) {
716  p.emitError(loc) << "elements hex data size is invalid for provided type: "
717  << type;
718  return nullptr;
719  }
720 
721  if (llvm::support::endian::system_endianness() ==
722  llvm::support::endianness::big) {
723  // Convert endianess in big-endian(BE) machines. `rawData` is
724  // little-endian(LE) because HEX in raw data of dense element attribute
725  // is always LE format. It is converted into BE here to be used in BE
726  // machines.
727  SmallVector<char, 64> outDataVec(rawData.size());
728  MutableArrayRef<char> convRawData(outDataVec);
729  DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
730  rawData, convRawData, type);
731  return DenseElementsAttr::getFromRawBuffer(type, convRawData);
732  }
733 
734  return DenseElementsAttr::getFromRawBuffer(type, rawData);
735 }
736 
737 ParseResult TensorLiteralParser::parseElement() {
738  switch (p.getToken().getKind()) {
739  // Parse a boolean element.
740  case Token::kw_true:
741  case Token::kw_false:
742  case Token::floatliteral:
743  case Token::integer:
744  storage.emplace_back(/*isNegative=*/false, p.getToken());
745  p.consumeToken();
746  break;
747 
748  // Parse a signed integer or a negative floating-point element.
749  case Token::minus:
750  p.consumeToken(Token::minus);
751  if (!p.getToken().isAny(Token::floatliteral, Token::integer))
752  return p.emitError("expected integer or floating point literal");
753  storage.emplace_back(/*isNegative=*/true, p.getToken());
754  p.consumeToken();
755  break;
756 
757  case Token::string:
758  storage.emplace_back(/*isNegative=*/false, p.getToken());
759  p.consumeToken();
760  break;
761 
762  // Parse a complex element of the form '(' element ',' element ')'.
763  case Token::l_paren:
764  p.consumeToken(Token::l_paren);
765  if (parseElement() ||
766  p.parseToken(Token::comma, "expected ',' between complex elements") ||
767  parseElement() ||
768  p.parseToken(Token::r_paren, "expected ')' after complex elements"))
769  return failure();
770  break;
771 
772  default:
773  return p.emitError("expected element literal of primitive type");
774  }
775 
776  return success();
777 }
778 
779 /// Parse a list of either lists or elements, returning the dimensions of the
780 /// parsed sub-tensors in dims. For example:
781 /// parseList([1, 2, 3]) -> Success, [3]
782 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
783 /// parseList([[1, 2], 3]) -> Failure
784 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure
785 ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
786  auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims,
787  const SmallVectorImpl<int64_t> &newDims) -> ParseResult {
788  if (prevDims == newDims)
789  return success();
790  return p.emitError("tensor literal is invalid; ranks are not consistent "
791  "between elements");
792  };
793 
794  bool first = true;
795  SmallVector<int64_t, 4> newDims;
796  unsigned size = 0;
797  auto parseOneElement = [&]() -> ParseResult {
798  SmallVector<int64_t, 4> thisDims;
799  if (p.getToken().getKind() == Token::l_square) {
800  if (parseList(thisDims))
801  return failure();
802  } else if (parseElement()) {
803  return failure();
804  }
805  ++size;
806  if (!first)
807  return checkDims(newDims, thisDims);
808  newDims = thisDims;
809  first = false;
810  return success();
811  };
812  if (p.parseCommaSeparatedList(Parser::Delimiter::Square, parseOneElement))
813  return failure();
814 
815  // Return the sublists' dimensions with 'size' prepended.
816  dims.clear();
817  dims.push_back(size);
818  dims.append(newDims.begin(), newDims.end());
819  return success();
820 }
821 
822 //===----------------------------------------------------------------------===//
823 // ElementsAttr Parser
824 //===----------------------------------------------------------------------===//
825 
826 namespace {
827 /// This class provides an implementation of AsmParser, allowing to call back
828 /// into the libMLIRIR-provided APIs for invoking attribute parsing code defined
829 /// in libMLIRIR.
830 class CustomAsmParser : public AsmParserImpl<AsmParser> {
831 public:
832  CustomAsmParser(Parser &parser)
833  : AsmParserImpl<AsmParser>(parser.getToken().getLoc(), parser) {}
834 };
835 } // namespace
836 
837 /// Parse a dense array attribute.
839  consumeToken(Token::kw_array);
840  SMLoc typeLoc = getToken().getLoc();
841  if (parseToken(Token::less, "expected '<' after 'array'") ||
842  (!type && !(type = parseType())))
843  return {};
844  CustomAsmParser parser(*this);
845  Attribute result;
846  // Check for empty list.
847  bool isEmptyList = getToken().is(Token::greater);
848  if (!isEmptyList &&
849  parseToken(Token::colon, "expected ':' after dense array type"))
850  return {};
851 
852  if (auto intType = type.dyn_cast<IntegerType>()) {
853  switch (type.getIntOrFloatBitWidth()) {
854  case 1:
855  if (isEmptyList)
856  result = DenseBoolArrayAttr::get(parser.getContext(), {});
857  else
858  result = DenseBoolArrayAttr::parseWithoutBraces(parser, Type{});
859  break;
860  case 8:
861  if (isEmptyList)
862  result = DenseI8ArrayAttr::get(parser.getContext(), {});
863  else
864  result = DenseI8ArrayAttr::parseWithoutBraces(parser, Type{});
865  break;
866  case 16:
867  if (isEmptyList)
868  result = DenseI16ArrayAttr::get(parser.getContext(), {});
869  else
870  result = DenseI16ArrayAttr::parseWithoutBraces(parser, Type{});
871  break;
872  case 32:
873  if (isEmptyList)
874  result = DenseI32ArrayAttr::get(parser.getContext(), {});
875  else
876  result = DenseI32ArrayAttr::parseWithoutBraces(parser, Type{});
877  break;
878  case 64:
879  if (isEmptyList)
880  result = DenseI64ArrayAttr::get(parser.getContext(), {});
881  else
882  result = DenseI64ArrayAttr::parseWithoutBraces(parser, Type{});
883  break;
884  default:
885  emitError(typeLoc, "expected i1, i8, i16, i32, or i64 but got: ") << type;
886  return {};
887  }
888  } else if (auto floatType = type.dyn_cast<FloatType>()) {
889  switch (type.getIntOrFloatBitWidth()) {
890  case 32:
891  if (isEmptyList)
892  result = DenseF32ArrayAttr::get(parser.getContext(), {});
893  else
894  result = DenseF32ArrayAttr::parseWithoutBraces(parser, Type{});
895  break;
896  case 64:
897  if (isEmptyList)
898  result = DenseF64ArrayAttr::get(parser.getContext(), {});
899  else
900  result = DenseF64ArrayAttr::parseWithoutBraces(parser, Type{});
901  break;
902  default:
903  emitError(typeLoc, "expected f32 or f64 but got: ") << type;
904  return {};
905  }
906  } else {
907  emitError(typeLoc, "expected integer or float type, got: ") << type;
908  return {};
909  }
910  if (parseToken(Token::greater, "expected '>' to close an array attribute"))
911  return {};
912  return result;
913 }
914 
915 /// Parse a dense elements attribute.
917  auto attribLoc = getToken().getLoc();
918  consumeToken(Token::kw_dense);
919  if (parseToken(Token::less, "expected '<' after 'dense'"))
920  return nullptr;
921 
922  // Parse the literal data if necessary.
923  TensorLiteralParser literalParser(*this);
924  if (!consumeIf(Token::greater)) {
925  if (literalParser.parse(/*allowHex=*/true) ||
926  parseToken(Token::greater, "expected '>'"))
927  return nullptr;
928  }
929 
930  // If the type is specified `parseElementsLiteralType` will not parse a type.
931  // Use the attribute location as the location for error reporting in that
932  // case.
933  auto loc = attrType ? attribLoc : getToken().getLoc();
934  auto type = parseElementsLiteralType(attrType);
935  if (!type)
936  return nullptr;
937  return literalParser.getAttr(loc, type);
938 }
939 
941  auto loc = getToken().getLoc();
942  consumeToken(Token::kw_dense_resource);
943  if (parseToken(Token::less, "expected '<' after 'dense_resource'"))
944  return nullptr;
945 
946  // Parse the resource handle.
948  parseResourceHandle(getContext()->getLoadedDialect<BuiltinDialect>());
949  if (failed(rawHandle) || parseToken(Token::greater, "expected '>'"))
950  return nullptr;
951 
952  auto *handle = dyn_cast<DenseResourceElementsHandle>(&*rawHandle);
953  if (!handle)
954  return emitError(loc, "invalid `dense_resource` handle type"), nullptr;
955 
956  // Parse the type of the attribute if the user didn't provide one.
957  SMLoc typeLoc = loc;
958  if (!attrType) {
959  typeLoc = getToken().getLoc();
960  if (parseToken(Token::colon, "expected ':'") || !(attrType = parseType()))
961  return nullptr;
962  }
963 
964  ShapedType shapedType = attrType.dyn_cast<ShapedType>();
965  if (!shapedType) {
966  emitError(typeLoc, "`dense_resource` expected a shaped type");
967  return nullptr;
968  }
969 
970  return DenseResourceElementsAttr::get(shapedType, *handle);
971 }
972 
973 /// Shaped type for elements attribute.
974 ///
975 /// elements-literal-type ::= vector-type | ranked-tensor-type
976 ///
977 /// This method also checks the type has static shape.
979  // If the user didn't provide a type, parse the colon type for the literal.
980  if (!type) {
981  if (parseToken(Token::colon, "expected ':'"))
982  return nullptr;
983  if (!(type = parseType()))
984  return nullptr;
985  }
986 
987  if (!type.isa<RankedTensorType, VectorType>()) {
988  emitError("elements literal must be a ranked tensor or vector type");
989  return nullptr;
990  }
991 
992  auto sType = type.cast<ShapedType>();
993  if (!sType.hasStaticShape())
994  return (emitError("elements literal type must have static shape"), nullptr);
995 
996  return sType;
997 }
998 
999 /// Parse a sparse elements attribute.
1001  SMLoc loc = getToken().getLoc();
1002  consumeToken(Token::kw_sparse);
1003  if (parseToken(Token::less, "Expected '<' after 'sparse'"))
1004  return nullptr;
1005 
1006  // Check for the case where all elements are sparse. The indices are
1007  // represented by a 2-dimensional shape where the second dimension is the rank
1008  // of the type.
1009  Type indiceEltType = builder.getIntegerType(64);
1010  if (consumeIf(Token::greater)) {
1011  ShapedType type = parseElementsLiteralType(attrType);
1012  if (!type)
1013  return nullptr;
1014 
1015  // Construct the sparse elements attr using zero element indice/value
1016  // attributes.
1017  ShapedType indicesType =
1018  RankedTensorType::get({0, type.getRank()}, indiceEltType);
1019  ShapedType valuesType = RankedTensorType::get({0}, type.getElementType());
1020  return getChecked<SparseElementsAttr>(
1021  loc, type, DenseElementsAttr::get(indicesType, ArrayRef<Attribute>()),
1023  }
1024 
1025  /// Parse the indices. We don't allow hex values here as we may need to use
1026  /// the inferred shape.
1027  auto indicesLoc = getToken().getLoc();
1028  TensorLiteralParser indiceParser(*this);
1029  if (indiceParser.parse(/*allowHex=*/false))
1030  return nullptr;
1031 
1032  if (parseToken(Token::comma, "expected ','"))
1033  return nullptr;
1034 
1035  /// Parse the values.
1036  auto valuesLoc = getToken().getLoc();
1037  TensorLiteralParser valuesParser(*this);
1038  if (valuesParser.parse(/*allowHex=*/true))
1039  return nullptr;
1040 
1041  if (parseToken(Token::greater, "expected '>'"))
1042  return nullptr;
1043 
1044  auto type = parseElementsLiteralType(attrType);
1045  if (!type)
1046  return nullptr;
1047 
1048  // If the indices are a splat, i.e. the literal parser parsed an element and
1049  // not a list, we set the shape explicitly. The indices are represented by a
1050  // 2-dimensional shape where the second dimension is the rank of the type.
1051  // Given that the parsed indices is a splat, we know that we only have one
1052  // indice and thus one for the first dimension.
1053  ShapedType indicesType;
1054  if (indiceParser.getShape().empty()) {
1055  indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType);
1056  } else {
1057  // Otherwise, set the shape to the one parsed by the literal parser.
1058  indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType);
1059  }
1060  auto indices = indiceParser.getAttr(indicesLoc, indicesType);
1061 
1062  // If the values are a splat, set the shape explicitly based on the number of
1063  // indices. The number of indices is encoded in the first dimension of the
1064  // indice shape type.
1065  auto valuesEltType = type.getElementType();
1066  ShapedType valuesType =
1067  valuesParser.getShape().empty()
1068  ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType)
1069  : RankedTensorType::get(valuesParser.getShape(), valuesEltType);
1070  auto values = valuesParser.getAttr(valuesLoc, valuesType);
1071 
1072  // Build the sparse elements attribute by the indices and values.
1073  return getChecked<SparseElementsAttr>(loc, type, indices, values);
1074 }
Include the generated interface declarations.
Optional< std::string > getHexStringValue() const
Given a token containing a hex string literal, return its value or None if the token does not contain...
Definition: Token.cpp:129
MLIRContext * getContext() const
Definition: Parser.h:36
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::string getSymbolReference() const
Given a token containing a symbol reference, return the unescaped string value.
Definition: Token.cpp:147
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:89
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
Definition: Types.cpp:31
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Builder builder
Definition: Parser.h:29
InFlightDiagnostic emitWrongTokenError(const Twine &message={})
Emit an error about a "wrong token".
Definition: Parser.cpp:180
bool consumeIf(Token::Kind kind)
If the current token has the specified kind, consume it and return true.
Definition: Parser.h:106
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
Location objects represent source locations information in MLIR.
Definition: Location.h:31
Attribute parseDenseArrayAttr(Type type)
Parse a DenseArrayAttr.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
OptionalParseResult parseOptionalType(Type &type)
Optionally parse a type.
Definition: TypeParser.cpp:23
void addUses(Value value, ArrayRef< SMLoc > locations)
Add a source uses of the given value.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
FailureOr< AsmDialectResourceHandle > parseResourceHandle(const OpAsmDialectInterface *dialect, StringRef &name)
Parse a handle to a dialect resource within the assembly format.
Definition: Parser.cpp:324
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:61
ParserState & state
The Parser is subclassed and reinstantiated.
Definition: Parser.h:344
static ParseResult parseElementAttrHexValues(Parser &parser, Token tok, std::string &result)
Parse elements values stored within a hex string.
static constexpr const bool value
bool isAny(Kind k1, Kind k2) const
Definition: Token.h:40
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
Attribute parseAttribute(Type type={})
Parse an arbitrary attribute with an optional type.
bool is(Kind k) const
Definition: Token.h:38
StringRef getTokenSpelling() const
Definition: Parser.h:102
AsmParserState * asmState
An optional pointer to a struct containing high level parser state to be populated during parsing...
Definition: ParserState.h:72
ParseResult parseLocationInstance(LocationAttr &loc)
Parse a raw location instance.
ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())
Parse a list of comma-separated items with an optional delimiter.
Definition: Parser.cpp:49
static bool isValidRawBuffer(ShapedType type, ArrayRef< char > rawBuffer, bool &detectedSplat)
Returns true if the given buffer is a valid raw buffer for the given type.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
{} brackets surrounding zero or more operands.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
This class defines a dialect specific handle to a resource blob.
unsigned getWidth()
Return the bitwidth of this float type.
This class provides support for representing a failure result, or a valid value of type T...
Definition: LogicalResult.h:78
An attribute that represents a reference to a dense vector or tensor object.
OptionalParseResult parseOptionalAttribute(Attribute &attribute, Type type={})
Parse an optional attribute with the provided type.
Attribute parseDecOrHexAttr(Type type, bool isNegative)
Parse a decimal or a hexadecimal literal, which can be either an integer or a float attribute...
ParseResult parseFloatFromIntegerLiteral(Optional< APFloat > &result, const Token &tok, bool isNegative, const llvm::fltSemantics &semantics, size_t typeSizeInBits)
Parse a floating point value from an integer literal token.
Definition: Parser.cpp:276
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:194
U dyn_cast() const
Definition: Types.h:270
UnitAttr getUnitAttr()
Definition: Builders.cpp:85
ShapedType parseElementsLiteralType(Type type)
Shaped type for elements attribute.
Attributes are known-constant values of operations.
Definition: Attributes.h:24
This class provides the implementation of the generic parser methods within AsmParser.
Definition: AsmParserImpl.h:26
static DenseArrayAttr get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:58
bool isIndex() const
Definition: Types.cpp:28
const Token & getToken() const
Return the current token the parser is inspecting.
Definition: Parser.h:101
std::string getStringValue() const
Given a token containing a string literal, return its value, including removing the quote characters ...
Definition: Token.cpp:80
A multi-dimensional affine map Affine map&#39;s are immutable like Type&#39;s, and they are uniqued...
Definition: AffineMap.h:42
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:49
void resetPointer(const char *newPointer)
Change the position of the lexer cursor.
Definition: Lexer.h:38
InFlightDiagnostic emitError(const Twine &message={})
Emit an error and return failure.
Definition: Parser.cpp:157
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:37
Attribute parseDenseElementsAttr(Type attrType)
Parse a dense elements attribute.
ParseResult parseAttributeDict(NamedAttrList &attributes)
Parse an attribute dictionary.
bool isF64() const
Definition: Types.cpp:24
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
Definition: MLIRContext.h:92
Type parseType()
Parse an arbitrary type.
Definition: TypeParser.cpp:54
StringRef getSpelling() const
Definition: Token.h:34
static Optional< APInt > buildAttributeAPInt(Type type, bool isNegative, StringRef spelling)
Construct an APint from a parsed value, a known attribute type and sign.
ParseResult parseAffineMapReference(AffineMap &map)
void consumeToken()
Advance the current lexer onto the next token.
Definition: Parser.h:114
Attribute parseSparseElementsAttr(Type attrType)
Parse a sparse elements attribute.
Attribute parseDenseResourceElementsAttr(Type attrType)
Parse a dense resource elements attribute.
ParseResult parseCommaSeparatedListUntil(Token::Kind rightToken, function_ref< ParseResult()> parseElement, bool allowEmptyList=true)
Parse a comma-separated list of elements up until the specified end token.
Definition: Parser.cpp:138
Lexer lex
The lexer for the source file we&#39;re parsing.
Definition: ParserState.h:62
FloatType getF64Type()
Definition: Builders.cpp:42
OptionalParseResult parseOptionalAttributeWithToken(Token::Kind kind, AttributeT &attr, Type type={})
Parse an optional attribute that is demarcated by a specific token.
Definition: Parser.h:243
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:87
This class implement support for parsing global entities like attributes and types.
Definition: Parser.h:25
Attribute parseFloatAttr(Type type, bool isNegative)
Parse a float attribute.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:91
Attribute codeCompleteAttribute()
Definition: Parser.cpp:426
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:47
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition: Types.cpp:85
static Attribute parseWithoutBraces(AsmParser &parser, Type odsType)
Parse the short form 42, 100, -1 without any type prefix or braces.
SMLoc getLoc() const
Definition: Token.cpp:18
ParseResult parseIntegerSetReference(IntegerSet &set)
bool isa() const
Definition: Types.h:254
This class represents success/failure for parsing-like operations that find it important to chain tog...
ParseResult parseToken(Token::Kind expectedToken, const Twine &message)
Consume the specified token if present and return success.
Definition: Parser.cpp:232
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute...
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:229
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:225
Optional< double > getFloatingPointValue() const
For a floatliteral token, return its value as a double.
Definition: Token.cpp:50
Attribute parseExtendedAttr(Type type)
Parse an extended attribute.
Square brackets surrounding zero or more operands.
U cast() const
Definition: Types.h:278
This represents a token in the MLIR syntax.
Definition: Token.h:20
An integer set representing a conjunction of one or more affine equalities and inequalities.
Definition: IntegerSet.h:44