MLIR  14.0.0git
DialectSymbolParser.cpp
Go to the documentation of this file.
1 //===- DialectSymbolParser.cpp - MLIR Dialect Symbol Parser --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the parser for the dialect symbols, such as extended
10 // attributes and types.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "AsmParserImpl.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Dialect.h"
18 #include "llvm/Support/SourceMgr.h"
19 
20 using namespace mlir;
21 using namespace mlir::detail;
22 using llvm::MemoryBuffer;
23 using llvm::SMLoc;
24 using llvm::SourceMgr;
25 
26 namespace {
27 /// This class provides the main implementation of the DialectAsmParser that
28 /// allows for dialects to parse attributes and types. This allows for dialect
29 /// hooking into the main MLIR parsing logic.
30 class CustomDialectAsmParser : public AsmParserImpl<DialectAsmParser> {
31 public:
32  CustomDialectAsmParser(StringRef fullSpec, Parser &parser)
33  : AsmParserImpl<DialectAsmParser>(parser.getToken().getLoc(), parser),
34  fullSpec(fullSpec) {}
35  ~CustomDialectAsmParser() override = default;
36 
37  /// Returns the full specification of the symbol being parsed. This allows
38  /// for using a separate parser if necessary.
39  StringRef getFullSymbolSpec() const override { return fullSpec; }
40 
41 private:
42  /// The full symbol specification.
43  StringRef fullSpec;
44 };
45 } // namespace
46 
47 /// Parse the body of a pretty dialect symbol, which starts and ends with <>'s,
48 /// and may be recursive. Return with the 'prettyName' StringRef encompassing
49 /// the entire pretty name.
50 ///
51 /// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>'
52 /// pretty-dialect-sym-contents ::= pretty-dialect-sym-body
53 /// | '(' pretty-dialect-sym-contents+ ')'
54 /// | '[' pretty-dialect-sym-contents+ ']'
55 /// | '{' pretty-dialect-sym-contents+ '}'
56 /// | '[^[<({>\])}\0]+'
57 ///
59  // Pretty symbol names are a relatively unstructured format that contains a
60  // series of properly nested punctuation, with anything else in the middle.
61  // Scan ahead to find it and consume it if successful, otherwise emit an
62  // error.
63  auto *curPtr = getTokenSpelling().data();
64 
65  SmallVector<char, 8> nestedPunctuation;
66 
67  // Scan over the nested punctuation, bailing out on error and consuming until
68  // we find the end. We know that we're currently looking at the '<', so we
69  // can go until we find the matching '>' character.
70  assert(*curPtr == '<');
71  do {
72  char c = *curPtr++;
73  switch (c) {
74  case '\0':
75  // This also handles the EOF case.
76  return emitError("unexpected nul or EOF in pretty dialect name");
77  case '<':
78  case '[':
79  case '(':
80  case '{':
81  nestedPunctuation.push_back(c);
82  continue;
83 
84  case '-':
85  // The sequence `->` is treated as special token.
86  if (*curPtr == '>')
87  ++curPtr;
88  continue;
89 
90  case '>':
91  if (nestedPunctuation.pop_back_val() != '<')
92  return emitError("unbalanced '>' character in pretty dialect name");
93  break;
94  case ']':
95  if (nestedPunctuation.pop_back_val() != '[')
96  return emitError("unbalanced ']' character in pretty dialect name");
97  break;
98  case ')':
99  if (nestedPunctuation.pop_back_val() != '(')
100  return emitError("unbalanced ')' character in pretty dialect name");
101  break;
102  case '}':
103  if (nestedPunctuation.pop_back_val() != '{')
104  return emitError("unbalanced '}' character in pretty dialect name");
105  break;
106 
107  default:
108  continue;
109  }
110  } while (!nestedPunctuation.empty());
111 
112  // Ok, we succeeded, remember where we stopped, reset the lexer to know it is
113  // consuming all this stuff, and return.
114  state.lex.resetPointer(curPtr);
115 
116  unsigned length = curPtr - prettyName.begin();
117  prettyName = StringRef(prettyName.begin(), length);
118  consumeToken();
119  return success();
120 }
121 
122 /// Parse an extended dialect symbol.
123 template <typename Symbol, typename SymbolAliasMap, typename CreateFn>
124 static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok,
125  SymbolAliasMap &aliases,
126  CreateFn &&createSymbol) {
127  // Parse the dialect namespace.
128  StringRef identifier = p.getTokenSpelling().drop_front();
129  auto loc = p.getToken().getLoc();
130  p.consumeToken(identifierTok);
131 
132  // If there is no '<' token following this, and if the typename contains no
133  // dot, then we are parsing a symbol alias.
134  if (p.getToken().isNot(Token::less) && !identifier.contains('.')) {
135  // Check for an alias for this type.
136  auto aliasIt = aliases.find(identifier);
137  if (aliasIt == aliases.end())
138  return (p.emitError("undefined symbol alias id '" + identifier + "'"),
139  nullptr);
140  return aliasIt->second;
141  }
142 
143  // Otherwise, we are parsing a dialect-specific symbol. If the name contains
144  // a dot, then this is the "pretty" form. If not, it is the verbose form that
145  // looks like <"...">.
146  std::string symbolData;
147  auto dialectName = identifier;
148 
149  // Handle the verbose form, where "identifier" is a simple dialect name.
150  if (!identifier.contains('.')) {
151  // Consume the '<'.
152  if (p.parseToken(Token::less, "expected '<' in dialect type"))
153  return nullptr;
154 
155  // Parse the symbol specific data.
156  if (p.getToken().isNot(Token::string))
157  return (p.emitError("expected string literal data in dialect symbol"),
158  nullptr);
159  symbolData = p.getToken().getStringValue();
160  loc = llvm::SMLoc::getFromPointer(p.getToken().getLoc().getPointer() + 1);
161  p.consumeToken(Token::string);
162 
163  // Consume the '>'.
164  if (p.parseToken(Token::greater, "expected '>' in dialect symbol"))
165  return nullptr;
166  } else {
167  // Ok, the dialect name is the part of the identifier before the dot, the
168  // part after the dot is the dialect's symbol, or the start thereof.
169  auto dotHalves = identifier.split('.');
170  dialectName = dotHalves.first;
171  auto prettyName = dotHalves.second;
172  loc = llvm::SMLoc::getFromPointer(prettyName.data());
173 
174  // If the dialect's symbol is followed immediately by a <, then lex the body
175  // of it into prettyName.
176  if (p.getToken().is(Token::less) &&
177  prettyName.bytes_end() == p.getTokenSpelling().bytes_begin()) {
178  if (p.parsePrettyDialectSymbolName(prettyName))
179  return nullptr;
180  }
181 
182  symbolData = prettyName.str();
183  }
184 
185  // Record the name location of the type remapped to the top level buffer.
186  llvm::SMLoc locInTopLevelBuffer = p.remapLocationToTopLevelBuffer(loc);
187  p.getState().symbols.nestedParserLocs.push_back(locInTopLevelBuffer);
188 
189  // Call into the provided symbol construction function.
190  Symbol sym = createSymbol(dialectName, symbolData, loc);
191 
192  // Pop the last parser location.
193  p.getState().symbols.nestedParserLocs.pop_back();
194  return sym;
195 }
196 
197 /// Parses a symbol, of type 'T', and returns it if parsing was successful. If
198 /// parsing failed, nullptr is returned. The number of bytes read from the input
199 /// string is returned in 'numRead'.
200 template <typename T, typename ParserFn>
201 static T parseSymbol(StringRef inputStr, MLIRContext *context,
202  SymbolState &symbolState, ParserFn &&parserFn,
203  size_t *numRead = nullptr) {
204  SourceMgr sourceMgr;
205  auto memBuffer = MemoryBuffer::getMemBuffer(
206  inputStr, /*BufferName=*/"<mlir_parser_buffer>",
207  /*RequiresNullTerminator=*/false);
208  sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
209  ParserState state(sourceMgr, context, symbolState, /*asmState=*/nullptr);
210  Parser parser(state);
211 
212  Token startTok = parser.getToken();
213  T symbol = parserFn(parser);
214  if (!symbol)
215  return T();
216 
217  // If 'numRead' is valid, then provide the number of bytes that were read.
218  Token endTok = parser.getToken();
219  if (numRead) {
220  *numRead = static_cast<size_t>(endTok.getLoc().getPointer() -
221  startTok.getLoc().getPointer());
222 
223  // Otherwise, ensure that all of the tokens were parsed.
224  } else if (startTok.getLoc() != endTok.getLoc() && endTok.isNot(Token::eof)) {
225  parser.emitError(endTok.getLoc(), "encountered unexpected token");
226  return T();
227  }
228  return symbol;
229 }
230 
231 /// Parse an extended attribute.
232 ///
233 /// extended-attribute ::= (dialect-attribute | attribute-alias)
234 /// dialect-attribute ::= `#` dialect-namespace `<` `"` attr-data `"` `>`
235 /// dialect-attribute ::= `#` alias-name pretty-dialect-sym-body?
236 /// attribute-alias ::= `#` alias-name
237 ///
239  Attribute attr = parseExtendedSymbol<Attribute>(
240  *this, Token::hash_identifier, state.symbols.attributeAliasDefinitions,
241  [&](StringRef dialectName, StringRef symbolData,
242  llvm::SMLoc loc) -> Attribute {
243  // Parse an optional trailing colon type.
244  Type attrType = type;
245  if (consumeIf(Token::colon) && !(attrType = parseType()))
246  return Attribute();
247 
248  // If we found a registered dialect, then ask it to parse the attribute.
249  if (Dialect *dialect =
250  builder.getContext()->getOrLoadDialect(dialectName)) {
251  return parseSymbol<Attribute>(
252  symbolData, state.context, state.symbols, [&](Parser &parser) {
253  CustomDialectAsmParser customParser(symbolData, parser);
254  return dialect->parseAttribute(customParser, attrType);
255  });
256  }
257 
258  // Otherwise, form a new opaque attribute.
259  return OpaqueAttr::getChecked(
260  [&] { return emitError(loc); },
261  StringAttr::get(state.context, dialectName), symbolData,
262  attrType ? attrType : NoneType::get(state.context));
263  });
264 
265  // Ensure that the attribute has the same type as requested.
266  if (attr && type && attr.getType() != type) {
267  emitError("attribute type different than expected: expected ")
268  << type << ", but got " << attr.getType();
269  return nullptr;
270  }
271  return attr;
272 }
273 
274 /// Parse an extended type.
275 ///
276 /// extended-type ::= (dialect-type | type-alias)
277 /// dialect-type ::= `!` dialect-namespace `<` `"` type-data `"` `>`
278 /// dialect-type ::= `!` alias-name pretty-dialect-attribute-body?
279 /// type-alias ::= `!` alias-name
280 ///
282  return parseExtendedSymbol<Type>(
283  *this, Token::exclamation_identifier, state.symbols.typeAliasDefinitions,
284  [&](StringRef dialectName, StringRef symbolData,
285  llvm::SMLoc loc) -> Type {
286  // If we found a registered dialect, then ask it to parse the type.
287  auto *dialect = state.context->getOrLoadDialect(dialectName);
288 
289  if (dialect) {
290  return parseSymbol<Type>(
291  symbolData, state.context, state.symbols, [&](Parser &parser) {
292  CustomDialectAsmParser customParser(symbolData, parser);
293  return dialect->parseType(customParser);
294  });
295  }
296 
297  // Otherwise, form a new opaque type.
298  return OpaqueType::getChecked(
299  [&] { return emitError(loc); },
300  StringAttr::get(state.context, dialectName), symbolData);
301  });
302 }
303 
304 //===----------------------------------------------------------------------===//
305 // mlir::parseAttribute/parseType
306 //===----------------------------------------------------------------------===//
307 
308 /// Parses a symbol, of type 'T', and returns it if parsing was successful. If
309 /// parsing failed, nullptr is returned. The number of bytes read from the input
310 /// string is returned in 'numRead'.
311 template <typename T, typename ParserFn>
312 static T parseSymbol(StringRef inputStr, MLIRContext *context, size_t &numRead,
313  ParserFn &&parserFn) {
314  SymbolState aliasState;
315  return parseSymbol<T>(
316  inputStr, context, aliasState,
317  [&](Parser &parser) {
319  const_cast<llvm::SourceMgr &>(parser.getSourceMgr()),
320  parser.getContext());
321  return parserFn(parser);
322  },
323  &numRead);
324 }
325 
326 Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context) {
327  size_t numRead = 0;
328  return parseAttribute(attrStr, context, numRead);
329 }
330 Attribute mlir::parseAttribute(StringRef attrStr, Type type) {
331  size_t numRead = 0;
332  return parseAttribute(attrStr, type, numRead);
333 }
334 
335 Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context,
336  size_t &numRead) {
337  return parseSymbol<Attribute>(attrStr, context, numRead, [](Parser &parser) {
338  return parser.parseAttribute();
339  });
340 }
341 Attribute mlir::parseAttribute(StringRef attrStr, Type type, size_t &numRead) {
342  return parseSymbol<Attribute>(
343  attrStr, type.getContext(), numRead,
344  [type](Parser &parser) { return parser.parseAttribute(type); });
345 }
346 
347 Type mlir::parseType(StringRef typeStr, MLIRContext *context) {
348  size_t numRead = 0;
349  return parseType(typeStr, context, numRead);
350 }
351 
352 Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t &numRead) {
353  return parseSymbol<Type>(typeStr, context, numRead,
354  [](Parser &parser) { return parser.parseType(); });
355 }
Include the generated interface declarations.
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context)
This parses a single MLIR attribute to an MLIR context if it was valid.
bool isNot(Kind k) const
Definition: Token.h:49
SmallVector< llvm::SMLoc, 1 > nestedParserLocs
A set of locations into the main parser memory buffer for each of the active nested parsers...
Definition: ParserState.h:35
static T parseSymbol(StringRef inputStr, MLIRContext *context, SymbolState &symbolState, ParserFn &&parserFn, size_t *numRead=nullptr)
Parses a symbol, of type &#39;T&#39;, and returns it if parsing was successful.
bool is(Kind k) const
Definition: Token.h:37
StringRef getTokenSpelling() const
Definition: Parser.h:113
ParserState & getState() const
Definition: Parser.h:34
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok, SymbolAliasMap &aliases, CreateFn &&createSymbol)
Parse an extended dialect symbol.
ParseResult parsePrettyDialectSymbolName(StringRef &prettyName)
Parse the body of a pretty dialect symbol, which starts and ends with <>&#39;s, and may be recursive...
Type parseExtendedType()
Parse an extended type.
Attributes are known-constant values of operations.
Definition: Attributes.h:24
This class provides the implementation of the generic parser methods within AsmParser.
Definition: AsmParserImpl.h:26
const Token & getToken() const
Return the current token the parser is inspecting.
Definition: Parser.h:112
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:42
std::string getStringValue() const
Given a token containing a string literal, return its value, including removing the quote characters ...
Definition: Token.cpp:81
MLIRContext * getContext() const
Definition: Dialect.h:56
Type parseType(llvm::StringRef typeStr, MLIRContext *context)
This parses a single MLIR type to an MLIR context if it was valid.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class is a utility diagnostic handler for use with llvm::SourceMgr.
Definition: Diagnostics.h:525
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:19
SymbolState & symbols
The current state for symbol parsing.
Definition: ParserState.h:77
This class refers to all of the state maintained globally by the parser, such as the current lexer po...
Definition: ParserState.h:49
Type getType() const
Return the type of this attribute.
Definition: Attributes.h:64
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
Definition: MLIRContext.h:92
llvm::SMLoc remapLocationToTopLevelBuffer(llvm::SMLoc loc)
Remaps the given SMLoc to the top level lexer of the parser.
Definition: Parser.h:91
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void consumeToken()
Advance the current lexer onto the next token.
Definition: Parser.h:125
This class implement support for parsing global entities like attributes and types.
Definition: Parser.h:25
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
This class contains record of any parsed top-level symbols.
Definition: ParserState.h:24
InFlightDiagnostic emitError(const Twine &message={})
Emit an error and return failure.
Definition: Parser.h:71
llvm::SMLoc getLoc() const
Definition: Token.cpp:19
This class represents success/failure for operation parsing.
Definition: OpDefinition.h:36
ParseResult parseToken(Token::Kind expectedToken, const Twine &message)
Consume the specified token if present and return success.
Definition: Parser.cpp:163
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
Attribute parseExtendedAttr(Type type)
Parse an extended attribute.
This represents a token in the MLIR syntax.
Definition: Token.h:19