MLIR  18.0.0git
DialectSymbolParser.cpp
Go to the documentation of this file.
1 //===- DialectSymbolParser.cpp - MLIR Dialect Symbol Parser --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the parser for the dialect symbols, such as extended
10 // attributes and types.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "AsmParserImpl.h"
15 #include "Parser.h"
17 #include "mlir/IR/AsmState.h"
18 #include "mlir/IR/Attributes.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/Dialect.h"
24 #include "mlir/IR/MLIRContext.h"
25 #include "mlir/Support/LLVM.h"
27 #include "llvm/Support/MemoryBuffer.h"
28 #include "llvm/Support/SourceMgr.h"
29 #include <cassert>
30 #include <cstddef>
31 #include <utility>
32 
33 using namespace mlir;
34 using namespace mlir::detail;
35 using llvm::MemoryBuffer;
36 using llvm::SourceMgr;
37 
38 namespace {
39 /// This class provides the main implementation of the DialectAsmParser that
40 /// allows for dialects to parse attributes and types. This allows for dialect
41 /// hooking into the main MLIR parsing logic.
42 class CustomDialectAsmParser : public AsmParserImpl<DialectAsmParser> {
43 public:
44  CustomDialectAsmParser(StringRef fullSpec, Parser &parser)
45  : AsmParserImpl<DialectAsmParser>(parser.getToken().getLoc(), parser),
46  fullSpec(fullSpec) {}
47  ~CustomDialectAsmParser() override = default;
48 
49  /// Returns the full specification of the symbol being parsed. This allows
50  /// for using a separate parser if necessary.
51  StringRef getFullSymbolSpec() const override { return fullSpec; }
52 
53 private:
54  /// The full symbol specification.
55  StringRef fullSpec;
56 };
57 } // namespace
58 
59 ///
60 /// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>'
61 /// pretty-dialect-sym-contents ::= pretty-dialect-sym-body
62 /// | '(' pretty-dialect-sym-contents+ ')'
63 /// | '[' pretty-dialect-sym-contents+ ']'
64 /// | '{' pretty-dialect-sym-contents+ '}'
65 /// | '[^[<({>\])}\0]+'
66 ///
68  bool &isCodeCompletion) {
69  // Symbol bodies are a relatively unstructured format that contains a series
70  // of properly nested punctuation, with anything else in the middle. Scan
71  // ahead to find it and consume it if successful, otherwise emit an error.
72  const char *curPtr = getTokenSpelling().data();
73 
74  // Scan over the nested punctuation, bailing out on error and consuming until
75  // we find the end. We know that we're currently looking at the '<', so we can
76  // go until we find the matching '>' character.
77  assert(*curPtr == '<');
78  SmallVector<char, 8> nestedPunctuation;
79  const char *codeCompleteLoc = state.lex.getCodeCompleteLoc();
80 
81  // Functor used to emit an unbalanced punctuation error.
82  auto emitPunctError = [&] {
83  return emitError() << "unbalanced '" << nestedPunctuation.back()
84  << "' character in pretty dialect name";
85  };
86  // Functor used to check for unbalanced punctuation.
87  auto checkNestedPunctuation = [&](char expectedToken) -> ParseResult {
88  if (nestedPunctuation.back() != expectedToken)
89  return emitPunctError();
90  nestedPunctuation.pop_back();
91  return success();
92  };
93  do {
94  // Handle code completions, which may appear in the middle of the symbol
95  // body.
96  if (curPtr == codeCompleteLoc) {
97  isCodeCompletion = true;
98  nestedPunctuation.clear();
99  break;
100  }
101 
102  char c = *curPtr++;
103  switch (c) {
104  case '\0':
105  // This also handles the EOF case.
106  if (!nestedPunctuation.empty())
107  return emitPunctError();
108  return emitError("unexpected nul or EOF in pretty dialect name");
109  case '<':
110  case '[':
111  case '(':
112  case '{':
113  nestedPunctuation.push_back(c);
114  continue;
115 
116  case '-':
117  // The sequence `->` is treated as special token.
118  if (*curPtr == '>')
119  ++curPtr;
120  continue;
121 
122  case '>':
123  if (failed(checkNestedPunctuation('<')))
124  return failure();
125  break;
126  case ']':
127  if (failed(checkNestedPunctuation('[')))
128  return failure();
129  break;
130  case ')':
131  if (failed(checkNestedPunctuation('(')))
132  return failure();
133  break;
134  case '}':
135  if (failed(checkNestedPunctuation('{')))
136  return failure();
137  break;
138  case '"': {
139  // Dispatch to the lexer to lex past strings.
140  resetToken(curPtr - 1);
141  curPtr = state.curToken.getEndLoc().getPointer();
142 
143  // Handle code completions, which may appear in the middle of the symbol
144  // body.
146  isCodeCompletion = true;
147  nestedPunctuation.clear();
148  break;
149  }
150 
151  // Otherwise, ensure this token was actually a string.
152  if (state.curToken.isNot(Token::string))
153  return failure();
154  break;
155  }
156 
157  default:
158  continue;
159  }
160  } while (!nestedPunctuation.empty());
161 
162  // Ok, we succeeded, remember where we stopped, reset the lexer to know it is
163  // consuming all this stuff, and return.
164  resetToken(curPtr);
165 
166  unsigned length = curPtr - body.begin();
167  body = StringRef(body.data(), length);
168  return success();
169 }
170 
171 /// Parse an extended dialect symbol.
172 template <typename Symbol, typename SymbolAliasMap, typename CreateFn>
173 static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState,
174  SymbolAliasMap &aliases,
175  CreateFn &&createSymbol) {
176  Token tok = p.getToken();
177 
178  // Handle code completion of the extended symbol.
179  StringRef identifier = tok.getSpelling().drop_front();
180  if (tok.isCodeCompletion() && identifier.empty())
181  return p.codeCompleteDialectSymbol(aliases);
182 
183  // Parse the dialect namespace.
184  SMRange range = p.getToken().getLocRange();
185  SMLoc loc = p.getToken().getLoc();
186  p.consumeToken();
187 
188  // Check to see if this is a pretty name.
189  auto [dialectName, symbolData] = identifier.split('.');
190  bool isPrettyName = !symbolData.empty() || identifier.back() == '.';
191 
192  // Check to see if the symbol has trailing data, i.e. has an immediately
193  // following '<'.
194  bool hasTrailingData =
195  p.getToken().is(Token::less) &&
196  identifier.bytes_end() == p.getTokenSpelling().bytes_begin();
197 
198  // If there is no '<' token following this, and if the typename contains no
199  // dot, then we are parsing a symbol alias.
200  if (!hasTrailingData && !isPrettyName) {
201  // Check for an alias for this type.
202  auto aliasIt = aliases.find(identifier);
203  if (aliasIt == aliases.end())
204  return (p.emitWrongTokenError("undefined symbol alias id '" + identifier +
205  "'"),
206  nullptr);
207  if (asmState) {
208  if constexpr (std::is_same_v<Symbol, Type>)
209  asmState->addTypeAliasUses(identifier, range);
210  else
211  asmState->addAttrAliasUses(identifier, range);
212  }
213  return aliasIt->second;
214  }
215 
216  // If this isn't an alias, we are parsing a dialect-specific symbol. If the
217  // name contains a dot, then this is the "pretty" form. If not, it is the
218  // verbose form that looks like <...>.
219  if (!isPrettyName) {
220  // Point the symbol data to the end of the dialect name to start.
221  symbolData = StringRef(dialectName.end(), 0);
222 
223  // Parse the body of the symbol.
224  bool isCodeCompletion = false;
225  if (p.parseDialectSymbolBody(symbolData, isCodeCompletion))
226  return nullptr;
227  symbolData = symbolData.drop_front();
228 
229  // If the body contained a code completion it won't have the trailing `>`
230  // token, so don't drop it.
231  if (!isCodeCompletion)
232  symbolData = symbolData.drop_back();
233  } else {
234  loc = SMLoc::getFromPointer(symbolData.data());
235 
236  // If the dialect's symbol is followed immediately by a <, then lex the body
237  // of it into prettyName.
238  if (hasTrailingData && p.parseDialectSymbolBody(symbolData))
239  return nullptr;
240  }
241 
242  return createSymbol(dialectName, symbolData, loc);
243 }
244 
245 /// Parse an extended attribute.
246 ///
247 /// extended-attribute ::= (dialect-attribute | attribute-alias)
248 /// dialect-attribute ::= `#` dialect-namespace `<` attr-data `>`
249 /// (`:` type)?
250 /// | `#` alias-name pretty-dialect-sym-body? (`:` type)?
251 /// attribute-alias ::= `#` alias-name
252 ///
254  MLIRContext *ctx = getContext();
255  Attribute attr = parseExtendedSymbol<Attribute>(
257  [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Attribute {
258  // Parse an optional trailing colon type.
259  Type attrType = type;
260  if (consumeIf(Token::colon) && !(attrType = parseType()))
261  return Attribute();
262 
263  // If we found a registered dialect, then ask it to parse the attribute.
264  if (Dialect *dialect =
265  builder.getContext()->getOrLoadDialect(dialectName)) {
266  // Temporarily reset the lexer to let the dialect parse the attribute.
267  const char *curLexerPos = getToken().getLoc().getPointer();
268  resetToken(symbolData.data());
269 
270  // Parse the attribute.
271  CustomDialectAsmParser customParser(symbolData, *this);
272  Attribute attr = dialect->parseAttribute(customParser, attrType);
273  resetToken(curLexerPos);
274  return attr;
275  }
276 
277  // Otherwise, form a new opaque attribute.
278  return OpaqueAttr::getChecked(
279  [&] { return emitError(loc); }, StringAttr::get(ctx, dialectName),
280  symbolData, attrType ? attrType : NoneType::get(ctx));
281  });
282 
283  // Ensure that the attribute has the same type as requested.
284  auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
285  if (type && typedAttr && typedAttr.getType() != type) {
286  emitError("attribute type different than expected: expected ")
287  << type << ", but got " << typedAttr.getType();
288  return nullptr;
289  }
290  return attr;
291 }
292 
293 /// Parse an extended type.
294 ///
295 /// extended-type ::= (dialect-type | type-alias)
296 /// dialect-type ::= `!` dialect-namespace `<` `"` type-data `"` `>`
297 /// dialect-type ::= `!` alias-name pretty-dialect-attribute-body?
298 /// type-alias ::= `!` alias-name
299 ///
300 Type Parser::parseExtendedType() {
301  MLIRContext *ctx = getContext();
302  return parseExtendedSymbol<Type>(
303  *this, state.asmState, state.symbols.typeAliasDefinitions,
304  [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Type {
305  // If we found a registered dialect, then ask it to parse the type.
306  if (auto *dialect = ctx->getOrLoadDialect(dialectName)) {
307  // Temporarily reset the lexer to let the dialect parse the type.
308  const char *curLexerPos = getToken().getLoc().getPointer();
309  resetToken(symbolData.data());
310 
311  // Parse the type.
312  CustomDialectAsmParser customParser(symbolData, *this);
313  Type type = dialect->parseType(customParser);
314  resetToken(curLexerPos);
315  return type;
316  }
317 
318  // Otherwise, form a new opaque type.
319  return OpaqueType::getChecked([&] { return emitError(loc); },
320  StringAttr::get(ctx, dialectName),
321  symbolData);
322  });
323 }
324 
325 //===----------------------------------------------------------------------===//
326 // mlir::parseAttribute/parseType
327 //===----------------------------------------------------------------------===//
328 
329 /// Parses a symbol, of type 'T', and returns it if parsing was successful. If
330 /// parsing failed, nullptr is returned.
331 template <typename T, typename ParserFn>
332 static T parseSymbol(StringRef inputStr, MLIRContext *context,
333  size_t *numReadOut, bool isKnownNullTerminated,
334  ParserFn &&parserFn) {
335  // Set the buffer name to the string being parsed, so that it appears in error
336  // diagnostics.
337  auto memBuffer =
338  isKnownNullTerminated
339  ? MemoryBuffer::getMemBuffer(inputStr,
340  /*BufferName=*/inputStr)
341  : MemoryBuffer::getMemBufferCopy(inputStr, /*BufferName=*/inputStr);
342  SourceMgr sourceMgr;
343  sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
344  SymbolState aliasState;
345  ParserConfig config(context);
346  ParserState state(sourceMgr, config, aliasState, /*asmState=*/nullptr,
347  /*codeCompleteContext=*/nullptr);
348  Parser parser(state);
349 
350  Token startTok = parser.getToken();
351  T symbol = parserFn(parser);
352  if (!symbol)
353  return T();
354 
355  // Provide the number of bytes that were read.
356  Token endTok = parser.getToken();
357  size_t numRead =
358  endTok.getLoc().getPointer() - startTok.getLoc().getPointer();
359  if (numReadOut) {
360  *numReadOut = numRead;
361  } else if (numRead != inputStr.size()) {
362  parser.emitError(endTok.getLoc()) << "found trailing characters: '"
363  << inputStr.drop_front(numRead) << "'";
364  return T();
365  }
366  return symbol;
367 }
368 
369 Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context,
370  Type type, size_t *numRead,
371  bool isKnownNullTerminated) {
372  return parseSymbol<Attribute>(
373  attrStr, context, numRead, isKnownNullTerminated,
374  [type](Parser &parser) { return parser.parseAttribute(type); });
375 }
376 Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t *numRead,
377  bool isKnownNullTerminated) {
378  return parseSymbol<Type>(typeStr, context, numRead, isKnownNullTerminated,
379  [](Parser &parser) { return parser.parseType(); });
380 }
static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState, SymbolAliasMap &aliases, CreateFn &&createSymbol)
Parse an extended dialect symbol.
static T parseSymbol(StringRef inputStr, MLIRContext *context, size_t *numReadOut, bool isKnownNullTerminated, ParserFn &&parserFn)
Parses a symbol, of type 'T', and returns it if parsing was successful.
static MLIRContext * getContext(OpFoldResult val)
This class represents state from a parsed MLIR textual format string.
void addTypeAliasUses(StringRef name, SMRange locations)
void addAttrAliasUses(StringRef name, SMRange locations)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
const char * getCodeCompleteLoc() const
Return the code completion location of the lexer, or nullptr if there is none.
Definition: Lexer.h:45
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class represents success/failure for parsing-like operations that find it important to chain tog...
This class represents a configuration for the MLIR assembly parser.
Definition: AsmState.h:460
This represents a token in the MLIR syntax.
Definition: Token.h:20
SMRange getLocRange() const
Definition: Token.cpp:30
SMLoc getLoc() const
Definition: Token.cpp:24
bool is(Kind k) const
Definition: Token.h:38
SMLoc getEndLoc() const
Definition: Token.cpp:26
bool isNot(Kind k) const
Definition: Token.h:50
bool isCodeCompletion() const
Returns true if the current token represents a code completion.
Definition: Token.h:62
StringRef getSpelling() const
Definition: Token.h:34
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides the implementation of the generic parser methods within AsmParser.
Definition: AsmParserImpl.h:28
This class implement support for parsing global entities like attributes and types.
Definition: Parser.h:26
Type parseType()
Parse an arbitrary type.
Definition: TypeParser.cpp:70
InFlightDiagnostic emitError(const Twine &message={})
Emit an error and return failure.
Definition: Parser.cpp:192
ParserState & state
The Parser is subclassed and reinstantiated.
Definition: Parser.h:347
Attribute parseAttribute(Type type={})
Parse an arbitrary attribute with an optional type.
StringRef getTokenSpelling() const
Definition: Parser.h:103
void consumeToken()
Advance the current lexer onto the next token.
Definition: Parser.h:115
ParseResult parseDialectSymbolBody(StringRef &body, bool &isCodeCompletion)
Parse the body of a dialect symbol, which starts and ends with <>'s, and may be recursive.
MLIRContext * getContext() const
Definition: Parser.h:37
InFlightDiagnostic emitWrongTokenError(const Twine &message={})
Emit an error about a "wrong token".
Definition: Parser.cpp:215
void resetToken(const char *tokPos)
Reset the parser to the given lexer position.
Definition: Parser.h:130
Attribute parseExtendedAttr(Type type)
Parse an extended attribute.
const Token & getToken() const
Return the current token the parser is inspecting.
Definition: Parser.h:102
Attribute codeCompleteDialectSymbol(const llvm::StringMap< Attribute > &aliases)
Definition: Parser.cpp:473
Detect if any of the given parameter types has a sub-element handler.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class refers to all of the state maintained globally by the parser, such as the current lexer po...
Definition: ParserState.h:51
SymbolState & symbols
The current state for symbol parsing.
Definition: ParserState.h:72
Lexer lex
The lexer for the source file we're parsing.
Definition: ParserState.h:66
Token curToken
This is the next token that hasn't been consumed yet.
Definition: ParserState.h:69
AsmParserState * asmState
An optional pointer to a struct containing high level parser state to be populated during parsing.
Definition: ParserState.h:80
This class contains record of any parsed top-level symbols.
Definition: ParserState.h:28
llvm::StringMap< Attribute > attributeAliasDefinitions
A map from attribute alias identifier to Attribute.
Definition: ParserState.h:30