MLIR  14.0.0git
AttributeParser.cpp
Go to the documentation of this file.
1 //===- AttributeParser.cpp - MLIR Attribute Parser Implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the parser for the MLIR Types.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Parser.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/IntegerSet.h"
19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/Support/Endian.h"
21 
22 using namespace mlir;
23 using namespace mlir::detail;
24 
25 /// Parse an arbitrary attribute.
26 ///
27 /// attribute-value ::= `unit`
28 /// | bool-literal
29 /// | integer-literal (`:` (index-type | integer-type))?
30 /// | float-literal (`:` float-type)?
31 /// | string-literal (`:` type)?
32 /// | type
33 /// | `[` (attribute-value (`,` attribute-value)*)? `]`
34 /// | `{` (attribute-entry (`,` attribute-entry)*)? `}`
35 /// | symbol-ref-id (`::` symbol-ref-id)*
36 /// | `dense` `<` attribute-value `>` `:`
37 /// (tensor-type | vector-type)
38 /// | `sparse` `<` attribute-value `,` attribute-value `>`
39 /// `:` (tensor-type | vector-type)
40 /// | `opaque` `<` dialect-namespace `,` hex-string-literal
41 /// `>` `:` (tensor-type | vector-type)
42 /// | extended-attribute
43 ///
45  switch (getToken().getKind()) {
46  // Parse an AffineMap or IntegerSet attribute.
47  case Token::kw_affine_map: {
48  consumeToken(Token::kw_affine_map);
49 
50  AffineMap map;
51  if (parseToken(Token::less, "expected '<' in affine map") ||
53  parseToken(Token::greater, "expected '>' in affine map"))
54  return Attribute();
55  return AffineMapAttr::get(map);
56  }
57  case Token::kw_affine_set: {
58  consumeToken(Token::kw_affine_set);
59 
60  IntegerSet set;
61  if (parseToken(Token::less, "expected '<' in integer set") ||
63  parseToken(Token::greater, "expected '>' in integer set"))
64  return Attribute();
65  return IntegerSetAttr::get(set);
66  }
67 
68  // Parse an array attribute.
69  case Token::l_square: {
71  auto parseElt = [&]() -> ParseResult {
72  elements.push_back(parseAttribute());
73  return elements.back() ? success() : failure();
74  };
75 
77  return nullptr;
78  return builder.getArrayAttr(elements);
79  }
80 
81  // Parse a boolean attribute.
82  case Token::kw_false:
83  consumeToken(Token::kw_false);
84  return builder.getBoolAttr(false);
85  case Token::kw_true:
86  consumeToken(Token::kw_true);
87  return builder.getBoolAttr(true);
88 
89  // Parse a dense elements attribute.
90  case Token::kw_dense:
91  return parseDenseElementsAttr(type);
92 
93  // Parse a dictionary attribute.
94  case Token::l_brace: {
95  NamedAttrList elements;
96  if (parseAttributeDict(elements))
97  return nullptr;
98  return elements.getDictionary(getContext());
99  }
100 
101  // Parse an extended attribute, i.e. alias or dialect attribute.
102  case Token::hash_identifier:
103  return parseExtendedAttr(type);
104 
105  // Parse floating point and integer attributes.
106  case Token::floatliteral:
107  return parseFloatAttr(type, /*isNegative=*/false);
108  case Token::integer:
109  return parseDecOrHexAttr(type, /*isNegative=*/false);
110  case Token::minus: {
111  consumeToken(Token::minus);
112  if (getToken().is(Token::integer))
113  return parseDecOrHexAttr(type, /*isNegative=*/true);
114  if (getToken().is(Token::floatliteral))
115  return parseFloatAttr(type, /*isNegative=*/true);
116 
117  return (emitError("expected constant integer or floating point value"),
118  nullptr);
119  }
120 
121  // Parse a location attribute.
122  case Token::kw_loc: {
123  consumeToken(Token::kw_loc);
124 
125  LocationAttr locAttr;
126  if (parseToken(Token::l_paren, "expected '(' in inline location") ||
127  parseLocationInstance(locAttr) ||
128  parseToken(Token::r_paren, "expected ')' in inline location"))
129  return Attribute();
130  return locAttr;
131  }
132 
133  // Parse an opaque elements attribute.
134  case Token::kw_opaque:
135  return parseOpaqueElementsAttr(type);
136 
137  // Parse a sparse elements attribute.
138  case Token::kw_sparse:
139  return parseSparseElementsAttr(type);
140 
141  // Parse a string attribute.
142  case Token::string: {
143  auto val = getToken().getStringValue();
144  consumeToken(Token::string);
145  // Parse the optional trailing colon type if one wasn't explicitly provided.
146  if (!type && consumeIf(Token::colon) && !(type = parseType()))
147  return Attribute();
148 
149  return type ? StringAttr::get(val, type)
150  : StringAttr::get(getContext(), val);
151  }
152 
153  // Parse a symbol reference attribute.
154  case Token::at_identifier: {
155  // When populating the parser state, this is a list of locations for all of
156  // the nested references.
157  SmallVector<llvm::SMRange> referenceLocations;
158  if (state.asmState)
159  referenceLocations.push_back(getToken().getLocRange());
160 
161  // Parse the top-level reference.
162  std::string nameStr = getToken().getSymbolReference();
163  consumeToken(Token::at_identifier);
164 
165  // Parse any nested references.
166  std::vector<FlatSymbolRefAttr> nestedRefs;
167  while (getToken().is(Token::colon)) {
168  // Check for the '::' prefix.
169  const char *curPointer = getToken().getLoc().getPointer();
170  consumeToken(Token::colon);
171  if (!consumeIf(Token::colon)) {
172  state.lex.resetPointer(curPointer);
173  consumeToken();
174  break;
175  }
176  // Parse the reference itself.
177  auto curLoc = getToken().getLoc();
178  if (getToken().isNot(Token::at_identifier)) {
179  emitError(curLoc, "expected nested symbol reference identifier");
180  return Attribute();
181  }
182 
183  // If we are populating the assembly state, add the location for this
184  // reference.
185  if (state.asmState)
186  referenceLocations.push_back(getToken().getLocRange());
187 
188  std::string nameStr = getToken().getSymbolReference();
189  consumeToken(Token::at_identifier);
190  nestedRefs.push_back(SymbolRefAttr::get(getContext(), nameStr));
191  }
192  SymbolRefAttr symbolRefAttr =
193  SymbolRefAttr::get(getContext(), nameStr, nestedRefs);
194 
195  // If we are populating the assembly state, record this symbol reference.
196  if (state.asmState)
197  state.asmState->addUses(symbolRefAttr, referenceLocations);
198  return symbolRefAttr;
199  }
200 
201  // Parse a 'unit' attribute.
202  case Token::kw_unit:
203  consumeToken(Token::kw_unit);
204  return builder.getUnitAttr();
205 
206  default:
207  // Parse a type attribute.
208  if (Type type = parseType())
209  return TypeAttr::get(type);
210  return nullptr;
211  }
212 }
213 
214 /// Parse an optional attribute with the provided type.
216  Type type) {
217  switch (getToken().getKind()) {
218  case Token::at_identifier:
219  case Token::floatliteral:
220  case Token::integer:
221  case Token::hash_identifier:
222  case Token::kw_affine_map:
223  case Token::kw_affine_set:
224  case Token::kw_dense:
225  case Token::kw_false:
226  case Token::kw_loc:
227  case Token::kw_opaque:
228  case Token::kw_sparse:
229  case Token::kw_true:
230  case Token::kw_unit:
231  case Token::l_brace:
232  case Token::l_square:
233  case Token::minus:
234  case Token::string:
235  attribute = parseAttribute(type);
236  return success(attribute != nullptr);
237 
238  default:
239  // Parse an optional type attribute.
240  Type type;
241  OptionalParseResult result = parseOptionalType(type);
242  if (result.hasValue() && succeeded(*result))
243  attribute = TypeAttr::get(type);
244  return result;
245  }
246 }
248  Type type) {
249  return parseOptionalAttributeWithToken(Token::l_square, attribute, type);
250 }
252  Type type) {
253  return parseOptionalAttributeWithToken(Token::string, attribute, type);
254 }
255 
256 /// Attribute dictionary.
257 ///
258 /// attribute-dict ::= `{` `}`
259 /// | `{` attribute-entry (`,` attribute-entry)* `}`
260 /// attribute-entry ::= (bare-id | string-literal) `=` attribute-value
261 ///
263  llvm::SmallDenseSet<StringAttr> seenKeys;
264  auto parseElt = [&]() -> ParseResult {
265  // The name of an attribute can either be a bare identifier, or a string.
266  Optional<StringAttr> nameId;
267  if (getToken().is(Token::string))
268  nameId = builder.getStringAttr(getToken().getStringValue());
269  else if (getToken().isAny(Token::bare_identifier, Token::inttype) ||
270  getToken().isKeyword())
272  else
273  return emitError("expected attribute name");
274  if (!seenKeys.insert(*nameId).second)
275  return emitError("duplicate key '")
276  << nameId->getValue() << "' in dictionary attribute";
277  consumeToken();
278 
279  // Lazy load a dialect in the context if there is a possible namespace.
280  auto splitName = nameId->strref().split('.');
281  if (!splitName.second.empty())
282  getContext()->getOrLoadDialect(splitName.first);
283 
284  // Try to parse the '=' for the attribute value.
285  if (!consumeIf(Token::equal)) {
286  // If there is no '=', we treat this as a unit attribute.
287  attributes.push_back({*nameId, builder.getUnitAttr()});
288  return success();
289  }
290 
291  auto attr = parseAttribute();
292  if (!attr)
293  return failure();
294  attributes.push_back({*nameId, attr});
295  return success();
296  };
297 
299  " in attribute dictionary"))
300  return failure();
301 
302  return success();
303 }
304 
305 /// Parse a float attribute.
306 Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
307  auto val = getToken().getFloatingPointValue();
308  if (!val.hasValue())
309  return (emitError("floating point value too large for attribute"), nullptr);
310  consumeToken(Token::floatliteral);
311  if (!type) {
312  // Default to F64 when no type is specified.
313  if (!consumeIf(Token::colon))
314  type = builder.getF64Type();
315  else if (!(type = parseType()))
316  return nullptr;
317  }
318  if (!type.isa<FloatType>())
319  return (emitError("floating point value not valid for specified type"),
320  nullptr);
321  return FloatAttr::get(type, isNegative ? -val.getValue() : val.getValue());
322 }
323 
324 /// Construct an APint from a parsed value, a known attribute type and
325 /// sign.
326 static Optional<APInt> buildAttributeAPInt(Type type, bool isNegative,
327  StringRef spelling) {
328  // Parse the integer value into an APInt that is big enough to hold the value.
329  APInt result;
330  bool isHex = spelling.size() > 1 && spelling[1] == 'x';
331  if (spelling.getAsInteger(isHex ? 0 : 10, result))
332  return llvm::None;
333 
334  // Extend or truncate the bitwidth to the right size.
335  unsigned width = type.isIndex() ? IndexType::kInternalStorageBitWidth
336  : type.getIntOrFloatBitWidth();
337 
338  if (width > result.getBitWidth()) {
339  result = result.zext(width);
340  } else if (width < result.getBitWidth()) {
341  // The parser can return an unnecessarily wide result with leading zeros.
342  // This isn't a problem, but truncating off bits is bad.
343  if (result.countLeadingZeros() < result.getBitWidth() - width)
344  return llvm::None;
345 
346  result = result.trunc(width);
347  }
348 
349  if (width == 0) {
350  // 0 bit integers cannot be negative and manipulation of their sign bit will
351  // assert, so short-cut validation here.
352  if (isNegative)
353  return llvm::None;
354  } else if (isNegative) {
355  // The value is negative, we have an overflow if the sign bit is not set
356  // in the negated apInt.
357  result.negate();
358  if (!result.isSignBitSet())
359  return llvm::None;
360  } else if ((type.isSignedInteger() || type.isIndex()) &&
361  result.isSignBitSet()) {
362  // The value is a positive signed integer or index,
363  // we have an overflow if the sign bit is set.
364  return llvm::None;
365  }
366 
367  return result;
368 }
369 
370 /// Parse a decimal or a hexadecimal literal, which can be either an integer
371 /// or a float attribute.
372 Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
373  Token tok = getToken();
374  StringRef spelling = tok.getSpelling();
375  llvm::SMLoc loc = tok.getLoc();
376 
377  consumeToken(Token::integer);
378  if (!type) {
379  // Default to i64 if not type is specified.
380  if (!consumeIf(Token::colon))
381  type = builder.getIntegerType(64);
382  else if (!(type = parseType()))
383  return nullptr;
384  }
385 
386  if (auto floatType = type.dyn_cast<FloatType>()) {
387  Optional<APFloat> result;
388  if (failed(parseFloatFromIntegerLiteral(result, tok, isNegative,
389  floatType.getFloatSemantics(),
390  floatType.getWidth())))
391  return Attribute();
392  return FloatAttr::get(floatType, *result);
393  }
394 
395  if (!type.isa<IntegerType, IndexType>())
396  return emitError(loc, "integer literal not valid for specified type"),
397  nullptr;
398 
399  if (isNegative && type.isUnsignedInteger()) {
400  emitError(loc,
401  "negative integer literal not valid for unsigned integer type");
402  return nullptr;
403  }
404 
405  Optional<APInt> apInt = buildAttributeAPInt(type, isNegative, spelling);
406  if (!apInt)
407  return emitError(loc, "integer constant out of range for attribute"),
408  nullptr;
409  return builder.getIntegerAttr(type, *apInt);
410 }
411 
412 //===----------------------------------------------------------------------===//
413 // TensorLiteralParser
414 //===----------------------------------------------------------------------===//
415 
416 /// Parse elements values stored within a hex string. On success, the values are
417 /// stored into 'result'.
419  std::string &result) {
421  result = std::move(*value);
422  return success();
423  }
424  return parser.emitError(
425  tok.getLoc(), "expected string containing hex digits starting with `0x`");
426 }
427 
428 namespace {
429 /// This class implements a parser for TensorLiterals. A tensor literal is
430 /// either a single element (e.g, 5) or a multi-dimensional list of elements
431 /// (e.g., [[5, 5]]).
432 class TensorLiteralParser {
433 public:
434  TensorLiteralParser(Parser &p) : p(p) {}
435 
436  /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
437  /// may also parse a tensor literal that is store as a hex string.
438  ParseResult parse(bool allowHex);
439 
440  /// Build a dense attribute instance with the parsed elements and the given
441  /// shaped type.
442  DenseElementsAttr getAttr(llvm::SMLoc loc, ShapedType type);
443 
444  ArrayRef<int64_t> getShape() const { return shape; }
445 
446 private:
447  /// Get the parsed elements for an integer attribute.
448  ParseResult getIntAttrElements(llvm::SMLoc loc, Type eltTy,
449  std::vector<APInt> &intValues);
450 
451  /// Get the parsed elements for a float attribute.
452  ParseResult getFloatAttrElements(llvm::SMLoc loc, FloatType eltTy,
453  std::vector<APFloat> &floatValues);
454 
455  /// Build a Dense String attribute for the given type.
456  DenseElementsAttr getStringAttr(llvm::SMLoc loc, ShapedType type, Type eltTy);
457 
458  /// Build a Dense attribute with hex data for the given type.
459  DenseElementsAttr getHexAttr(llvm::SMLoc loc, ShapedType type);
460 
461  /// Parse a single element, returning failure if it isn't a valid element
462  /// literal. For example:
463  /// parseElement(1) -> Success, 1
464  /// parseElement([1]) -> Failure
465  ParseResult parseElement();
466 
467  /// Parse a list of either lists or elements, returning the dimensions of the
468  /// parsed sub-tensors in dims. For example:
469  /// parseList([1, 2, 3]) -> Success, [3]
470  /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
471  /// parseList([[1, 2], 3]) -> Failure
472  /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure
473  ParseResult parseList(SmallVectorImpl<int64_t> &dims);
474 
475  /// Parse a literal that was printed as a hex string.
476  ParseResult parseHexElements();
477 
478  Parser &p;
479 
480  /// The shape inferred from the parsed elements.
482 
483  /// Storage used when parsing elements, this is a pair of <is_negated, token>.
484  std::vector<std::pair<bool, Token>> storage;
485 
486  /// Storage used when parsing elements that were stored as hex values.
487  Optional<Token> hexStorage;
488 };
489 } // namespace
490 
491 /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
492 /// may also parse a tensor literal that is store as a hex string.
493 ParseResult TensorLiteralParser::parse(bool allowHex) {
494  // If hex is allowed, check for a string literal.
495  if (allowHex && p.getToken().is(Token::string)) {
496  hexStorage = p.getToken();
497  p.consumeToken(Token::string);
498  return success();
499  }
500  // Otherwise, parse a list or an individual element.
501  if (p.getToken().is(Token::l_square))
502  return parseList(shape);
503  return parseElement();
504 }
505 
506 /// Build a dense attribute instance with the parsed elements and the given
507 /// shaped type.
508 DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc,
509  ShapedType type) {
510  Type eltType = type.getElementType();
511 
512  // Check to see if we parse the literal from a hex string.
513  if (hexStorage.hasValue() &&
514  (eltType.isIntOrIndexOrFloat() || eltType.isa<ComplexType>()))
515  return getHexAttr(loc, type);
516 
517  // Check that the parsed storage size has the same number of elements to the
518  // type, or is a known splat.
519  if (!shape.empty() && getShape() != type.getShape()) {
520  p.emitError(loc) << "inferred shape of elements literal ([" << getShape()
521  << "]) does not match type ([" << type.getShape() << "])";
522  return nullptr;
523  }
524 
525  // Handle the case where no elements were parsed.
526  if (!hexStorage.hasValue() && storage.empty() && type.getNumElements()) {
527  p.emitError(loc) << "parsed zero elements, but type (" << type
528  << ") expected at least 1";
529  return nullptr;
530  }
531 
532  // Handle complex types in the specific element type cases below.
533  bool isComplex = false;
534  if (ComplexType complexTy = eltType.dyn_cast<ComplexType>()) {
535  eltType = complexTy.getElementType();
536  isComplex = true;
537  }
538 
539  // Handle integer and index types.
540  if (eltType.isIntOrIndex()) {
541  std::vector<APInt> intValues;
542  if (failed(getIntAttrElements(loc, eltType, intValues)))
543  return nullptr;
544  if (isComplex) {
545  // If this is a complex, treat the parsed values as complex values.
546  auto complexData = llvm::makeArrayRef(
547  reinterpret_cast<std::complex<APInt> *>(intValues.data()),
548  intValues.size() / 2);
549  return DenseElementsAttr::get(type, complexData);
550  }
551  return DenseElementsAttr::get(type, intValues);
552  }
553  // Handle floating point types.
554  if (FloatType floatTy = eltType.dyn_cast<FloatType>()) {
555  std::vector<APFloat> floatValues;
556  if (failed(getFloatAttrElements(loc, floatTy, floatValues)))
557  return nullptr;
558  if (isComplex) {
559  // If this is a complex, treat the parsed values as complex values.
560  auto complexData = llvm::makeArrayRef(
561  reinterpret_cast<std::complex<APFloat> *>(floatValues.data()),
562  floatValues.size() / 2);
563  return DenseElementsAttr::get(type, complexData);
564  }
565  return DenseElementsAttr::get(type, floatValues);
566  }
567 
568  // Other types are assumed to be string representations.
569  return getStringAttr(loc, type, type.getElementType());
570 }
571 
572 /// Build a Dense Integer attribute for the given type.
574 TensorLiteralParser::getIntAttrElements(llvm::SMLoc loc, Type eltTy,
575  std::vector<APInt> &intValues) {
576  intValues.reserve(storage.size());
577  bool isUintType = eltTy.isUnsignedInteger();
578  for (const auto &signAndToken : storage) {
579  bool isNegative = signAndToken.first;
580  const Token &token = signAndToken.second;
581  auto tokenLoc = token.getLoc();
582 
583  if (isNegative && isUintType) {
584  return p.emitError(tokenLoc)
585  << "expected unsigned integer elements, but parsed negative value";
586  }
587 
588  // Check to see if floating point values were parsed.
589  if (token.is(Token::floatliteral)) {
590  return p.emitError(tokenLoc)
591  << "expected integer elements, but parsed floating-point";
592  }
593 
594  assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) &&
595  "unexpected token type");
596  if (token.isAny(Token::kw_true, Token::kw_false)) {
597  if (!eltTy.isInteger(1)) {
598  return p.emitError(tokenLoc)
599  << "expected i1 type for 'true' or 'false' values";
600  }
601  APInt apInt(1, token.is(Token::kw_true), /*isSigned=*/false);
602  intValues.push_back(apInt);
603  continue;
604  }
605 
606  // Create APInt values for each element with the correct bitwidth.
607  Optional<APInt> apInt =
608  buildAttributeAPInt(eltTy, isNegative, token.getSpelling());
609  if (!apInt)
610  return p.emitError(tokenLoc, "integer constant out of range for type");
611  intValues.push_back(*apInt);
612  }
613  return success();
614 }
615 
616 /// Build a Dense Float attribute for the given type.
618 TensorLiteralParser::getFloatAttrElements(llvm::SMLoc loc, FloatType eltTy,
619  std::vector<APFloat> &floatValues) {
620  floatValues.reserve(storage.size());
621  for (const auto &signAndToken : storage) {
622  bool isNegative = signAndToken.first;
623  const Token &token = signAndToken.second;
624 
625  // Handle hexadecimal float literals.
626  if (token.is(Token::integer) && token.getSpelling().startswith("0x")) {
627  Optional<APFloat> result;
628  if (failed(p.parseFloatFromIntegerLiteral(result, token, isNegative,
629  eltTy.getFloatSemantics(),
630  eltTy.getWidth())))
631  return failure();
632 
633  floatValues.push_back(*result);
634  continue;
635  }
636 
637  // Check to see if any decimal integers or booleans were parsed.
638  if (!token.is(Token::floatliteral))
639  return p.emitError()
640  << "expected floating-point elements, but parsed integer";
641 
642  // Build the float values from tokens.
643  auto val = token.getFloatingPointValue();
644  if (!val.hasValue())
645  return p.emitError("floating point value too large for attribute");
646 
647  APFloat apVal(isNegative ? -*val : *val);
648  if (!eltTy.isF64()) {
649  bool unused;
650  apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven,
651  &unused);
652  }
653  floatValues.push_back(apVal);
654  }
655  return success();
656 }
657 
658 /// Build a Dense String attribute for the given type.
659 DenseElementsAttr TensorLiteralParser::getStringAttr(llvm::SMLoc loc,
660  ShapedType type,
661  Type eltTy) {
662  if (hexStorage.hasValue()) {
663  auto stringValue = hexStorage.getValue().getStringValue();
664  return DenseStringElementsAttr::get(type, {stringValue});
665  }
666 
667  std::vector<std::string> stringValues;
668  std::vector<StringRef> stringRefValues;
669  stringValues.reserve(storage.size());
670  stringRefValues.reserve(storage.size());
671 
672  for (auto val : storage) {
673  stringValues.push_back(val.second.getStringValue());
674  stringRefValues.emplace_back(stringValues.back());
675  }
676 
677  return DenseStringElementsAttr::get(type, stringRefValues);
678 }
679 
680 /// Build a Dense attribute with hex data for the given type.
681 DenseElementsAttr TensorLiteralParser::getHexAttr(llvm::SMLoc loc,
682  ShapedType type) {
683  Type elementType = type.getElementType();
684  if (!elementType.isIntOrIndexOrFloat() && !elementType.isa<ComplexType>()) {
685  p.emitError(loc)
686  << "expected floating-point, integer, or complex element type, got "
687  << elementType;
688  return nullptr;
689  }
690 
691  std::string data;
692  if (parseElementAttrHexValues(p, hexStorage.getValue(), data))
693  return nullptr;
694 
695  ArrayRef<char> rawData(data.data(), data.size());
696  bool detectedSplat = false;
697  if (!DenseElementsAttr::isValidRawBuffer(type, rawData, detectedSplat)) {
698  p.emitError(loc) << "elements hex data size is invalid for provided type: "
699  << type;
700  return nullptr;
701  }
702 
703  if (llvm::support::endian::system_endianness() ==
704  llvm::support::endianness::big) {
705  // Convert endianess in big-endian(BE) machines. `rawData` is
706  // little-endian(LE) because HEX in raw data of dense element attribute
707  // is always LE format. It is converted into BE here to be used in BE
708  // machines.
709  SmallVector<char, 64> outDataVec(rawData.size());
710  MutableArrayRef<char> convRawData(outDataVec);
711  DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
712  rawData, convRawData, type);
713  return DenseElementsAttr::getFromRawBuffer(type, convRawData,
714  detectedSplat);
715  }
716 
717  return DenseElementsAttr::getFromRawBuffer(type, rawData, detectedSplat);
718 }
719 
720 ParseResult TensorLiteralParser::parseElement() {
721  switch (p.getToken().getKind()) {
722  // Parse a boolean element.
723  case Token::kw_true:
724  case Token::kw_false:
725  case Token::floatliteral:
726  case Token::integer:
727  storage.emplace_back(/*isNegative=*/false, p.getToken());
728  p.consumeToken();
729  break;
730 
731  // Parse a signed integer or a negative floating-point element.
732  case Token::minus:
733  p.consumeToken(Token::minus);
734  if (!p.getToken().isAny(Token::floatliteral, Token::integer))
735  return p.emitError("expected integer or floating point literal");
736  storage.emplace_back(/*isNegative=*/true, p.getToken());
737  p.consumeToken();
738  break;
739 
740  case Token::string:
741  storage.emplace_back(/*isNegative=*/false, p.getToken());
742  p.consumeToken();
743  break;
744 
745  // Parse a complex element of the form '(' element ',' element ')'.
746  case Token::l_paren:
747  p.consumeToken(Token::l_paren);
748  if (parseElement() ||
749  p.parseToken(Token::comma, "expected ',' between complex elements") ||
750  parseElement() ||
751  p.parseToken(Token::r_paren, "expected ')' after complex elements"))
752  return failure();
753  break;
754 
755  default:
756  return p.emitError("expected element literal of primitive type");
757  }
758 
759  return success();
760 }
761 
762 /// Parse a list of either lists or elements, returning the dimensions of the
763 /// parsed sub-tensors in dims. For example:
764 /// parseList([1, 2, 3]) -> Success, [3]
765 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
766 /// parseList([[1, 2], 3]) -> Failure
767 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure
768 ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
769  auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims,
770  const SmallVectorImpl<int64_t> &newDims) -> ParseResult {
771  if (prevDims == newDims)
772  return success();
773  return p.emitError("tensor literal is invalid; ranks are not consistent "
774  "between elements");
775  };
776 
777  bool first = true;
778  SmallVector<int64_t, 4> newDims;
779  unsigned size = 0;
780  auto parseOneElement = [&]() -> ParseResult {
781  SmallVector<int64_t, 4> thisDims;
782  if (p.getToken().getKind() == Token::l_square) {
783  if (parseList(thisDims))
784  return failure();
785  } else if (parseElement()) {
786  return failure();
787  }
788  ++size;
789  if (!first)
790  return checkDims(newDims, thisDims);
791  newDims = thisDims;
792  first = false;
793  return success();
794  };
795  if (p.parseCommaSeparatedList(Parser::Delimiter::Square, parseOneElement))
796  return failure();
797 
798  // Return the sublists' dimensions with 'size' prepended.
799  dims.clear();
800  dims.push_back(size);
801  dims.append(newDims.begin(), newDims.end());
802  return success();
803 }
804 
805 //===----------------------------------------------------------------------===//
806 // ElementsAttr Parser
807 //===----------------------------------------------------------------------===//
808 
809 /// Parse a dense elements attribute.
811  auto attribLoc = getToken().getLoc();
812  consumeToken(Token::kw_dense);
813  if (parseToken(Token::less, "expected '<' after 'dense'"))
814  return nullptr;
815 
816  // Parse the literal data if necessary.
817  TensorLiteralParser literalParser(*this);
818  if (!consumeIf(Token::greater)) {
819  if (literalParser.parse(/*allowHex=*/true) ||
820  parseToken(Token::greater, "expected '>'"))
821  return nullptr;
822  }
823 
824  // If the type is specified `parseElementsLiteralType` will not parse a type.
825  // Use the attribute location as the location for error reporting in that
826  // case.
827  auto loc = attrType ? attribLoc : getToken().getLoc();
828  auto type = parseElementsLiteralType(attrType);
829  if (!type)
830  return nullptr;
831  return literalParser.getAttr(loc, type);
832 }
833 
834 /// Parse an opaque elements attribute.
836  llvm::SMLoc loc = getToken().getLoc();
837  consumeToken(Token::kw_opaque);
838  if (parseToken(Token::less, "expected '<' after 'opaque'"))
839  return nullptr;
840 
841  if (getToken().isNot(Token::string))
842  return (emitError("expected dialect namespace"), nullptr);
843 
844  std::string name = getToken().getStringValue();
845  consumeToken(Token::string);
846 
847  if (parseToken(Token::comma, "expected ','"))
848  return nullptr;
849 
850  Token hexTok = getToken();
851  if (parseToken(Token::string, "elements hex string should start with '0x'") ||
852  parseToken(Token::greater, "expected '>'"))
853  return nullptr;
854  auto type = parseElementsLiteralType(attrType);
855  if (!type)
856  return nullptr;
857 
858  std::string data;
859  if (parseElementAttrHexValues(*this, hexTok, data))
860  return nullptr;
861  return getChecked<OpaqueElementsAttr>(loc, builder.getStringAttr(name), type,
862  data);
863 }
864 
865 /// Shaped type for elements attribute.
866 ///
867 /// elements-literal-type ::= vector-type | ranked-tensor-type
868 ///
869 /// This method also checks the type has static shape.
871  // If the user didn't provide a type, parse the colon type for the literal.
872  if (!type) {
873  if (parseToken(Token::colon, "expected ':'"))
874  return nullptr;
875  if (!(type = parseType()))
876  return nullptr;
877  }
878 
879  if (!type.isa<RankedTensorType, VectorType>()) {
880  emitError("elements literal must be a ranked tensor or vector type");
881  return nullptr;
882  }
883 
884  auto sType = type.cast<ShapedType>();
885  if (!sType.hasStaticShape())
886  return (emitError("elements literal type must have static shape"), nullptr);
887 
888  return sType;
889 }
890 
891 /// Parse a sparse elements attribute.
893  llvm::SMLoc loc = getToken().getLoc();
894  consumeToken(Token::kw_sparse);
895  if (parseToken(Token::less, "Expected '<' after 'sparse'"))
896  return nullptr;
897 
898  // Check for the case where all elements are sparse. The indices are
899  // represented by a 2-dimensional shape where the second dimension is the rank
900  // of the type.
901  Type indiceEltType = builder.getIntegerType(64);
902  if (consumeIf(Token::greater)) {
903  ShapedType type = parseElementsLiteralType(attrType);
904  if (!type)
905  return nullptr;
906 
907  // Construct the sparse elements attr using zero element indice/value
908  // attributes.
909  ShapedType indicesType =
910  RankedTensorType::get({0, type.getRank()}, indiceEltType);
911  ShapedType valuesType = RankedTensorType::get({0}, type.getElementType());
912  return getChecked<SparseElementsAttr>(
913  loc, type, DenseElementsAttr::get(indicesType, ArrayRef<Attribute>()),
915  }
916 
917  /// Parse the indices. We don't allow hex values here as we may need to use
918  /// the inferred shape.
919  auto indicesLoc = getToken().getLoc();
920  TensorLiteralParser indiceParser(*this);
921  if (indiceParser.parse(/*allowHex=*/false))
922  return nullptr;
923 
924  if (parseToken(Token::comma, "expected ','"))
925  return nullptr;
926 
927  /// Parse the values.
928  auto valuesLoc = getToken().getLoc();
929  TensorLiteralParser valuesParser(*this);
930  if (valuesParser.parse(/*allowHex=*/true))
931  return nullptr;
932 
933  if (parseToken(Token::greater, "expected '>'"))
934  return nullptr;
935 
936  auto type = parseElementsLiteralType(attrType);
937  if (!type)
938  return nullptr;
939 
940  // If the indices are a splat, i.e. the literal parser parsed an element and
941  // not a list, we set the shape explicitly. The indices are represented by a
942  // 2-dimensional shape where the second dimension is the rank of the type.
943  // Given that the parsed indices is a splat, we know that we only have one
944  // indice and thus one for the first dimension.
945  ShapedType indicesType;
946  if (indiceParser.getShape().empty()) {
947  indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType);
948  } else {
949  // Otherwise, set the shape to the one parsed by the literal parser.
950  indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType);
951  }
952  auto indices = indiceParser.getAttr(indicesLoc, indicesType);
953 
954  // If the values are a splat, set the shape explicitly based on the number of
955  // indices. The number of indices is encoded in the first dimension of the
956  // indice shape type.
957  auto valuesEltType = type.getElementType();
958  ShapedType valuesType =
959  valuesParser.getShape().empty()
960  ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType)
961  : RankedTensorType::get(valuesParser.getShape(), valuesEltType);
962  auto values = valuesParser.getAttr(valuesLoc, valuesType);
963 
964  // Build the sparse elements attribute by the indices and values.
965  return getChecked<SparseElementsAttr>(loc, type, indices, values);
966 }
Include the generated interface declarations.
Optional< std::string > getHexStringValue() const
Given a token containing a hex string literal, return its value or None if the token does not contain...
Definition: Token.cpp:127
MLIRContext * getContext() const
Definition: Parser.h:35
void addUses(Value value, ArrayRef< llvm::SMLoc > locations)
Add a source uses of the given value.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::string getSymbolReference() const
Given a token containing a symbol reference, return the unescaped string value.
Definition: Token.cpp:142
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:89
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
Definition: Types.cpp:31
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Builder builder
Definition: Parser.h:29
bool consumeIf(Token::Kind kind)
If the current token has the specified kind, consume it and return true.
Definition: Parser.h:117
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
Location objects represent source locations information in MLIR.
Definition: Location.h:31
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
OptionalParseResult parseOptionalType(Type &type)
Optionally parse a type.
Definition: TypeParser.cpp:23
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:61
ParserState & state
The Parser is subclassed and reinstantiated.
Definition: Parser.h:305
static ParseResult parseElementAttrHexValues(Parser &parser, Token tok, std::string &result)
Parse elements values stored within a hex string.
static constexpr const bool value
bool isAny(Kind k1, Kind k2) const
Definition: Token.h:40
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
Attribute parseAttribute(Type type={})
Parse an arbitrary attribute with an optional type.
bool is(Kind k) const
Definition: Token.h:38
StringRef getTokenSpelling() const
Definition: Parser.h:113
AsmParserState * asmState
An optional pointer to a struct containing high level parser state to be populated during parsing...
Definition: ParserState.h:84
ParseResult parseLocationInstance(LocationAttr &loc)
Parse a raw location instance.
ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())
Parse a list of comma-separated items with an optional delimiter.
Definition: Parser.cpp:43
static bool isValidRawBuffer(ShapedType type, ArrayRef< char > rawBuffer, bool &detectedSplat)
Returns true if the given buffer is a valid raw buffer for the given type.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
{} brackets surrounding zero or more operands.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
unsigned getWidth()
Return the bitwidth of this float type.
An attribute that represents a reference to a dense vector or tensor object.
OptionalParseResult parseOptionalAttribute(Attribute &attribute, Type type={})
Parse an optional attribute with the provided type.
Attribute parseDecOrHexAttr(Type type, bool isNegative)
Parse a decimal or a hexadecimal literal, which can be either an integer or a float attribute...
ParseResult parseFloatFromIntegerLiteral(Optional< APFloat > &result, const Token &tok, bool isNegative, const llvm::fltSemantics &semantics, size_t typeSizeInBits)
Parse a floating point value from an integer literal token.
Definition: Parser.cpp:198
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:170
U dyn_cast() const
Definition: Types.h:244
UnitAttr getUnitAttr()
Definition: Builders.cpp:85
ShapedType parseElementsLiteralType(Type type)
Shaped type for elements attribute.
Attributes are known-constant values of operations.
Definition: Attributes.h:24
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:58
bool isIndex() const
Definition: Types.cpp:28
const Token & getToken() const
Return the current token the parser is inspecting.
Definition: Parser.h:112
std::string getStringValue() const
Given a token containing a string literal, return its value, including removing the quote characters ...
Definition: Token.cpp:81
A multi-dimensional affine map Affine map&#39;s are immutable like Type&#39;s, and they are uniqued...
Definition: AffineMap.h:38
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:49
Attribute parseOpaqueElementsAttr(Type attrType)
Parse an opaque elements attribute.
void resetPointer(const char *newPointer)
Change the position of the lexer cursor.
Definition: Lexer.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:52
Attribute parseDenseElementsAttr(Type attrType)
Parse a dense elements attribute.
ParseResult parseAttributeDict(NamedAttrList &attributes)
Parse an attribute dictionary.
bool isF64() const
Definition: Types.cpp:24
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
Definition: MLIRContext.h:92
Type parseType()
Parse an arbitrary type.
Definition: TypeParser.cpp:52
StringRef getSpelling() const
Definition: Token.h:34
static Optional< APInt > buildAttributeAPInt(Type type, bool isNegative, StringRef spelling)
Construct an APint from a parsed value, a known attribute type and sign.
ParseResult parseAffineMapReference(AffineMap &map)
void consumeToken()
Advance the current lexer onto the next token.
Definition: Parser.h:125
Attribute parseSparseElementsAttr(Type attrType)
Parse a sparse elements attribute.
Lexer lex
The lexer for the source file we&#39;re parsing.
Definition: ParserState.h:71
FloatType getF64Type()
Definition: Builders.cpp:42
OptionalParseResult parseOptionalAttributeWithToken(Token::Kind kind, AttributeT &attr, Type type={})
Parse an optional attribute that is demarcated by a specific token.
Definition: Parser.h:229
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:87
This class implement support for parsing global entities like attributes and types.
Definition: Parser.h:25
Attribute parseFloatAttr(Type type, bool isNegative)
Parse a float attribute.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:91
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition: Types.cpp:85
bool hasValue() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:62
InFlightDiagnostic emitError(const Twine &message={})
Emit an error and return failure.
Definition: Parser.h:71
llvm::SMLoc getLoc() const
Definition: Token.cpp:19
ParseResult parseIntegerSetReference(IntegerSet &set)
bool isa() const
Definition: Types.h:234
This class represents success/failure for operation parsing.
Definition: OpDefinition.h:36
ParseResult parseToken(Token::Kind expectedToken, const Twine &message)
Consume the specified token if present and return success.
Definition: Parser.cpp:163
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:205
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:201
Optional< double > getFloatingPointValue() const
For a floatliteral token, return its value as a double.
Definition: Token.cpp:51
Attribute parseExtendedAttr(Type type)
Parse an extended attribute.
Square brackets surrounding zero or more operands.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer, bool isSplatBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute...
U cast() const
Definition: Types.h:250
This represents a token in the MLIR syntax.
Definition: Token.h:20
An integer set representing a conjunction of one or more affine equalities and inequalities.
Definition: IntegerSet.h:44