MLIR  22.0.0git
DialectSymbolParser.cpp
Go to the documentation of this file.
1 //===- DialectSymbolParser.cpp - MLIR Dialect Symbol Parser --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the parser for the dialect symbols, such as extended
10 // attributes and types.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "AsmParserImpl.h"
15 #include "Parser.h"
17 #include "mlir/IR/AsmState.h"
18 #include "mlir/IR/Attributes.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/Dialect.h"
24 #include "mlir/IR/MLIRContext.h"
25 #include "mlir/Support/LLVM.h"
26 #include "llvm/Support/MemoryBuffer.h"
27 #include "llvm/Support/SourceMgr.h"
28 #include <cassert>
29 #include <cstddef>
30 #include <utility>
31 
32 using namespace mlir;
33 using namespace mlir::detail;
34 using llvm::MemoryBuffer;
35 using llvm::SourceMgr;
36 
37 namespace {
38 /// This class provides the main implementation of the DialectAsmParser that
39 /// allows for dialects to parse attributes and types. This allows for dialect
40 /// hooking into the main MLIR parsing logic.
41 class CustomDialectAsmParser : public AsmParserImpl<DialectAsmParser> {
42 public:
43  CustomDialectAsmParser(StringRef fullSpec, Parser &parser)
44  : AsmParserImpl<DialectAsmParser>(parser.getToken().getLoc(), parser),
45  fullSpec(fullSpec) {}
46  ~CustomDialectAsmParser() override = default;
47 
48  /// Returns the full specification of the symbol being parsed. This allows
49  /// for using a separate parser if necessary.
50  StringRef getFullSymbolSpec() const override { return fullSpec; }
51 
52 private:
53  /// The full symbol specification.
54  StringRef fullSpec;
55 };
56 } // namespace
57 
58 ///
59 /// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>'
60 /// pretty-dialect-sym-contents ::= pretty-dialect-sym-body
61 /// | '(' pretty-dialect-sym-contents+ ')'
62 /// | '[' pretty-dialect-sym-contents+ ']'
63 /// | '{' pretty-dialect-sym-contents+ '}'
64 /// | '[^[<({>\])}\0]+'
65 ///
66 ParseResult Parser::parseDialectSymbolBody(StringRef &body,
67  bool &isCodeCompletion) {
68  // Symbol bodies are a relatively unstructured format that contains a series
69  // of properly nested punctuation, with anything else in the middle. Scan
70  // ahead to find it and consume it if successful, otherwise emit an error.
71  const char *curPtr = getTokenSpelling().data();
72 
73  // Scan over the nested punctuation, bailing out on error and consuming until
74  // we find the end. We know that we're currently looking at the '<', so we can
75  // go until we find the matching '>' character.
76  assert(*curPtr == '<');
77  SmallVector<char, 8> nestedPunctuation;
78  const char *codeCompleteLoc = state.lex.getCodeCompleteLoc();
79 
80  // Functor used to emit an unbalanced punctuation error.
81  auto emitPunctError = [&] {
82  return emitError() << "unbalanced '" << nestedPunctuation.back()
83  << "' character in pretty dialect name";
84  };
85  // Functor used to check for unbalanced punctuation.
86  auto checkNestedPunctuation = [&](char expectedToken) -> ParseResult {
87  if (nestedPunctuation.back() != expectedToken)
88  return emitPunctError();
89  nestedPunctuation.pop_back();
90  return success();
91  };
92  const char *curBufferEnd = state.lex.getBufferEnd();
93  do {
94  // Handle code completions, which may appear in the middle of the symbol
95  // body.
96  if (curPtr == codeCompleteLoc) {
97  isCodeCompletion = true;
98  nestedPunctuation.clear();
99  break;
100  }
101 
102  if (curBufferEnd == curPtr) {
103  if (!nestedPunctuation.empty())
104  return emitPunctError();
105  return emitError("unexpected nul or EOF in pretty dialect name");
106  }
107 
108  char c = *curPtr++;
109  switch (c) {
110  case '\0':
111  // This also handles the EOF case.
112  if (!nestedPunctuation.empty())
113  return emitPunctError();
114  return emitError("unexpected nul or EOF in pretty dialect name");
115  case '<':
116  case '[':
117  case '(':
118  case '{':
119  nestedPunctuation.push_back(c);
120  continue;
121 
122  case '-':
123  // The sequence `->` is treated as special token.
124  if (*curPtr == '>')
125  ++curPtr;
126  continue;
127 
128  case '>':
129  if (failed(checkNestedPunctuation('<')))
130  return failure();
131  break;
132  case ']':
133  if (failed(checkNestedPunctuation('[')))
134  return failure();
135  break;
136  case ')':
137  if (failed(checkNestedPunctuation('(')))
138  return failure();
139  break;
140  case '}':
141  if (failed(checkNestedPunctuation('{')))
142  return failure();
143  break;
144  case '"': {
145  // Dispatch to the lexer to lex past strings.
146  resetToken(curPtr - 1);
147  curPtr = state.curToken.getEndLoc().getPointer();
148 
149  // Handle code completions, which may appear in the middle of the symbol
150  // body.
152  isCodeCompletion = true;
153  nestedPunctuation.clear();
154  break;
155  }
156 
157  // Otherwise, ensure this token was actually a string.
158  if (state.curToken.isNot(Token::string))
159  return failure();
160  break;
161  }
162 
163  default:
164  continue;
165  }
166  } while (!nestedPunctuation.empty());
167 
168  // Ok, we succeeded, remember where we stopped, reset the lexer to know it is
169  // consuming all this stuff, and return.
170  resetToken(curPtr);
171 
172  unsigned length = curPtr - body.begin();
173  body = StringRef(body.data(), length);
174  return success();
175 }
176 
177 /// Parse an extended dialect symbol.
178 template <typename Symbol, typename SymbolAliasMap, typename CreateFn>
179 static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState,
180  SymbolAliasMap &aliases,
181  CreateFn &&createSymbol) {
182  Token tok = p.getToken();
183 
184  // Handle code completion of the extended symbol.
185  StringRef identifier = tok.getSpelling().drop_front();
186  if (tok.isCodeCompletion() && identifier.empty())
187  return p.codeCompleteDialectSymbol(aliases);
188 
189  // Parse the dialect namespace.
190  SMRange range = p.getToken().getLocRange();
191  SMLoc loc = p.getToken().getLoc();
192  p.consumeToken();
193 
194  // Check to see if this is a pretty name.
195  auto [dialectName, symbolData] = identifier.split('.');
196  bool isPrettyName = !symbolData.empty() || identifier.back() == '.';
197 
198  // Check to see if the symbol has trailing data, i.e. has an immediately
199  // following '<'.
200  bool hasTrailingData =
201  p.getToken().is(Token::less) &&
202  identifier.bytes_end() == p.getTokenSpelling().bytes_begin();
203 
204  // If there is no '<' token following this, and if the typename contains no
205  // dot, then we are parsing a symbol alias.
206  if (!hasTrailingData && !isPrettyName) {
207  // Check for an alias for this type.
208  auto aliasIt = aliases.find(identifier);
209  if (aliasIt == aliases.end())
210  return (p.emitWrongTokenError("undefined symbol alias id '" + identifier +
211  "'"),
212  nullptr);
213  if (asmState) {
214  if constexpr (std::is_same_v<Symbol, Type>)
215  asmState->addTypeAliasUses(identifier, range);
216  else
217  asmState->addAttrAliasUses(identifier, range);
218  }
219  return aliasIt->second;
220  }
221 
222  // If this isn't an alias, we are parsing a dialect-specific symbol. If the
223  // name contains a dot, then this is the "pretty" form. If not, it is the
224  // verbose form that looks like <...>.
225  if (!isPrettyName) {
226  // Point the symbol data to the end of the dialect name to start.
227  symbolData = StringRef(dialectName.end(), 0);
228 
229  // Parse the body of the symbol.
230  bool isCodeCompletion = false;
231  if (p.parseDialectSymbolBody(symbolData, isCodeCompletion))
232  return nullptr;
233  symbolData = symbolData.drop_front();
234 
235  // If the body contained a code completion it won't have the trailing `>`
236  // token, so don't drop it.
237  if (!isCodeCompletion)
238  symbolData = symbolData.drop_back();
239  } else {
240  loc = SMLoc::getFromPointer(symbolData.data());
241 
242  // If the dialect's symbol is followed immediately by a <, then lex the body
243  // of it into prettyName.
244  if (hasTrailingData && p.parseDialectSymbolBody(symbolData))
245  return nullptr;
246  }
247 
248  return createSymbol(dialectName, symbolData, loc);
249 }
250 
251 /// Parse an extended attribute.
252 ///
253 /// extended-attribute ::= (dialect-attribute | attribute-alias)
254 /// dialect-attribute ::= `#` dialect-namespace `<` attr-data `>`
255 /// (`:` type)?
256 /// | `#` alias-name pretty-dialect-sym-body? (`:` type)?
257 /// attribute-alias ::= `#` alias-name
258 ///
260  MLIRContext *ctx = getContext();
261  Attribute attr = parseExtendedSymbol<Attribute>(
263  [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Attribute {
264  // Parse an optional trailing colon type.
265  Type attrType = type;
266  if (consumeIf(Token::colon) && !(attrType = parseType()))
267  return Attribute();
268 
269  // If we found a registered dialect, then ask it to parse the attribute.
270  if (Dialect *dialect =
271  builder.getContext()->getOrLoadDialect(dialectName)) {
272  // Temporarily reset the lexer to let the dialect parse the attribute.
273  const char *curLexerPos = getToken().getLoc().getPointer();
274  resetToken(symbolData.data());
275 
276  // Parse the attribute.
277  CustomDialectAsmParser customParser(symbolData, *this);
278  Attribute attr = dialect->parseAttribute(customParser, attrType);
279  resetToken(curLexerPos);
280  return attr;
281  }
282 
283  // Otherwise, form a new opaque attribute.
284  return OpaqueAttr::getChecked(
285  [&] { return emitError(loc); }, StringAttr::get(ctx, dialectName),
286  symbolData, attrType ? attrType : NoneType::get(ctx));
287  });
288 
289  // Ensure that the attribute has the same type as requested.
290  auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
291  if (type && typedAttr && typedAttr.getType() != type) {
292  emitError("attribute type different than expected: expected ")
293  << type << ", but got " << typedAttr.getType();
294  return nullptr;
295  }
296  return attr;
297 }
298 
299 /// Parse an extended type.
300 ///
301 /// extended-type ::= (dialect-type | type-alias)
302 /// dialect-type ::= `!` dialect-namespace `<` `"` type-data `"` `>`
303 /// dialect-type ::= `!` alias-name pretty-dialect-attribute-body?
304 /// type-alias ::= `!` alias-name
305 ///
306 Type Parser::parseExtendedType() {
307  MLIRContext *ctx = getContext();
308  return parseExtendedSymbol<Type>(
309  *this, state.asmState, state.symbols.typeAliasDefinitions,
310  [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Type {
311  // If we found a registered dialect, then ask it to parse the type.
312  if (auto *dialect = ctx->getOrLoadDialect(dialectName)) {
313  // Temporarily reset the lexer to let the dialect parse the type.
314  const char *curLexerPos = getToken().getLoc().getPointer();
315  resetToken(symbolData.data());
316 
317  // Parse the type.
318  CustomDialectAsmParser customParser(symbolData, *this);
319  Type type = dialect->parseType(customParser);
320  resetToken(curLexerPos);
321  return type;
322  }
323 
324  // Otherwise, form a new opaque type.
325  return OpaqueType::getChecked([&] { return emitError(loc); },
326  StringAttr::get(ctx, dialectName),
327  symbolData);
328  });
329 }
330 
331 //===----------------------------------------------------------------------===//
332 // mlir::parseAttribute/parseType
333 //===----------------------------------------------------------------------===//
334 
335 /// Parses a symbol, of type 'T', and returns it if parsing was successful. If
336 /// parsing failed, nullptr is returned.
337 template <typename T, typename ParserFn>
338 static T parseSymbol(StringRef inputStr, MLIRContext *context,
339  size_t *numReadOut, bool isKnownNullTerminated,
340  ParserFn &&parserFn) {
341  // Set the buffer name to the string being parsed, so that it appears in error
342  // diagnostics.
343  auto memBuffer =
344  isKnownNullTerminated
345  ? MemoryBuffer::getMemBuffer(inputStr,
346  /*BufferName=*/inputStr)
347  : MemoryBuffer::getMemBufferCopy(inputStr, /*BufferName=*/inputStr);
348  SourceMgr sourceMgr;
349  sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
350  SymbolState aliasState;
351  ParserConfig config(context);
352  ParserState state(sourceMgr, config, aliasState, /*asmState=*/nullptr,
353  /*codeCompleteContext=*/nullptr);
354  Parser parser(state);
355 
356  Token startTok = parser.getToken();
357  T symbol = parserFn(parser);
358  if (!symbol)
359  return T();
360 
361  // Provide the number of bytes that were read.
362  Token endTok = parser.getToken();
363  size_t numRead =
364  endTok.getLoc().getPointer() - startTok.getLoc().getPointer();
365  if (numReadOut) {
366  *numReadOut = numRead;
367  } else if (numRead != inputStr.size()) {
368  parser.emitError(endTok.getLoc()) << "found trailing characters: '"
369  << inputStr.drop_front(numRead) << "'";
370  return T();
371  }
372  return symbol;
373 }
374 
375 Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context,
376  Type type, size_t *numRead,
377  bool isKnownNullTerminated) {
378  return parseSymbol<Attribute>(
379  attrStr, context, numRead, isKnownNullTerminated,
380  [type](Parser &parser) { return parser.parseAttribute(type); });
381 }
382 Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t *numRead,
383  bool isKnownNullTerminated) {
384  return parseSymbol<Type>(typeStr, context, numRead, isKnownNullTerminated,
385  [](Parser &parser) { return parser.parseType(); });
386 }
static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState, SymbolAliasMap &aliases, CreateFn &&createSymbol)
Parse an extended dialect symbol.
static T parseSymbol(StringRef inputStr, MLIRContext *context, size_t *numReadOut, bool isKnownNullTerminated, ParserFn &&parserFn)
Parses a symbol, of type 'T', and returns it if parsing was successful.
static MLIRContext * getContext(OpFoldResult val)
This class represents state from a parsed MLIR textual format string.
void addTypeAliasUses(StringRef name, SMRange locations)
void addAttrAliasUses(StringRef name, SMRange locations)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
const char * getBufferEnd()
Returns the end of the buffer.
Definition: Lexer.h:44
const char * getCodeCompleteLoc() const
Return the code completion location of the lexer, or nullptr if there is none.
Definition: Lexer.h:48
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class represents a configuration for the MLIR assembly parser.
Definition: AsmState.h:469
This represents a token in the MLIR syntax.
Definition: Token.h:20
SMRange getLocRange() const
Definition: Token.cpp:30
SMLoc getLoc() const
Definition: Token.cpp:24
bool is(Kind k) const
Definition: Token.h:38
SMLoc getEndLoc() const
Definition: Token.cpp:26
bool isNot(Kind k) const
Definition: Token.h:50
bool isCodeCompletion() const
Returns true if the current token represents a code completion.
Definition: Token.h:62
StringRef getSpelling() const
Definition: Token.h:34
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides the implementation of the generic parser methods within AsmParser.
Definition: AsmParserImpl.h:28
This class implement support for parsing global entities like attributes and types.
Definition: Parser.h:27
Type parseType()
Parse an arbitrary type.
Definition: TypeParser.cpp:74
InFlightDiagnostic emitError(const Twine &message={})
Emit an error and return failure.
Definition: Parser.cpp:192
ParserState & state
The Parser is subclassed and reinstantiated.
Definition: Parser.h:370
Attribute parseAttribute(Type type={})
Parse an arbitrary attribute with an optional type.
StringRef getTokenSpelling() const
Definition: Parser.h:104
void consumeToken()
Advance the current lexer onto the next token.
Definition: Parser.h:119
ParseResult parseDialectSymbolBody(StringRef &body, bool &isCodeCompletion)
Parse the body of a dialect symbol, which starts and ends with <>'s, and may be recursive.
MLIRContext * getContext() const
Definition: Parser.h:38
InFlightDiagnostic emitWrongTokenError(const Twine &message={})
Emit an error about a "wrong token".
Definition: Parser.cpp:215
void resetToken(const char *tokPos)
Reset the parser to the given lexer position.
Definition: Parser.h:140
Attribute parseExtendedAttr(Type type)
Parse an extended attribute.
const Token & getToken() const
Return the current token the parser is inspecting.
Definition: Parser.h:103
Attribute codeCompleteDialectSymbol(const llvm::StringMap< Attribute > &aliases)
Definition: Parser.cpp:555
AttrTypeReplacer.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
This class refers to all of the state maintained globally by the parser, such as the current lexer po...
Definition: ParserState.h:51
SymbolState & symbols
The current state for symbol parsing.
Definition: ParserState.h:75
Lexer lex
The lexer for the source file we're parsing.
Definition: ParserState.h:66
Token curToken
This is the next token that hasn't been consumed yet.
Definition: ParserState.h:69
AsmParserState * asmState
An optional pointer to a struct containing high level parser state to be populated during parsing.
Definition: ParserState.h:83
This class contains record of any parsed top-level symbols.
Definition: ParserState.h:28
llvm::StringMap< Attribute > attributeAliasDefinitions
A map from attribute alias identifier to Attribute.
Definition: ParserState.h:30