MLIR 22.0.0git
DialectSymbolParser.cpp
Go to the documentation of this file.
1//===- DialectSymbolParser.cpp - MLIR Dialect Symbol Parser --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the parser for the dialect symbols, such as extended
10// attributes and types.
11//
12//===----------------------------------------------------------------------===//
13
14#include "AsmParserImpl.h"
15#include "Parser.h"
17#include "mlir/IR/AsmState.h"
18#include "mlir/IR/Attributes.h"
22#include "mlir/IR/Dialect.h"
24#include "mlir/IR/MLIRContext.h"
25#include "mlir/Support/LLVM.h"
26#include "llvm/Support/MemoryBuffer.h"
27#include "llvm/Support/SourceMgr.h"
28#include <cassert>
29#include <cstddef>
30#include <utility>
31
32using namespace mlir;
33using namespace mlir::detail;
34using llvm::MemoryBuffer;
35using llvm::SourceMgr;
36
37namespace {
38/// This class provides the main implementation of the DialectAsmParser that
39/// allows for dialects to parse attributes and types. This allows for dialect
40/// hooking into the main MLIR parsing logic.
41class CustomDialectAsmParser : public AsmParserImpl<DialectAsmParser> {
42public:
43 CustomDialectAsmParser(StringRef fullSpec, Parser &parser)
44 : AsmParserImpl<DialectAsmParser>(parser.getToken().getLoc(), parser),
45 fullSpec(fullSpec) {}
46 ~CustomDialectAsmParser() override = default;
47
48 /// Returns the full specification of the symbol being parsed. This allows
49 /// for using a separate parser if necessary.
50 StringRef getFullSymbolSpec() const override { return fullSpec; }
51
52private:
53 /// The full symbol specification.
54 StringRef fullSpec;
55};
56} // namespace
57
58///
59/// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>'
60/// pretty-dialect-sym-contents ::= pretty-dialect-sym-body
61/// | '(' pretty-dialect-sym-contents+ ')'
62/// | '[' pretty-dialect-sym-contents+ ']'
63/// | '{' pretty-dialect-sym-contents+ '}'
64/// | '[^[<({>\])}\0]+'
65///
66ParseResult Parser::parseDialectSymbolBody(StringRef &body,
67 bool &isCodeCompletion) {
68 // Symbol bodies are a relatively unstructured format that contains a series
69 // of properly nested punctuation, with anything else in the middle. Scan
70 // ahead to find it and consume it if successful, otherwise emit an error.
71 const char *curPtr = getTokenSpelling().data();
72
73 // Scan over the nested punctuation, bailing out on error and consuming until
74 // we find the end. We know that we're currently looking at the '<', so we can
75 // go until we find the matching '>' character.
76 assert(*curPtr == '<');
77 SmallVector<char, 8> nestedPunctuation;
78 const char *codeCompleteLoc = state.lex.getCodeCompleteLoc();
79
80 // Functor used to emit an unbalanced punctuation error.
81 auto emitPunctError = [&] {
82 return emitError() << "unbalanced '" << nestedPunctuation.back()
83 << "' character in pretty dialect name";
84 };
85 // Functor used to check for unbalanced punctuation.
86 auto checkNestedPunctuation = [&](char expectedToken) -> ParseResult {
87 if (nestedPunctuation.back() != expectedToken)
88 return emitPunctError();
89 nestedPunctuation.pop_back();
90 return success();
91 };
92 const char *curBufferEnd = state.lex.getBufferEnd();
93 do {
94 // Handle code completions, which may appear in the middle of the symbol
95 // body.
96 if (curPtr == codeCompleteLoc) {
97 isCodeCompletion = true;
98 nestedPunctuation.clear();
99 break;
100 }
101
102 if (curBufferEnd == curPtr) {
103 if (!nestedPunctuation.empty())
104 return emitPunctError();
105 return emitError("unexpected nul or EOF in pretty dialect name");
106 }
107
108 char c = *curPtr++;
109 switch (c) {
110 case '\0':
111 // This also handles the EOF case.
112 if (!nestedPunctuation.empty())
113 return emitPunctError();
114 return emitError("unexpected nul or EOF in pretty dialect name");
115 case '<':
116 case '[':
117 case '(':
118 case '{':
119 nestedPunctuation.push_back(c);
120 continue;
121
122 case '-':
123 // The sequence `->` is treated as special token.
124 if (*curPtr == '>')
125 ++curPtr;
126 continue;
127
128 case '>':
129 if (failed(checkNestedPunctuation('<')))
130 return failure();
131 break;
132 case ']':
133 if (failed(checkNestedPunctuation('[')))
134 return failure();
135 break;
136 case ')':
137 if (failed(checkNestedPunctuation('(')))
138 return failure();
139 break;
140 case '}':
141 if (failed(checkNestedPunctuation('{')))
142 return failure();
143 break;
144 case '"': {
145 // Dispatch to the lexer to lex past strings.
146 resetToken(curPtr - 1);
147 curPtr = state.curToken.getEndLoc().getPointer();
148
149 // Handle code completions, which may appear in the middle of the symbol
150 // body.
151 if (state.curToken.isCodeCompletion()) {
152 isCodeCompletion = true;
153 nestedPunctuation.clear();
154 break;
155 }
156
157 // Otherwise, ensure this token was actually a string.
158 if (state.curToken.isNot(Token::string))
159 return failure();
160 break;
161 }
162
163 default:
164 continue;
165 }
166 } while (!nestedPunctuation.empty());
167
168 // Ok, we succeeded, remember where we stopped, reset the lexer to know it is
169 // consuming all this stuff, and return.
170 resetToken(curPtr);
171
172 unsigned length = curPtr - body.begin();
173 body = StringRef(body.data(), length);
174 return success();
175}
176
177/// Parse an extended dialect symbol.
178template <typename Symbol, typename SymbolAliasMap, typename CreateFn>
179static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState,
180 SymbolAliasMap &aliases,
181 CreateFn &&createSymbol) {
182 Token tok = p.getToken();
183
184 // Handle code completion of the extended symbol.
185 StringRef identifier = tok.getSpelling().drop_front();
186 if (tok.isCodeCompletion() && identifier.empty())
187 return p.codeCompleteDialectSymbol(aliases);
188
189 // Parse the dialect namespace.
190 SMRange range = p.getToken().getLocRange();
191 SMLoc loc = p.getToken().getLoc();
192 p.consumeToken();
193
194 // Check to see if this is a pretty name.
195 auto [dialectName, symbolData] = identifier.split('.');
196 bool isPrettyName = !symbolData.empty() || identifier.back() == '.';
197
198 // Check to see if the symbol has trailing data, i.e. has an immediately
199 // following '<'.
200 bool hasTrailingData =
201 p.getToken().is(Token::less) &&
202 identifier.bytes_end() == p.getTokenSpelling().bytes_begin();
203
204 // If there is no '<' token following this, and if the typename contains no
205 // dot, then we are parsing a symbol alias.
206 if (!hasTrailingData && !isPrettyName) {
207 // Check for an alias for this type.
208 auto aliasIt = aliases.find(identifier);
209 if (aliasIt == aliases.end())
210 return (p.emitWrongTokenError("undefined symbol alias id '" + identifier +
211 "'"),
212 nullptr);
213 if (asmState) {
214 if constexpr (std::is_same_v<Symbol, Type>)
215 asmState->addTypeAliasUses(identifier, range);
216 else
217 asmState->addAttrAliasUses(identifier, range);
218 }
219 return aliasIt->second;
220 }
221
222 // If this isn't an alias, we are parsing a dialect-specific symbol. If the
223 // name contains a dot, then this is the "pretty" form. If not, it is the
224 // verbose form that looks like <...>.
225 if (!isPrettyName) {
226 // Point the symbol data to the end of the dialect name to start.
227 symbolData = StringRef(dialectName.end(), 0);
228
229 // Parse the body of the symbol.
230 bool isCodeCompletion = false;
231 if (p.parseDialectSymbolBody(symbolData, isCodeCompletion))
232 return nullptr;
233 symbolData = symbolData.drop_front();
234
235 // If the body contained a code completion it won't have the trailing `>`
236 // token, so don't drop it.
237 if (!isCodeCompletion)
238 symbolData = symbolData.drop_back();
239 } else {
240 loc = SMLoc::getFromPointer(symbolData.data());
241
242 // If the dialect's symbol is followed immediately by a <, then lex the body
243 // of it into prettyName.
244 if (hasTrailingData && p.parseDialectSymbolBody(symbolData))
245 return nullptr;
246 }
247
248 return createSymbol(dialectName, symbolData, loc);
249}
250
251/// Parse an extended attribute.
252///
253/// extended-attribute ::= (dialect-attribute | attribute-alias)
254/// dialect-attribute ::= `#` dialect-namespace `<` attr-data `>`
255/// (`:` type)?
256/// | `#` alias-name pretty-dialect-sym-body? (`:` type)?
257/// attribute-alias ::= `#` alias-name
258///
260 MLIRContext *ctx = getContext();
262 *this, state.asmState, state.symbols.attributeAliasDefinitions,
263 [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Attribute {
264 // Parse an optional trailing colon type.
265 Type attrType = type;
266 if (consumeIf(Token::colon) && !(attrType = parseType()))
267 return Attribute();
268
269 // If we found a registered dialect, then ask it to parse the attribute.
270 if (Dialect *dialect =
271 builder.getContext()->getOrLoadDialect(dialectName)) {
272 // Temporarily reset the lexer to let the dialect parse the attribute.
273 const char *curLexerPos = getToken().getLoc().getPointer();
274 resetToken(symbolData.data());
275
276 // Parse the attribute.
277 CustomDialectAsmParser customParser(symbolData, *this);
278 Attribute attr = dialect->parseAttribute(customParser, attrType);
279 resetToken(curLexerPos);
280 return attr;
281 }
282
283 // Otherwise, form a new opaque attribute.
284 return OpaqueAttr::getChecked(
285 [&] { return emitError(loc); }, StringAttr::get(ctx, dialectName),
286 symbolData, attrType ? attrType : NoneType::get(ctx));
287 });
288
289 // Ensure that the attribute has the same type as requested.
290 auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
291 if (type && typedAttr && typedAttr.getType() != type) {
292 emitError("attribute type different than expected: expected ")
293 << type << ", but got " << typedAttr.getType();
294 return nullptr;
295 }
296 return attr;
297}
298
299/// Parse an extended type.
300///
301/// extended-type ::= (dialect-type | type-alias)
302/// dialect-type ::= `!` dialect-namespace `<` `"` type-data `"` `>`
303/// dialect-type ::= `!` alias-name pretty-dialect-attribute-body?
304/// type-alias ::= `!` alias-name
305///
307 MLIRContext *ctx = getContext();
309 *this, state.asmState, state.symbols.typeAliasDefinitions,
310 [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Type {
311 // If we found a registered dialect, then ask it to parse the type.
312 if (auto *dialect = ctx->getOrLoadDialect(dialectName)) {
313 // Temporarily reset the lexer to let the dialect parse the type.
314 const char *curLexerPos = getToken().getLoc().getPointer();
315 resetToken(symbolData.data());
316
317 // Parse the type.
318 CustomDialectAsmParser customParser(symbolData, *this);
319 Type type = dialect->parseType(customParser);
320 resetToken(curLexerPos);
321 return type;
322 }
323
324 // Otherwise, form a new opaque type.
325 return OpaqueType::getChecked([&] { return emitError(loc); },
326 StringAttr::get(ctx, dialectName),
327 symbolData);
328 });
329}
330
331//===----------------------------------------------------------------------===//
332// mlir::parseAttribute/parseType
333//===----------------------------------------------------------------------===//
334
335/// Parses a symbol, of type 'T', and returns it if parsing was successful. If
336/// parsing failed, nullptr is returned.
337template <typename T, typename ParserFn>
338static T parseSymbol(StringRef inputStr, MLIRContext *context,
339 size_t *numReadOut, bool isKnownNullTerminated,
340 ParserFn &&parserFn) {
341 // Set the buffer name to the string being parsed, so that it appears in error
342 // diagnostics.
343 auto memBuffer =
344 isKnownNullTerminated
345 ? MemoryBuffer::getMemBuffer(inputStr,
346 /*BufferName=*/inputStr)
347 : MemoryBuffer::getMemBufferCopy(inputStr, /*BufferName=*/inputStr);
348 SourceMgr sourceMgr;
349 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
350 SymbolState aliasState;
351 ParserConfig config(context);
352 ParserState state(sourceMgr, config, aliasState, /*asmState=*/nullptr,
353 /*codeCompleteContext=*/nullptr);
354 Parser parser(state);
355
356 Token startTok = parser.getToken();
357 T symbol = parserFn(parser);
358 if (!symbol)
359 return T();
360
361 // Provide the number of bytes that were read.
362 Token endTok = parser.getToken();
363 size_t numRead =
364 endTok.getLoc().getPointer() - startTok.getLoc().getPointer();
365 if (numReadOut) {
366 *numReadOut = numRead;
367 } else if (numRead != inputStr.size()) {
368 parser.emitError(endTok.getLoc()) << "found trailing characters: '"
369 << inputStr.drop_front(numRead) << "'";
370 return T();
371 }
372 return symbol;
373}
374
375Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context,
376 Type type, size_t *numRead,
377 bool isKnownNullTerminated) {
379 attrStr, context, numRead, isKnownNullTerminated,
380 [type](Parser &parser) { return parser.parseAttribute(type); });
381}
382Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t *numRead,
383 bool isKnownNullTerminated) {
384 return parseSymbol<Type>(typeStr, context, numRead, isKnownNullTerminated,
385 [](Parser &parser) { return parser.parseType(); });
386}
return success()
static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState, SymbolAliasMap &aliases, CreateFn &&createSymbol)
Parse an extended dialect symbol.
static T parseSymbol(StringRef inputStr, MLIRContext *context, size_t *numReadOut, bool isKnownNullTerminated, ParserFn &&parserFn)
Parses a symbol, of type 'T', and returns it if parsing was successful.
This class represents state from a parsed MLIR textual format string.
void addTypeAliasUses(StringRef name, SMRange locations)
void addAttrAliasUses(StringRef name, SMRange locations)
Attributes are known-constant values of operations.
Definition Attributes.h:25
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents a configuration for the MLIR assembly parser.
Definition AsmState.h:469
This represents a token in the MLIR syntax.
Definition Token.h:20
SMRange getLocRange() const
Definition Token.cpp:30
SMLoc getLoc() const
Definition Token.cpp:24
bool is(Kind k) const
Definition Token.h:38
bool isCodeCompletion() const
Returns true if the current token represents a code completion.
Definition Token.h:62
StringRef getSpelling() const
Definition Token.h:34
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides the implementation of the generic parser methods within AsmParser.
This class implement support for parsing global entities like attributes and types.
Definition Parser.h:27
Type parseType()
Parse an arbitrary type.
Type parseExtendedType()
Parse an extended type.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error and return failure.
Definition Parser.cpp:192
ParserState & state
The Parser is subclassed and reinstantiated.
Definition Parser.h:370
Attribute parseAttribute(Type type={})
Parse an arbitrary attribute with an optional type.
StringRef getTokenSpelling() const
Definition Parser.h:104
void consumeToken()
Advance the current lexer onto the next token.
Definition Parser.h:119
ParseResult parseDialectSymbolBody(StringRef &body, bool &isCodeCompletion)
Parse the body of a dialect symbol, which starts and ends with <>'s, and may be recursive.
InFlightDiagnostic emitWrongTokenError(const Twine &message={})
Emit an error about a "wrong token".
Definition Parser.cpp:215
void resetToken(const char *tokPos)
Reset the parser to the given lexer position.
Definition Parser.h:140
const Token & getToken() const
Return the current token the parser is inspecting.
Definition Parser.h:103
Attribute parseExtendedAttr(Type type)
Parse an extended attribute.
MLIRContext * getContext() const
Definition Parser.h:38
Attribute codeCompleteDialectSymbol(const llvm::StringMap< Attribute > &aliases)
Definition Parser.cpp:555
AttrTypeReplacer.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
This class refers to all of the state maintained globally by the parser, such as the current lexer po...
Definition ParserState.h:51
This class contains record of any parsed top-level symbols.
Definition ParserState.h:28