MLIR  20.0.0git
DialectSymbolParser.cpp
Go to the documentation of this file.
1 //===- DialectSymbolParser.cpp - MLIR Dialect Symbol Parser --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the parser for the dialect symbols, such as extended
10 // attributes and types.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "AsmParserImpl.h"
15 #include "Parser.h"
17 #include "mlir/IR/AsmState.h"
18 #include "mlir/IR/Attributes.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/Dialect.h"
24 #include "mlir/IR/MLIRContext.h"
25 #include "mlir/Support/LLVM.h"
26 #include "llvm/Support/MemoryBuffer.h"
27 #include "llvm/Support/SourceMgr.h"
28 #include <cassert>
29 #include <cstddef>
30 #include <utility>
31 
32 using namespace mlir;
33 using namespace mlir::detail;
34 using llvm::MemoryBuffer;
35 using llvm::SourceMgr;
36 
37 namespace {
38 /// This class provides the main implementation of the DialectAsmParser that
39 /// allows for dialects to parse attributes and types. This allows for dialect
40 /// hooking into the main MLIR parsing logic.
41 class CustomDialectAsmParser : public AsmParserImpl<DialectAsmParser> {
42 public:
43  CustomDialectAsmParser(StringRef fullSpec, Parser &parser)
44  : AsmParserImpl<DialectAsmParser>(parser.getToken().getLoc(), parser),
45  fullSpec(fullSpec) {}
46  ~CustomDialectAsmParser() override = default;
47 
48  /// Returns the full specification of the symbol being parsed. This allows
49  /// for using a separate parser if necessary.
50  StringRef getFullSymbolSpec() const override { return fullSpec; }
51 
52 private:
53  /// The full symbol specification.
54  StringRef fullSpec;
55 };
56 } // namespace
57 
58 ///
59 /// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>'
60 /// pretty-dialect-sym-contents ::= pretty-dialect-sym-body
61 /// | '(' pretty-dialect-sym-contents+ ')'
62 /// | '[' pretty-dialect-sym-contents+ ']'
63 /// | '{' pretty-dialect-sym-contents+ '}'
64 /// | '[^[<({>\])}\0]+'
65 ///
66 ParseResult Parser::parseDialectSymbolBody(StringRef &body,
67  bool &isCodeCompletion) {
68  // Symbol bodies are a relatively unstructured format that contains a series
69  // of properly nested punctuation, with anything else in the middle. Scan
70  // ahead to find it and consume it if successful, otherwise emit an error.
71  const char *curPtr = getTokenSpelling().data();
72 
73  // Scan over the nested punctuation, bailing out on error and consuming until
74  // we find the end. We know that we're currently looking at the '<', so we can
75  // go until we find the matching '>' character.
76  assert(*curPtr == '<');
77  SmallVector<char, 8> nestedPunctuation;
78  const char *codeCompleteLoc = state.lex.getCodeCompleteLoc();
79 
80  // Functor used to emit an unbalanced punctuation error.
81  auto emitPunctError = [&] {
82  return emitError() << "unbalanced '" << nestedPunctuation.back()
83  << "' character in pretty dialect name";
84  };
85  // Functor used to check for unbalanced punctuation.
86  auto checkNestedPunctuation = [&](char expectedToken) -> ParseResult {
87  if (nestedPunctuation.back() != expectedToken)
88  return emitPunctError();
89  nestedPunctuation.pop_back();
90  return success();
91  };
92  do {
93  // Handle code completions, which may appear in the middle of the symbol
94  // body.
95  if (curPtr == codeCompleteLoc) {
96  isCodeCompletion = true;
97  nestedPunctuation.clear();
98  break;
99  }
100 
101  char c = *curPtr++;
102  switch (c) {
103  case '\0':
104  // This also handles the EOF case.
105  if (!nestedPunctuation.empty())
106  return emitPunctError();
107  return emitError("unexpected nul or EOF in pretty dialect name");
108  case '<':
109  case '[':
110  case '(':
111  case '{':
112  nestedPunctuation.push_back(c);
113  continue;
114 
115  case '-':
116  // The sequence `->` is treated as special token.
117  if (*curPtr == '>')
118  ++curPtr;
119  continue;
120 
121  case '>':
122  if (failed(checkNestedPunctuation('<')))
123  return failure();
124  break;
125  case ']':
126  if (failed(checkNestedPunctuation('[')))
127  return failure();
128  break;
129  case ')':
130  if (failed(checkNestedPunctuation('(')))
131  return failure();
132  break;
133  case '}':
134  if (failed(checkNestedPunctuation('{')))
135  return failure();
136  break;
137  case '"': {
138  // Dispatch to the lexer to lex past strings.
139  resetToken(curPtr - 1);
140  curPtr = state.curToken.getEndLoc().getPointer();
141 
142  // Handle code completions, which may appear in the middle of the symbol
143  // body.
145  isCodeCompletion = true;
146  nestedPunctuation.clear();
147  break;
148  }
149 
150  // Otherwise, ensure this token was actually a string.
151  if (state.curToken.isNot(Token::string))
152  return failure();
153  break;
154  }
155 
156  default:
157  continue;
158  }
159  } while (!nestedPunctuation.empty());
160 
161  // Ok, we succeeded, remember where we stopped, reset the lexer to know it is
162  // consuming all this stuff, and return.
163  resetToken(curPtr);
164 
165  unsigned length = curPtr - body.begin();
166  body = StringRef(body.data(), length);
167  return success();
168 }
169 
170 /// Parse an extended dialect symbol.
171 template <typename Symbol, typename SymbolAliasMap, typename CreateFn>
172 static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState,
173  SymbolAliasMap &aliases,
174  CreateFn &&createSymbol) {
175  Token tok = p.getToken();
176 
177  // Handle code completion of the extended symbol.
178  StringRef identifier = tok.getSpelling().drop_front();
179  if (tok.isCodeCompletion() && identifier.empty())
180  return p.codeCompleteDialectSymbol(aliases);
181 
182  // Parse the dialect namespace.
183  SMRange range = p.getToken().getLocRange();
184  SMLoc loc = p.getToken().getLoc();
185  p.consumeToken();
186 
187  // Check to see if this is a pretty name.
188  auto [dialectName, symbolData] = identifier.split('.');
189  bool isPrettyName = !symbolData.empty() || identifier.back() == '.';
190 
191  // Check to see if the symbol has trailing data, i.e. has an immediately
192  // following '<'.
193  bool hasTrailingData =
194  p.getToken().is(Token::less) &&
195  identifier.bytes_end() == p.getTokenSpelling().bytes_begin();
196 
197  // If there is no '<' token following this, and if the typename contains no
198  // dot, then we are parsing a symbol alias.
199  if (!hasTrailingData && !isPrettyName) {
200  // Check for an alias for this type.
201  auto aliasIt = aliases.find(identifier);
202  if (aliasIt == aliases.end())
203  return (p.emitWrongTokenError("undefined symbol alias id '" + identifier +
204  "'"),
205  nullptr);
206  if (asmState) {
207  if constexpr (std::is_same_v<Symbol, Type>)
208  asmState->addTypeAliasUses(identifier, range);
209  else
210  asmState->addAttrAliasUses(identifier, range);
211  }
212  return aliasIt->second;
213  }
214 
215  // If this isn't an alias, we are parsing a dialect-specific symbol. If the
216  // name contains a dot, then this is the "pretty" form. If not, it is the
217  // verbose form that looks like <...>.
218  if (!isPrettyName) {
219  // Point the symbol data to the end of the dialect name to start.
220  symbolData = StringRef(dialectName.end(), 0);
221 
222  // Parse the body of the symbol.
223  bool isCodeCompletion = false;
224  if (p.parseDialectSymbolBody(symbolData, isCodeCompletion))
225  return nullptr;
226  symbolData = symbolData.drop_front();
227 
228  // If the body contained a code completion it won't have the trailing `>`
229  // token, so don't drop it.
230  if (!isCodeCompletion)
231  symbolData = symbolData.drop_back();
232  } else {
233  loc = SMLoc::getFromPointer(symbolData.data());
234 
235  // If the dialect's symbol is followed immediately by a <, then lex the body
236  // of it into prettyName.
237  if (hasTrailingData && p.parseDialectSymbolBody(symbolData))
238  return nullptr;
239  }
240 
241  return createSymbol(dialectName, symbolData, loc);
242 }
243 
244 /// Parse an extended attribute.
245 ///
246 /// extended-attribute ::= (dialect-attribute | attribute-alias)
247 /// dialect-attribute ::= `#` dialect-namespace `<` attr-data `>`
248 /// (`:` type)?
249 /// | `#` alias-name pretty-dialect-sym-body? (`:` type)?
250 /// attribute-alias ::= `#` alias-name
251 ///
253  MLIRContext *ctx = getContext();
254  Attribute attr = parseExtendedSymbol<Attribute>(
256  [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Attribute {
257  // Parse an optional trailing colon type.
258  Type attrType = type;
259  if (consumeIf(Token::colon) && !(attrType = parseType()))
260  return Attribute();
261 
262  // If we found a registered dialect, then ask it to parse the attribute.
263  if (Dialect *dialect =
264  builder.getContext()->getOrLoadDialect(dialectName)) {
265  // Temporarily reset the lexer to let the dialect parse the attribute.
266  const char *curLexerPos = getToken().getLoc().getPointer();
267  resetToken(symbolData.data());
268 
269  // Parse the attribute.
270  CustomDialectAsmParser customParser(symbolData, *this);
271  Attribute attr = dialect->parseAttribute(customParser, attrType);
272  resetToken(curLexerPos);
273  return attr;
274  }
275 
276  // Otherwise, form a new opaque attribute.
277  return OpaqueAttr::getChecked(
278  [&] { return emitError(loc); }, StringAttr::get(ctx, dialectName),
279  symbolData, attrType ? attrType : NoneType::get(ctx));
280  });
281 
282  // Ensure that the attribute has the same type as requested.
283  auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
284  if (type && typedAttr && typedAttr.getType() != type) {
285  emitError("attribute type different than expected: expected ")
286  << type << ", but got " << typedAttr.getType();
287  return nullptr;
288  }
289  return attr;
290 }
291 
292 /// Parse an extended type.
293 ///
294 /// extended-type ::= (dialect-type | type-alias)
295 /// dialect-type ::= `!` dialect-namespace `<` `"` type-data `"` `>`
296 /// dialect-type ::= `!` alias-name pretty-dialect-attribute-body?
297 /// type-alias ::= `!` alias-name
298 ///
299 Type Parser::parseExtendedType() {
300  MLIRContext *ctx = getContext();
301  return parseExtendedSymbol<Type>(
302  *this, state.asmState, state.symbols.typeAliasDefinitions,
303  [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Type {
304  // If we found a registered dialect, then ask it to parse the type.
305  if (auto *dialect = ctx->getOrLoadDialect(dialectName)) {
306  // Temporarily reset the lexer to let the dialect parse the type.
307  const char *curLexerPos = getToken().getLoc().getPointer();
308  resetToken(symbolData.data());
309 
310  // Parse the type.
311  CustomDialectAsmParser customParser(symbolData, *this);
312  Type type = dialect->parseType(customParser);
313  resetToken(curLexerPos);
314  return type;
315  }
316 
317  // Otherwise, form a new opaque type.
318  return OpaqueType::getChecked([&] { return emitError(loc); },
319  StringAttr::get(ctx, dialectName),
320  symbolData);
321  });
322 }
323 
324 //===----------------------------------------------------------------------===//
325 // mlir::parseAttribute/parseType
326 //===----------------------------------------------------------------------===//
327 
328 /// Parses a symbol, of type 'T', and returns it if parsing was successful. If
329 /// parsing failed, nullptr is returned.
330 template <typename T, typename ParserFn>
331 static T parseSymbol(StringRef inputStr, MLIRContext *context,
332  size_t *numReadOut, bool isKnownNullTerminated,
333  ParserFn &&parserFn) {
334  // Set the buffer name to the string being parsed, so that it appears in error
335  // diagnostics.
336  auto memBuffer =
337  isKnownNullTerminated
338  ? MemoryBuffer::getMemBuffer(inputStr,
339  /*BufferName=*/inputStr)
340  : MemoryBuffer::getMemBufferCopy(inputStr, /*BufferName=*/inputStr);
341  SourceMgr sourceMgr;
342  sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
343  SymbolState aliasState;
344  ParserConfig config(context);
345  ParserState state(sourceMgr, config, aliasState, /*asmState=*/nullptr,
346  /*codeCompleteContext=*/nullptr);
347  Parser parser(state);
348 
349  Token startTok = parser.getToken();
350  T symbol = parserFn(parser);
351  if (!symbol)
352  return T();
353 
354  // Provide the number of bytes that were read.
355  Token endTok = parser.getToken();
356  size_t numRead =
357  endTok.getLoc().getPointer() - startTok.getLoc().getPointer();
358  if (numReadOut) {
359  *numReadOut = numRead;
360  } else if (numRead != inputStr.size()) {
361  parser.emitError(endTok.getLoc()) << "found trailing characters: '"
362  << inputStr.drop_front(numRead) << "'";
363  return T();
364  }
365  return symbol;
366 }
367 
368 Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context,
369  Type type, size_t *numRead,
370  bool isKnownNullTerminated) {
371  return parseSymbol<Attribute>(
372  attrStr, context, numRead, isKnownNullTerminated,
373  [type](Parser &parser) { return parser.parseAttribute(type); });
374 }
375 Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t *numRead,
376  bool isKnownNullTerminated) {
377  return parseSymbol<Type>(typeStr, context, numRead, isKnownNullTerminated,
378  [](Parser &parser) { return parser.parseType(); });
379 }
static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState, SymbolAliasMap &aliases, CreateFn &&createSymbol)
Parse an extended dialect symbol.
static T parseSymbol(StringRef inputStr, MLIRContext *context, size_t *numReadOut, bool isKnownNullTerminated, ParserFn &&parserFn)
Parses a symbol, of type 'T', and returns it if parsing was successful.
static MLIRContext * getContext(OpFoldResult val)
This class represents state from a parsed MLIR textual format string.
void addTypeAliasUses(StringRef name, SMRange locations)
void addAttrAliasUses(StringRef name, SMRange locations)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
const char * getCodeCompleteLoc() const
Return the code completion location of the lexer, or nullptr if there is none.
Definition: Lexer.h:45
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class represents a configuration for the MLIR assembly parser.
Definition: AsmState.h:467
This represents a token in the MLIR syntax.
Definition: Token.h:20
SMRange getLocRange() const
Definition: Token.cpp:30
SMLoc getLoc() const
Definition: Token.cpp:24
bool is(Kind k) const
Definition: Token.h:38
SMLoc getEndLoc() const
Definition: Token.cpp:26
bool isNot(Kind k) const
Definition: Token.h:50
bool isCodeCompletion() const
Returns true if the current token represents a code completion.
Definition: Token.h:62
StringRef getSpelling() const
Definition: Token.h:34
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides the implementation of the generic parser methods within AsmParser.
Definition: AsmParserImpl.h:28
This class implement support for parsing global entities like attributes and types.
Definition: Parser.h:27
Type parseType()
Parse an arbitrary type.
Definition: TypeParser.cpp:75
InFlightDiagnostic emitError(const Twine &message={})
Emit an error and return failure.
Definition: Parser.cpp:192
ParserState & state
The Parser is subclassed and reinstantiated.
Definition: Parser.h:364
Attribute parseAttribute(Type type={})
Parse an arbitrary attribute with an optional type.
StringRef getTokenSpelling() const
Definition: Parser.h:104
void consumeToken()
Advance the current lexer onto the next token.
Definition: Parser.h:119
ParseResult parseDialectSymbolBody(StringRef &body, bool &isCodeCompletion)
Parse the body of a dialect symbol, which starts and ends with <>'s, and may be recursive.
MLIRContext * getContext() const
Definition: Parser.h:38
InFlightDiagnostic emitWrongTokenError(const Twine &message={})
Emit an error about a "wrong token".
Definition: Parser.cpp:215
void resetToken(const char *tokPos)
Reset the parser to the given lexer position.
Definition: Parser.h:140
Attribute parseExtendedAttr(Type type)
Parse an extended attribute.
const Token & getToken() const
Return the current token the parser is inspecting.
Definition: Parser.h:103
Attribute codeCompleteDialectSymbol(const llvm::StringMap< Attribute > &aliases)
Definition: Parser.cpp:532
AttrTypeReplacer.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
This class refers to all of the state maintained globally by the parser, such as the current lexer po...
Definition: ParserState.h:51
SymbolState & symbols
The current state for symbol parsing.
Definition: ParserState.h:75
Lexer lex
The lexer for the source file we're parsing.
Definition: ParserState.h:66
Token curToken
This is the next token that hasn't been consumed yet.
Definition: ParserState.h:69
AsmParserState * asmState
An optional pointer to a struct containing high level parser state to be populated during parsing.
Definition: ParserState.h:83
This class contains record of any parsed top-level symbols.
Definition: ParserState.h:28
llvm::StringMap< Attribute > attributeAliasDefinitions
A map from attribute alias identifier to Attribute.
Definition: ParserState.h:30