MLIR 22.0.0git
Parser.cpp
Go to the documentation of this file.
1//===- Parser.cpp ---------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "Lexer.h"
25#include "llvm/ADT/StringExtras.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Support/FormatVariadic.h"
28#include "llvm/Support/SaveAndRestore.h"
29#include "llvm/Support/ScopedPrinter.h"
30#include "llvm/Support/VirtualFileSystem.h"
31#include "llvm/TableGen/Error.h"
32#include "llvm/TableGen/Parser.h"
33#include <optional>
34#include <string>
35
36using namespace mlir;
37using namespace mlir::pdll;
39//===----------------------------------------------------------------------===//
40// Parser
41//===----------------------------------------------------------------------===//
42
43namespace {
44class Parser {
45public:
46 Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
47 bool enableDocumentation, CodeCompleteContext *codeCompleteContext)
48 : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext),
49 curToken(lexer.lexToken()), enableDocumentation(enableDocumentation),
50 typeTy(ast::TypeType::get(ctx)), valueTy(ast::ValueType::get(ctx)),
51 typeRangeTy(ast::TypeRangeType::get(ctx)),
52 valueRangeTy(ast::ValueRangeType::get(ctx)),
53 attrTy(ast::AttributeType::get(ctx)),
54 codeCompleteContext(codeCompleteContext) {}
55
56 /// Try to parse a new module. Returns nullptr in the case of failure.
57 FailureOr<ast::Module *> parseModule();
58
59private:
60 /// The current context of the parser. It allows for the parser to know a bit
61 /// about the construct it is nested within during parsing. This is used
62 /// specifically to provide additional verification during parsing, e.g. to
63 /// prevent using rewrites within a match context, matcher constraints within
64 /// a rewrite section, etc.
65 enum class ParserContext {
66 /// The parser is in the global context.
67 Global,
68 /// The parser is currently within a Constraint, which disallows all types
69 /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.).
70 Constraint,
71 /// The parser is currently within the matcher portion of a Pattern, which
72 /// is allows a terminal operation rewrite statement but no other rewrite
73 /// transformations.
74 PatternMatch,
75 /// The parser is currently within a Rewrite, which disallows calls to
76 /// constraints, requires operation expressions to have names, etc.
77 Rewrite,
78 };
79
80 /// The current specification context of an operations result type. This
81 /// indicates how the result types of an operation may be inferred.
82 enum class OpResultTypeContext {
83 /// The result types of the operation are not known to be inferred.
84 Explicit,
85 /// The result types of the operation are inferred from the root input of a
86 /// `replace` statement.
87 Replacement,
88 /// The result types of the operation are inferred by using the
89 /// `InferTypeOpInterface` interface provided by the operation.
90 Interface,
91 };
93 //===--------------------------------------------------------------------===//
94 // Parsing
95 //===--------------------------------------------------------------------===//
96
97 /// Push a new decl scope onto the lexer.
98 ast::DeclScope *pushDeclScope() {
99 ast::DeclScope *newScope =
100 new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope);
101 return (curDeclScope = newScope);
102 }
103 void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; }
104
105 /// Pop the last decl scope from the lexer.
106 void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); }
107
108 /// Parse the body of an AST module.
109 LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls);
110
111 /// Try to convert the given expression to `type`. Returns failure and emits
112 /// an error if a conversion is not viable. On failure, `noteAttachFn` is
113 /// invoked to attach notes to the emitted error diagnostic. On success,
114 /// `expr` is updated to the expression used to convert to `type`.
115 LogicalResult convertExpressionTo(
116 ast::Expr *&expr, ast::Type type,
117 function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {});
118 LogicalResult
119 convertOpExpressionTo(ast::Expr *&expr, ast::OperationType exprType,
120 ast::Type type,
122 LogicalResult convertTupleExpressionTo(
123 ast::Expr *&expr, ast::TupleType exprType, ast::Type type,
125 function_ref<void(ast::Diagnostic &diag)> noteAttachFn);
126
127 /// Given an operation expression, convert it to a Value or ValueRange
128 /// typed expression.
129 ast::Expr *convertOpToValue(const ast::Expr *opExpr);
130
131 /// Lookup ODS information for the given operation, returns nullptr if no
132 /// information is found.
133 const ods::Operation *lookupODSOperation(std::optional<StringRef> opName) {
134 return opName ? ctx.getODSContext().lookupOperation(*opName) : nullptr;
135 }
136
137 /// Process the given documentation string, or return an empty string if
138 /// documentation isn't enabled.
139 StringRef processDoc(StringRef doc) {
140 return enableDocumentation ? doc : StringRef();
141 }
142
143 /// Process the given documentation string and format it, or return an empty
144 /// string if documentation isn't enabled.
145 std::string processAndFormatDoc(const Twine &doc) {
146 if (!enableDocumentation)
147 return "";
148 std::string docStr;
149 {
150 llvm::raw_string_ostream docOS(docStr);
151 std::string tmpDocStr = doc.str();
152 raw_indented_ostream(docOS).printReindented(
153 StringRef(tmpDocStr).rtrim(" \t"));
154 }
155 return docStr;
156 }
157
158 //===--------------------------------------------------------------------===//
159 // Directives
160
161 LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls);
162 LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls);
163 LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc,
164 SmallVectorImpl<ast::Decl *> &decls);
165
166 /// Process the records of a parsed tablegen include file.
167 void processTdIncludeRecords(const llvm::RecordKeeper &tdRecords,
168 SmallVectorImpl<ast::Decl *> &decls);
169
170 /// Create a user defined native constraint for a constraint imported from
171 /// ODS.
172 template <typename ConstraintT>
173 ast::Decl *
174 createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock,
175 SMRange loc, ast::Type type,
176 StringRef nativeType, StringRef docString);
177 template <typename ConstraintT>
178 ast::Decl *
179 createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
180 SMRange loc, ast::Type type,
181 StringRef nativeType);
182
183 //===--------------------------------------------------------------------===//
184 // Decls
185
186 /// This structure contains the set of pattern metadata that may be parsed.
187 struct ParsedPatternMetadata {
188 std::optional<uint16_t> benefit;
189 bool hasBoundedRecursion = false;
190 };
191
192 FailureOr<ast::Decl *> parseTopLevelDecl();
193 FailureOr<ast::NamedAttributeDecl *>
194 parseNamedAttributeDecl(std::optional<StringRef> parentOpName);
195
196 /// Parse an argument variable as part of the signature of a
197 /// UserConstraintDecl or UserRewriteDecl.
198 FailureOr<ast::VariableDecl *> parseArgumentDecl();
199
200 /// Parse a result variable as part of the signature of a UserConstraintDecl
201 /// or UserRewriteDecl.
202 FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum);
203
204 /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being
205 /// defined in a non-global context.
206 FailureOr<ast::UserConstraintDecl *>
207 parseUserConstraintDecl(bool isInline = false);
208
209 /// Parse an inline UserConstraintDecl. An inline decl is one defined in a
210 /// non-global context, such as within a Pattern/Constraint/etc.
211 FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl();
212
213 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
214 /// PDLL constructs.
215 FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl(
216 const ast::Name &name, bool isInline,
217 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
218 ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
219
220 /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being
221 /// defined in a non-global context.
222 FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false);
223
224 /// Parse an inline UserRewriteDecl. An inline decl is one defined in a
225 /// non-global context, such as within a Pattern/Rewrite/etc.
226 FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl();
227
228 /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
229 /// PDLL constructs.
230 FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl(
231 const ast::Name &name, bool isInline,
232 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
233 ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
234
235 /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have
236 /// effectively the same syntax, and only differ on slight semantics (given
237 /// the different parsing contexts).
238 template <typename T, typename ParseUserPDLLDeclFnT>
239 FailureOr<T *> parseUserConstraintOrRewriteDecl(
240 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
241 StringRef anonymousNamePrefix, bool isInline);
242
243 /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl.
244 /// These decls have effectively the same syntax.
245 template <typename T>
246 FailureOr<T *> parseUserNativeConstraintOrRewriteDecl(
247 const ast::Name &name, bool isInline,
248 ArrayRef<ast::VariableDecl *> arguments,
249 ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
250
251 /// Parse the functional signature (i.e. the arguments and results) of a
252 /// UserConstraintDecl or UserRewriteDecl.
253 LogicalResult parseUserConstraintOrRewriteSignature(
254 SmallVectorImpl<ast::VariableDecl *> &arguments,
255 SmallVectorImpl<ast::VariableDecl *> &results,
256 ast::DeclScope *&argumentScope, ast::Type &resultType);
257
258 /// Validate the return (which if present is specified by bodyIt) of a
259 /// UserConstraintDecl or UserRewriteDecl.
260 LogicalResult validateUserConstraintOrRewriteReturn(
261 StringRef declType, ast::CompoundStmt *body,
262 ArrayRef<ast::Stmt *>::iterator bodyIt,
263 ArrayRef<ast::Stmt *>::iterator bodyE,
264 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType);
265
266 FailureOr<ast::CompoundStmt *>
267 parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
268 bool expectTerminalSemicolon = true);
269 FailureOr<ast::CompoundStmt *> parsePatternLambdaBody();
270 FailureOr<ast::Decl *> parsePatternDecl();
271 LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
272
273 /// Check to see if a decl has already been defined with the given name, if
274 /// one has emit and error and return failure. Returns success otherwise.
275 LogicalResult checkDefineNamedDecl(const ast::Name &name);
276
277 /// Try to define a variable decl with the given components, returns the
278 /// variable on success.
279 FailureOr<ast::VariableDecl *>
280 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
281 ast::Expr *initExpr,
282 ArrayRef<ast::ConstraintRef> constraints);
283 FailureOr<ast::VariableDecl *>
284 defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
285 ArrayRef<ast::ConstraintRef> constraints);
286
287 /// Parse the constraint reference list for a variable decl.
288 LogicalResult parseVariableDeclConstraintList(
289 SmallVectorImpl<ast::ConstraintRef> &constraints);
290
291 /// Parse the expression used within a type constraint, e.g. Attr<type-expr>.
292 FailureOr<ast::Expr *> parseTypeConstraintExpr();
293
294 /// Try to parse a single reference to a constraint. `typeConstraint` is the
295 /// location of a previously parsed type constraint for the entity that will
296 /// be constrained by the parsed constraint. `existingConstraints` are any
297 /// existing constraints that have already been parsed for the same entity
298 /// that will be constrained by this constraint. `allowInlineTypeConstraints`
299 /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`.
300 FailureOr<ast::ConstraintRef>
301 parseConstraint(std::optional<SMRange> &typeConstraint,
302 ArrayRef<ast::ConstraintRef> existingConstraints,
303 bool allowInlineTypeConstraints);
304
305 /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl
306 /// argument or result variable. The constraints for these variables do not
307 /// allow inline type constraints, and only permit a single constraint.
308 FailureOr<ast::ConstraintRef> parseArgOrResultConstraint();
309
310 //===--------------------------------------------------------------------===//
311 // Exprs
312
313 FailureOr<ast::Expr *> parseExpr();
314
315 /// Identifier expressions.
316 FailureOr<ast::Expr *> parseAttributeExpr();
317 FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr,
318 bool isNegated = false);
319 FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
320 FailureOr<ast::Expr *> parseIdentifierExpr();
321 FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
322 FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
323 FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
324 FailureOr<ast::Expr *> parseNegatedExpr();
325 FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
326 FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
327 FailureOr<ast::Expr *>
328 parseOperationExpr(OpResultTypeContext inputResultTypeContext =
329 OpResultTypeContext::Explicit);
330 FailureOr<ast::Expr *> parseTupleExpr();
331 FailureOr<ast::Expr *> parseTypeExpr();
332 FailureOr<ast::Expr *> parseUnderscoreExpr();
333
334 //===--------------------------------------------------------------------===//
335 // Stmts
336
337 FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true);
338 FailureOr<ast::CompoundStmt *> parseCompoundStmt();
339 FailureOr<ast::EraseStmt *> parseEraseStmt();
340 FailureOr<ast::LetStmt *> parseLetStmt();
341 FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
342 FailureOr<ast::ReturnStmt *> parseReturnStmt();
343 FailureOr<ast::RewriteStmt *> parseRewriteStmt();
344
345 //===--------------------------------------------------------------------===//
346 // Creation+Analysis
347 //===--------------------------------------------------------------------===//
348
349 //===--------------------------------------------------------------------===//
350 // Decls
351
352 /// Try to extract a callable from the given AST node. Returns nullptr on
353 /// failure.
354 ast::CallableDecl *tryExtractCallableDecl(ast::Node *node);
355
356 /// Try to create a pattern decl with the given components, returning the
357 /// Pattern on success.
358 FailureOr<ast::PatternDecl *>
359 createPatternDecl(SMRange loc, const ast::Name *name,
360 const ParsedPatternMetadata &metadata,
361 ast::CompoundStmt *body);
362
363 /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set
364 /// of results, defined as part of the signature.
365 ast::Type
366 createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results);
367
368 /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl.
369 template <typename T>
370 FailureOr<T *> createUserPDLLConstraintOrRewriteDecl(
371 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
372 ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
373 ast::CompoundStmt *body);
374
375 /// Try to create a variable decl with the given components, returning the
376 /// Variable on success.
377 FailureOr<ast::VariableDecl *>
378 createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
379 ArrayRef<ast::ConstraintRef> constraints);
380
381 /// Create a variable for an argument or result defined as part of the
382 /// signature of a UserConstraintDecl/UserRewriteDecl.
383 FailureOr<ast::VariableDecl *>
384 createArgOrResultVariableDecl(StringRef name, SMRange loc,
385 const ast::ConstraintRef &constraint);
386
387 /// Validate the constraints used to constraint a variable decl.
388 /// `inferredType` is the type of the variable inferred by the constraints
389 /// within the list, and is updated to the most refined type as determined by
390 /// the constraints. Returns success if the constraint list is valid, failure
391 /// otherwise.
392 LogicalResult
393 validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
394 ast::Type &inferredType);
395 /// Validate a single reference to a constraint. `inferredType` contains the
396 /// currently inferred variabled type and is refined within the type defined
397 /// by the constraint. Returns success if the constraint is valid, failure
398 /// otherwise.
399 LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref,
400 ast::Type &inferredType);
401 LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr);
402 LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr);
403
404 //===--------------------------------------------------------------------===//
405 // Exprs
406
407 FailureOr<ast::CallExpr *>
408 createCallExpr(SMRange loc, ast::Expr *parentExpr,
409 MutableArrayRef<ast::Expr *> arguments,
410 bool isNegated = false);
411 FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl);
412 FailureOr<ast::DeclRefExpr *>
413 createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
414 ArrayRef<ast::ConstraintRef> constraints);
415 FailureOr<ast::MemberAccessExpr *>
416 createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc);
417
418 /// Validate the member access `name` into the given parent expression. On
419 /// success, this also returns the type of the member accessed.
420 FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr,
421 StringRef name, SMRange loc);
422 FailureOr<ast::OperationExpr *>
423 createOperationExpr(SMRange loc, const ast::OpNameDecl *name,
424 OpResultTypeContext resultTypeContext,
425 SmallVectorImpl<ast::Expr *> &operands,
426 MutableArrayRef<ast::NamedAttributeDecl *> attributes,
427 SmallVectorImpl<ast::Expr *> &results);
428 LogicalResult
429 validateOperationOperands(SMRange loc, std::optional<StringRef> name,
430 const ods::Operation *odsOp,
431 SmallVectorImpl<ast::Expr *> &operands);
432 LogicalResult validateOperationResults(SMRange loc,
433 std::optional<StringRef> name,
434 const ods::Operation *odsOp,
435 SmallVectorImpl<ast::Expr *> &results);
436 void checkOperationResultTypeInferrence(SMRange loc, StringRef name,
437 const ods::Operation *odsOp);
438 LogicalResult validateOperationOperandsOrResults(
439 StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc,
440 std::optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values,
441 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
442 ast::RangeType rangeTy);
443 FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
444 ArrayRef<ast::Expr *> elements,
445 ArrayRef<StringRef> elementNames);
446
447 //===--------------------------------------------------------------------===//
448 // Stmts
449
450 FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp);
451 FailureOr<ast::ReplaceStmt *>
452 createReplaceStmt(SMRange loc, ast::Expr *rootOp,
453 MutableArrayRef<ast::Expr *> replValues);
454 FailureOr<ast::RewriteStmt *>
455 createRewriteStmt(SMRange loc, ast::Expr *rootOp,
456 ast::CompoundStmt *rewriteBody);
457
458 //===--------------------------------------------------------------------===//
459 // Code Completion
460 //===--------------------------------------------------------------------===//
461
462 /// The set of various code completion methods. Every completion method
463 /// returns `failure` to stop the parsing process after providing completion
464 /// results.
465
466 LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr);
467 LogicalResult codeCompleteAttributeName(std::optional<StringRef> opName);
468 LogicalResult codeCompleteConstraintName(ast::Type inferredType,
469 bool allowInlineTypeConstraints);
470 LogicalResult codeCompleteDialectName();
471 LogicalResult codeCompleteOperationName(StringRef dialectName);
472 LogicalResult codeCompletePatternMetadata();
473 LogicalResult codeCompleteIncludeFilename(StringRef curPath);
474
475 void codeCompleteCallSignature(ast::Node *parent, unsigned currentNumArgs);
476 void codeCompleteOperationOperandsSignature(std::optional<StringRef> opName,
477 unsigned currentNumOperands);
478 void codeCompleteOperationResultsSignature(std::optional<StringRef> opName,
479 unsigned currentNumResults);
480
481 //===--------------------------------------------------------------------===//
482 // Lexer Utilities
483 //===--------------------------------------------------------------------===//
484
485 /// If the current token has the specified kind, consume it and return true.
486 /// If not, return false.
487 bool consumeIf(Token::Kind kind) {
488 if (curToken.isNot(kind))
489 return false;
490 consumeToken(kind);
491 return true;
492 }
493
494 /// Advance the current lexer onto the next token.
495 void consumeToken() {
496 assert(curToken.isNot(Token::eof, Token::error) &&
497 "shouldn't advance past EOF or errors");
498 curToken = lexer.lexToken();
499 }
500
501 /// Advance the current lexer onto the next token, asserting what the expected
502 /// current token is. This is preferred to the above method because it leads
503 /// to more self-documenting code with better checking.
504 void consumeToken(Token::Kind kind) {
505 assert(curToken.is(kind) && "consumed an unexpected token");
506 consumeToken();
507 }
508
509 /// Reset the lexer to the location at the given position.
510 void resetToken(SMRange tokLoc) {
511 lexer.resetPointer(tokLoc.Start.getPointer());
512 curToken = lexer.lexToken();
513 }
514
515 /// Consume the specified token if present and return success. On failure,
516 /// output a diagnostic and return failure.
517 LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
518 if (curToken.getKind() != kind)
519 return emitError(curToken.getLoc(), msg);
520 consumeToken();
521 return success();
522 }
523 LogicalResult emitError(SMRange loc, const Twine &msg) {
524 lexer.emitError(loc, msg);
525 return failure();
526 }
527 LogicalResult emitError(const Twine &msg) {
528 return emitError(curToken.getLoc(), msg);
529 }
530 LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc,
531 const Twine &note) {
532 lexer.emitErrorAndNote(loc, msg, noteLoc, note);
533 return failure();
534 }
535
536 //===--------------------------------------------------------------------===//
537 // Fields
538 //===--------------------------------------------------------------------===//
539
540 /// The owning AST context.
541 ast::Context &ctx;
542
543 /// The lexer of this parser.
544 Lexer lexer;
545
546 /// The current token within the lexer.
547 Token curToken;
548
549 /// A flag indicating if the parser should add documentation to AST nodes when
550 /// viable.
551 bool enableDocumentation;
552
553 /// The most recently defined decl scope.
554 ast::DeclScope *curDeclScope = nullptr;
555 llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;
556
557 /// The current context of the parser.
558 ParserContext parserContext = ParserContext::Global;
559
560 /// Cached types to simplify verification and expression creation.
561 ast::Type typeTy, valueTy;
562 ast::RangeType typeRangeTy, valueRangeTy;
563 ast::Type attrTy;
564
565 /// A counter used when naming anonymous constraints and rewrites.
566 unsigned anonymousDeclNameCounter = 0;
567
568 /// The optional code completion context.
569 CodeCompleteContext *codeCompleteContext;
570};
571} // namespace
572
573FailureOr<ast::Module *> Parser::parseModule() {
574 SMLoc moduleLoc = curToken.getStartLoc();
575 pushDeclScope();
576
577 // Parse the top-level decls of the module.
578 SmallVector<ast::Decl *> decls;
579 if (failed(parseModuleBody(decls)))
580 return popDeclScope(), failure();
581
582 popDeclScope();
583 return ast::Module::create(ctx, moduleLoc, decls);
584}
585
586LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) {
587 while (curToken.isNot(Token::eof)) {
588 if (curToken.is(Token::directive)) {
589 if (failed(parseDirective(decls)))
590 return failure();
591 continue;
592 }
593
594 FailureOr<ast::Decl *> decl = parseTopLevelDecl();
595 if (failed(decl))
596 return failure();
597 decls.push_back(*decl);
598 }
599 return success();
600}
601
602ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) {
603 return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr,
604 valueRangeTy);
605}
606
607LogicalResult Parser::convertExpressionTo(
608 ast::Expr *&expr, ast::Type type,
609 function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
610 ast::Type exprType = expr->getType();
611 if (exprType == type)
612 return success();
613
614 auto emitConvertError = [&]() -> ast::InFlightDiagnostic {
615 ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError(
616 expr->getLoc(), llvm::formatv("unable to convert expression of type "
617 "`{0}` to the expected type of "
618 "`{1}`",
619 exprType, type));
620 if (noteAttachFn)
621 noteAttachFn(*diag);
622 return diag;
623 };
624
625 if (auto exprOpType = dyn_cast<ast::OperationType>(exprType))
626 return convertOpExpressionTo(expr, exprOpType, type, emitConvertError);
627
628 // FIXME: Decide how to allow/support converting a single result to multiple,
629 // and multiple to a single result. For now, we just allow Single->Range,
630 // but this isn't something really supported in the PDL dialect. We should
631 // figure out some way to support both.
632 if ((exprType == valueTy || exprType == valueRangeTy) &&
633 (type == valueTy || type == valueRangeTy))
634 return success();
635 if ((exprType == typeTy || exprType == typeRangeTy) &&
636 (type == typeTy || type == typeRangeTy))
637 return success();
638
639 // Handle tuple types.
640 if (auto exprTupleType = dyn_cast<ast::TupleType>(exprType))
641 return convertTupleExpressionTo(expr, exprTupleType, type, emitConvertError,
642 noteAttachFn);
643
644 return emitConvertError();
645}
646
647LogicalResult Parser::convertOpExpressionTo(
648 ast::Expr *&expr, ast::OperationType exprType, ast::Type type,
649 function_ref<ast::InFlightDiagnostic()> emitErrorFn) {
650 // Two operation types are compatible if they have the same name, or if the
651 // expected type is more general.
652 if (auto opType = dyn_cast<ast::OperationType>(type)) {
653 if (opType.getName())
654 return emitErrorFn();
655 return success();
656 }
657
658 // An operation can always convert to a ValueRange.
659 if (type == valueRangeTy) {
660 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
661 valueRangeTy);
662 return success();
663 }
664
665 // Allow conversion to a single value by constraining the result range.
666 if (type == valueTy) {
667 // If the operation is registered, we can verify if it can ever have a
668 // single result.
669 if (const ods::Operation *odsOp = exprType.getODSOperation()) {
670 if (odsOp->getResults().empty()) {
671 return emitErrorFn()->attachNote(
672 llvm::formatv("see the definition of `{0}`, which was defined "
673 "with zero results",
674 odsOp->getName()),
675 odsOp->getLoc());
676 }
677
678 unsigned numSingleResults = llvm::count_if(
679 odsOp->getResults(), [](const ods::OperandOrResult &result) {
680 return result.getVariableLengthKind() ==
681 ods::VariableLengthKind::Single;
682 });
683 if (numSingleResults > 1) {
684 return emitErrorFn()->attachNote(
685 llvm::formatv("see the definition of `{0}`, which was defined "
686 "with at least {1} results",
687 odsOp->getName(), numSingleResults),
688 odsOp->getLoc());
689 }
690 }
691
692 expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
693 valueTy);
694 return success();
695 }
696 return emitErrorFn();
697}
698
699LogicalResult Parser::convertTupleExpressionTo(
700 ast::Expr *&expr, ast::TupleType exprType, ast::Type type,
701 function_ref<ast::InFlightDiagnostic()> emitErrorFn,
702 function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
703 // Handle conversions between tuples.
704 if (auto tupleType = dyn_cast<ast::TupleType>(type)) {
705 if (tupleType.size() != exprType.size())
706 return emitErrorFn();
707
708 // Build a new tuple expression using each of the elements of the current
709 // tuple.
710 SmallVector<ast::Expr *> newExprs;
711 for (unsigned i = 0, e = exprType.size(); i < e; ++i) {
712 newExprs.push_back(ast::MemberAccessExpr::create(
713 ctx, expr->getLoc(), expr, llvm::to_string(i),
714 exprType.getElementTypes()[i]));
715
716 auto diagFn = [&](ast::Diagnostic &diag) {
717 diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`",
718 i, exprType));
719 if (noteAttachFn)
720 noteAttachFn(diag);
721 };
722 if (failed(convertExpressionTo(newExprs.back(),
723 tupleType.getElementTypes()[i], diagFn)))
724 return failure();
725 }
726 expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs,
727 tupleType.getElementNames());
728 return success();
729 }
730
731 // Handle conversion to a range.
732 auto convertToRange = [&](ArrayRef<ast::Type> allowedElementTypes,
733 ast::RangeType resultTy) -> LogicalResult {
734 // TODO: We currently only allow range conversion within a rewrite context.
735 if (parserContext != ParserContext::Rewrite) {
736 return emitErrorFn()->attachNote("Tuple to Range conversion is currently "
737 "only allowed within a rewrite context");
738 }
739
740 // All of the tuple elements must be allowed types.
741 for (ast::Type elementType : exprType.getElementTypes())
742 if (!llvm::is_contained(allowedElementTypes, elementType))
743 return emitErrorFn();
744
745 // Build a new tuple expression using each of the elements of the current
746 // tuple.
747 SmallVector<ast::Expr *> newExprs;
748 for (unsigned i = 0, e = exprType.size(); i < e; ++i) {
749 newExprs.push_back(ast::MemberAccessExpr::create(
750 ctx, expr->getLoc(), expr, llvm::to_string(i),
751 exprType.getElementTypes()[i]));
752 }
753 expr = ast::RangeExpr::create(ctx, expr->getLoc(), newExprs, resultTy);
754 return success();
755 };
756 if (type == valueRangeTy)
757 return convertToRange({valueTy, valueRangeTy}, valueRangeTy);
758 if (type == typeRangeTy)
759 return convertToRange({typeTy, typeRangeTy}, typeRangeTy);
760
761 return emitErrorFn();
762}
763
764//===----------------------------------------------------------------------===//
765// Directives
766//===----------------------------------------------------------------------===//
767
768LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) {
769 StringRef directive = curToken.getSpelling();
770 if (directive == "#include")
771 return parseInclude(decls);
772
773 return emitError("unknown directive `" + directive + "`");
774}
775
776LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) {
777 SMRange loc = curToken.getLoc();
778 consumeToken(Token::directive);
779
780 // Handle code completion of the include file path.
781 if (curToken.is(Token::code_complete_string))
782 return codeCompleteIncludeFilename(curToken.getStringValue());
783
784 // Parse the file being included.
785 if (!curToken.isString())
786 return emitError(loc,
787 "expected string file name after `include` directive");
788 SMRange fileLoc = curToken.getLoc();
789 std::string filenameStr = curToken.getStringValue();
790 StringRef filename = filenameStr;
791 consumeToken();
792
793 // Check the type of include. If ending with `.pdll`, this is another pdl file
794 // to be parsed along with the current module.
795 if (filename.ends_with(".pdll")) {
796 if (failed(lexer.pushInclude(filename, fileLoc)))
797 return emitError(fileLoc,
798 "unable to open include file `" + filename + "`");
799
800 // If we added the include successfully, parse it into the current module.
801 // Make sure to update to the next token after we finish parsing the nested
802 // file.
803 curToken = lexer.lexToken();
804 LogicalResult result = parseModuleBody(decls);
805 curToken = lexer.lexToken();
806 return result;
807 }
808
809 // Otherwise, this must be a `.td` include.
810 if (filename.ends_with(".td"))
811 return parseTdInclude(filename, fileLoc, decls);
812
813 return emitError(fileLoc,
814 "expected include filename to end with `.pdll` or `.td`");
815}
816
817LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc,
818 SmallVectorImpl<ast::Decl *> &decls) {
819 llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr();
820
821 // Use the source manager to open the file, but don't yet add it.
822 std::string includedFile;
823 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer =
824 parserSrcMgr.OpenIncludeFile(filename.str(), includedFile);
825 if (!includeBuffer)
826 return emitError(fileLoc, "unable to open include file `" + filename + "`");
827
828 // Setup the source manager for parsing the tablegen file.
829 llvm::SourceMgr tdSrcMgr;
830 tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc());
831 tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs());
832 tdSrcMgr.setVirtualFileSystem(llvm::vfs::getRealFileSystem());
833
834 // This class provides a context argument for the llvm::SourceMgr diagnostic
835 // handler.
836 struct DiagHandlerContext {
837 Parser &parser;
838 StringRef filename;
839 llvm::SMRange loc;
840 } handlerContext{*this, filename, fileLoc};
841
842 // Set the diagnostic handler for the tablegen source manager.
843 tdSrcMgr.setDiagHandler(
844 [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) {
845 auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext);
846 (void)ctx->parser.emitError(
847 ctx->loc,
848 llvm::formatv("error while processing include file `{0}`: {1}",
849 ctx->filename, diag.getMessage()));
850 },
851 &handlerContext);
852
853 // Parse the tablegen file.
854 llvm::RecordKeeper tdRecords;
855 if (llvm::TableGenParseFile(tdSrcMgr, tdRecords))
856 return failure();
857
858 // Process the parsed records.
859 processTdIncludeRecords(tdRecords, decls);
860
861 // After we are done processing, move all of the tablegen source buffers to
862 // the main parser source mgr. This allows for directly using source locations
863 // from the .td files without needing to remap them.
864 parserSrcMgr.takeSourceBuffersFrom(tdSrcMgr, fileLoc.End);
865 return success();
866}
867
868void Parser::processTdIncludeRecords(const llvm::RecordKeeper &tdRecords,
869 SmallVectorImpl<ast::Decl *> &decls) {
870 // Return the length kind of the given value.
871 auto getLengthKind = [](const auto &value) {
872 if (value.isOptional())
873 return ods::VariableLengthKind::Optional;
874 return value.isVariadic() ? ods::VariableLengthKind::Variadic
875 : ods::VariableLengthKind::Single;
876 };
877
878 // Insert a type constraint into the ODS context.
879 ods::Context &odsContext = ctx.getODSContext();
880 auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst)
881 -> const ods::TypeConstraint & {
882 return odsContext.insertTypeConstraint(
883 cst.constraint.getUniqueDefName(),
884 processDoc(cst.constraint.getSummary()), cst.constraint.getCppType());
885 };
886 auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange {
887 return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)};
888 };
889
890 // Process the parsed tablegen records to build ODS information.
891 /// Operations.
892 for (const llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) {
893 tblgen::Operator op(def);
894
895 // Check to see if this operation is known to support type inferrence.
896 bool supportsResultTypeInferrence =
897 op.getTrait("::mlir::InferTypeOpInterface::Trait");
898
899 auto [odsOp, inserted] = odsContext.insertOperation(
900 op.getOperationName(), processDoc(op.getSummary()),
901 processAndFormatDoc(op.getDescription()), op.getQualCppClassName(),
902 supportsResultTypeInferrence, op.getLoc().front());
903
904 // Ignore operations that have already been added.
905 if (!inserted)
906 continue;
907
908 for (const tblgen::NamedAttribute &attr : op.getAttributes()) {
909 odsOp->appendAttribute(attr.name, attr.attr.isOptional(),
910 odsContext.insertAttributeConstraint(
911 attr.attr.getUniqueDefName(),
912 processDoc(attr.attr.getSummary()),
913 attr.attr.getStorageType()));
914 }
915 for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) {
916 odsOp->appendOperand(operand.name, getLengthKind(operand),
917 addTypeConstraint(operand));
918 }
919 for (const tblgen::NamedTypeConstraint &result : op.getResults()) {
920 odsOp->appendResult(result.name, getLengthKind(result),
921 addTypeConstraint(result));
922 }
923 }
924
925 auto shouldBeSkipped = [this](const llvm::Record *def) {
926 return def->isAnonymous() || curDeclScope->lookup(def->getName()) ||
927 def->isSubClassOf("DeclareInterfaceMethods");
928 };
929
930 /// Attr constraints.
931 for (const llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) {
932 if (shouldBeSkipped(def))
933 continue;
934
935 tblgen::Attribute constraint(def);
936 decls.push_back(createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
937 constraint, convertLocToRange(def->getLoc().front()), attrTy,
938 constraint.getStorageType()));
939 }
940 /// Type constraints.
941 for (const llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) {
942 if (shouldBeSkipped(def))
943 continue;
944
945 tblgen::TypeConstraint constraint(def);
946 decls.push_back(createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
947 constraint, convertLocToRange(def->getLoc().front()), typeTy,
948 constraint.getCppType()));
949 }
950 /// OpInterfaces.
951 ast::Type opTy = ast::OperationType::get(ctx);
952 for (const llvm::Record *def :
953 tdRecords.getAllDerivedDefinitions("OpInterface")) {
954 if (shouldBeSkipped(def))
955 continue;
956
957 SMRange loc = convertLocToRange(def->getLoc().front());
958
959 std::string cppClassName =
960 llvm::formatv("{0}::{1}", def->getValueAsString("cppNamespace"),
961 def->getValueAsString("cppInterfaceName"))
962 .str();
963 std::string codeBlock =
964 llvm::formatv("return ::mlir::success(llvm::isa<{0}>(self));",
965 cppClassName)
966 .str();
967
968 std::string desc =
969 processAndFormatDoc(def->getValueAsString("description"));
970 decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>(
971 def->getName(), codeBlock, loc, opTy, cppClassName, desc));
972 }
973}
974
975template <typename ConstraintT>
976ast::Decl *Parser::createODSNativePDLLConstraintDecl(
977 StringRef name, StringRef codeBlock, SMRange loc, ast::Type type,
978 StringRef nativeType, StringRef docString) {
979 // Build the single input parameter.
980 ast::DeclScope *argScope = pushDeclScope();
981 auto *paramVar = ast::VariableDecl::create(
982 ctx, ast::Name::create(ctx, "self", loc), type,
983 /*initExpr=*/nullptr, ast::ConstraintRef(ConstraintT::create(ctx, loc)));
984 argScope->add(paramVar);
985 popDeclScope();
986
987 // Build the native constraint.
988 auto *constraintDecl = ast::UserConstraintDecl::createNative(
989 ctx, ast::Name::create(ctx, name, loc), paramVar,
990 /*results=*/{}, codeBlock, ast::TupleType::get(ctx), nativeType);
991 constraintDecl->setDocComment(ctx, docString);
992 curDeclScope->add(constraintDecl);
993 return constraintDecl;
994}
995
996template <typename ConstraintT>
997ast::Decl *
998Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
999 SMRange loc, ast::Type type,
1000 StringRef nativeType) {
1001 // Format the condition template.
1002 tblgen::FmtContext fmtContext;
1003 fmtContext.withSelf("self");
1004 std::string codeBlock = tblgen::tgfmt(
1005 "return ::mlir::success(" + constraint.getConditionTemplate() + ");",
1006 &fmtContext);
1007
1008 // If documentation was enabled, build the doc string for the generated
1009 // constraint. It would be nice to do this lazily, but TableGen information is
1010 // destroyed after we finish parsing the file.
1011 std::string docString;
1012 if (enableDocumentation) {
1013 StringRef desc = constraint.getDescription();
1014 docString = processAndFormatDoc(
1015 constraint.getSummary() +
1016 (desc.empty() ? "" : ("\n\n" + constraint.getDescription())));
1017 }
1018
1019 return createODSNativePDLLConstraintDecl<ConstraintT>(
1020 constraint.getUniqueDefName(), codeBlock, loc, type, nativeType,
1021 docString);
1022}
1023
1024//===----------------------------------------------------------------------===//
1025// Decls
1026//===----------------------------------------------------------------------===//
1027
1028FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
1029 FailureOr<ast::Decl *> decl;
1030 switch (curToken.getKind()) {
1032 decl = parseUserConstraintDecl();
1033 break;
1034 case Token::kw_Pattern:
1035 decl = parsePatternDecl();
1036 break;
1037 case Token::kw_Rewrite:
1038 decl = parseUserRewriteDecl();
1039 break;
1040 default:
1041 return emitError("expected top-level declaration, such as a `Pattern`");
1042 }
1043 if (failed(decl))
1044 return failure();
1045
1046 // If the decl has a name, add it to the current scope.
1047 if (const ast::Name *name = (*decl)->getName()) {
1048 if (failed(checkDefineNamedDecl(*name)))
1049 return failure();
1050 curDeclScope->add(*decl);
1051 }
1052 return decl;
1053}
1054
1055FailureOr<ast::NamedAttributeDecl *>
1056Parser::parseNamedAttributeDecl(std::optional<StringRef> parentOpName) {
1057 // Check for name code completion.
1058 if (curToken.is(Token::code_complete))
1059 return codeCompleteAttributeName(parentOpName);
1060
1061 std::string attrNameStr;
1062 if (curToken.isString())
1063 attrNameStr = curToken.getStringValue();
1064 else if (curToken.is(Token::identifier) || curToken.isKeyword())
1065 attrNameStr = curToken.getSpelling().str();
1066 else
1067 return emitError("expected identifier or string attribute name");
1068 const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc());
1069 consumeToken();
1070
1071 // Check for a value of the attribute.
1072 ast::Expr *attrValue = nullptr;
1073 if (consumeIf(Token::equal)) {
1074 FailureOr<ast::Expr *> attrExpr = parseExpr();
1075 if (failed(attrExpr))
1076 return failure();
1077 attrValue = *attrExpr;
1078 } else {
1079 // If there isn't a concrete value, create an expression representing a
1080 // UnitAttr.
1081 attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit");
1082 }
1083
1084 return ast::NamedAttributeDecl::create(ctx, name, attrValue);
1085}
1086
1087FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody(
1088 function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
1089 bool expectTerminalSemicolon) {
1090 consumeToken(Token::equal_arrow);
1091
1092 // Parse the single statement of the lambda body.
1093 SMLoc bodyStartLoc = curToken.getStartLoc();
1094 pushDeclScope();
1095 FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon);
1096 bool failedToParse =
1097 failed(singleStatement) || failed(processStatementFn(*singleStatement));
1098 popDeclScope();
1099 if (failedToParse)
1100 return failure();
1101
1102 SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc());
1103 return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement);
1104}
1105
1106FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() {
1107 // Ensure that the argument is named.
1108 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword())
1109 return emitError("expected identifier argument name");
1110
1111 // Parse the argument similarly to a normal variable.
1112 StringRef name = curToken.getSpelling();
1113 SMRange nameLoc = curToken.getLoc();
1114 consumeToken();
1115
1116 if (failed(
1117 parseToken(Token::colon, "expected `:` before argument constraint")))
1118 return failure();
1119
1120 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1121 if (failed(cst))
1122 return failure();
1123
1124 return createArgOrResultVariableDecl(name, nameLoc, *cst);
1125}
1126
1127FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) {
1128 // Check to see if this result is named.
1129 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
1130 // Check to see if this name actually refers to a Constraint.
1131 if (!curDeclScope->lookup<ast::ConstraintDecl>(curToken.getSpelling())) {
1132 // If it wasn't a constraint, parse the result similarly to a variable. If
1133 // there is already an existing decl, we will emit an error when defining
1134 // this variable later.
1135 StringRef name = curToken.getSpelling();
1136 SMRange nameLoc = curToken.getLoc();
1137 consumeToken();
1138
1139 if (failed(parseToken(Token::colon,
1140 "expected `:` before result constraint")))
1141 return failure();
1142
1143 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1144 if (failed(cst))
1145 return failure();
1146
1147 return createArgOrResultVariableDecl(name, nameLoc, *cst);
1148 }
1149 }
1150
1151 // If it isn't named, we parse the constraint directly and create an unnamed
1152 // result variable.
1153 FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1154 if (failed(cst))
1155 return failure();
1156
1157 return createArgOrResultVariableDecl("", cst->referenceLoc, *cst);
1158}
1159
1160FailureOr<ast::UserConstraintDecl *>
1161Parser::parseUserConstraintDecl(bool isInline) {
1162 // Constraints and rewrites have very similar formats, dispatch to a shared
1163 // interface for parsing.
1164 return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1165 [&](auto &&...args) {
1166 return this->parseUserPDLLConstraintDecl(args...);
1167 },
1168 ParserContext::Constraint, "constraint", isInline);
1169}
1170
1171FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() {
1172 FailureOr<ast::UserConstraintDecl *> decl =
1173 parseUserConstraintDecl(/*isInline=*/true);
1174 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1175 return failure();
1176
1177 curDeclScope->add(*decl);
1178 return decl;
1179}
1180
1181FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl(
1182 const ast::Name &name, bool isInline,
1183 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1184 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1185 // Push the argument scope back onto the list, so that the body can
1186 // reference arguments.
1187 pushDeclScope(argumentScope);
1188
1189 // Parse the body of the constraint. The body is either defined as a compound
1190 // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`.
1191 ast::CompoundStmt *body;
1192 if (curToken.is(Token::equal_arrow)) {
1193 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1194 [&](ast::Stmt *&stmt) -> LogicalResult {
1195 ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt);
1196 if (!stmtExpr) {
1197 return emitError(stmt->getLoc(),
1198 "expected `Constraint` lambda body to contain a "
1199 "single expression");
1200 }
1201 stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr);
1202 return success();
1203 },
1204 /*expectTerminalSemicolon=*/!isInline);
1205 if (failed(bodyResult))
1206 return failure();
1207 body = *bodyResult;
1208 } else {
1209 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1210 if (failed(bodyResult))
1211 return failure();
1212 body = *bodyResult;
1213
1214 // Verify the structure of the body.
1215 auto bodyIt = body->begin(), bodyE = body->end();
1216 for (; bodyIt != bodyE; ++bodyIt)
1217 if (isa<ast::ReturnStmt>(*bodyIt))
1218 break;
1219 if (failed(validateUserConstraintOrRewriteReturn(
1220 "Constraint", body, bodyIt, bodyE, results, resultType)))
1221 return failure();
1222 }
1223 popDeclScope();
1224
1225 return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1226 name, arguments, results, resultType, body);
1227}
1228
1229FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) {
1230 // Constraints and rewrites have very similar formats, dispatch to a shared
1231 // interface for parsing.
1232 return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1233 [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(args...); },
1234 ParserContext::Rewrite, "rewrite", isInline);
1235}
1236
1237FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() {
1238 FailureOr<ast::UserRewriteDecl *> decl =
1239 parseUserRewriteDecl(/*isInline=*/true);
1240 if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1241 return failure();
1242
1243 curDeclScope->add(*decl);
1244 return decl;
1245}
1246
1247FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl(
1248 const ast::Name &name, bool isInline,
1249 ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1250 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1251 // Push the argument scope back onto the list, so that the body can
1252 // reference arguments.
1253 curDeclScope = argumentScope;
1254 ast::CompoundStmt *body;
1255 if (curToken.is(Token::equal_arrow)) {
1256 FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1257 [&](ast::Stmt *&statement) -> LogicalResult {
1258 if (isa<ast::OpRewriteStmt>(statement))
1259 return success();
1260
1261 ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement);
1262 if (!statementExpr) {
1263 return emitError(
1264 statement->getLoc(),
1265 "expected `Rewrite` lambda body to contain a single expression "
1266 "or an operation rewrite statement; such as `erase`, "
1267 "`replace`, or `rewrite`");
1268 }
1269 statement =
1270 ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr);
1271 return success();
1272 },
1273 /*expectTerminalSemicolon=*/!isInline);
1274 if (failed(bodyResult))
1275 return failure();
1276 body = *bodyResult;
1277 } else {
1278 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1279 if (failed(bodyResult))
1280 return failure();
1281 body = *bodyResult;
1282 }
1283 popDeclScope();
1284
1285 // Verify the structure of the body.
1286 auto bodyIt = body->begin(), bodyE = body->end();
1287 for (; bodyIt != bodyE; ++bodyIt)
1288 if (isa<ast::ReturnStmt>(*bodyIt))
1289 break;
1290 if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt,
1291 bodyE, results, resultType)))
1292 return failure();
1293 return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1294 name, arguments, results, resultType, body);
1295}
1296
1297template <typename T, typename ParseUserPDLLDeclFnT>
1298FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl(
1299 ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
1300 StringRef anonymousNamePrefix, bool isInline) {
1301 SMRange loc = curToken.getLoc();
1302 consumeToken();
1303 llvm::SaveAndRestore saveCtx(parserContext, declContext);
1304
1305 // Parse the name of the decl.
1306 const ast::Name *name = nullptr;
1307 if (curToken.isNot(Token::identifier)) {
1308 // Only inline decls can be un-named. Inline decls are similar to "lambdas"
1309 // in C++, so being unnamed is fine.
1310 if (!isInline)
1311 return emitError("expected identifier name");
1312
1313 // Create a unique anonymous name to use, as the name for this decl is not
1314 // important.
1315 std::string anonName =
1316 llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix,
1317 anonymousDeclNameCounter++)
1318 .str();
1319 name = &ast::Name::create(ctx, anonName, loc);
1320 } else {
1321 // If a name was provided, we can use it directly.
1322 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1323 consumeToken(Token::identifier);
1324 }
1325
1326 // Parse the functional signature of the decl.
1327 SmallVector<ast::VariableDecl *> arguments, results;
1328 ast::DeclScope *argumentScope;
1329 ast::Type resultType;
1330 if (failed(parseUserConstraintOrRewriteSignature(arguments, results,
1331 argumentScope, resultType)))
1332 return failure();
1333
1334 // Check to see which type of constraint this is. If the constraint contains a
1335 // compound body, this is a PDLL decl.
1337 return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results,
1338 resultType);
1339
1340 // Otherwise, this is a native decl.
1341 return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments,
1342 results, resultType);
1343}
1344
1345template <typename T>
1346FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
1347 const ast::Name &name, bool isInline,
1348 ArrayRef<ast::VariableDecl *> arguments,
1349 ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1350 // If followed by a string, the native code body has also been specified.
1351 std::string codeStrStorage;
1352 std::optional<StringRef> optCodeStr;
1353 if (curToken.isString()) {
1354 codeStrStorage = curToken.getStringValue();
1355 optCodeStr = codeStrStorage;
1356 consumeToken();
1357 } else if (isInline) {
1358 return emitError(name.getLoc(),
1359 "external declarations must be declared in global scope");
1360 } else if (curToken.is(Token::error)) {
1361 return failure();
1362 }
1363 if (failed(parseToken(Token::semicolon,
1364 "expected `;` after native declaration")))
1365 return failure();
1366 return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
1367}
1368
1369LogicalResult Parser::parseUserConstraintOrRewriteSignature(
1370 SmallVectorImpl<ast::VariableDecl *> &arguments,
1371 SmallVectorImpl<ast::VariableDecl *> &results,
1372 ast::DeclScope *&argumentScope, ast::Type &resultType) {
1373 // Parse the argument list of the decl.
1374 if (failed(parseToken(Token::l_paren, "expected `(` to start argument list")))
1375 return failure();
1376
1377 argumentScope = pushDeclScope();
1378 if (curToken.isNot(Token::r_paren)) {
1379 do {
1380 FailureOr<ast::VariableDecl *> argument = parseArgumentDecl();
1381 if (failed(argument))
1382 return failure();
1383 arguments.emplace_back(*argument);
1384 } while (consumeIf(Token::comma));
1385 }
1386 popDeclScope();
1387 if (failed(parseToken(Token::r_paren, "expected `)` to end argument list")))
1388 return failure();
1389
1390 // Parse the results of the decl.
1391 pushDeclScope();
1392 if (consumeIf(Token::arrow)) {
1393 auto parseResultFn = [&]() -> LogicalResult {
1394 FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size());
1395 if (failed(result))
1396 return failure();
1397 results.emplace_back(*result);
1398 return success();
1399 };
1400
1401 // Check for a list of results.
1402 if (consumeIf(Token::l_paren)) {
1403 do {
1404 if (failed(parseResultFn()))
1405 return failure();
1406 } while (consumeIf(Token::comma));
1407 if (failed(parseToken(Token::r_paren, "expected `)` to end result list")))
1408 return failure();
1409
1410 // Otherwise, there is only one result.
1411 } else if (failed(parseResultFn())) {
1412 return failure();
1413 }
1414 }
1415 popDeclScope();
1416
1417 // Compute the result type of the decl.
1418 resultType = createUserConstraintRewriteResultType(results);
1419
1420 // Verify that results are only named if there are more than one.
1421 if (results.size() == 1 && !results.front()->getName().getName().empty()) {
1422 return emitError(
1423 results.front()->getLoc(),
1424 "cannot create a single-element tuple with an element label");
1425 }
1426 return success();
1427}
1428
1429LogicalResult Parser::validateUserConstraintOrRewriteReturn(
1430 StringRef declType, ast::CompoundStmt *body,
1431 ArrayRef<ast::Stmt *>::iterator bodyIt,
1432 ArrayRef<ast::Stmt *>::iterator bodyE,
1433 ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) {
1434 // Handle if a `return` was provided.
1435 if (bodyIt != bodyE) {
1436 // Emit an error if we have trailing statements after the return.
1437 if (std::next(bodyIt) != bodyE) {
1438 return emitError(
1439 (*std::next(bodyIt))->getLoc(),
1440 llvm::formatv("`return` terminated the `{0}` body, but found "
1441 "trailing statements afterwards",
1442 declType));
1443 }
1444
1445 // Otherwise if a return wasn't provided, check that no results are
1446 // expected.
1447 } else if (!results.empty()) {
1448 return emitError(
1449 {body->getLoc().End, body->getLoc().End},
1450 llvm::formatv("missing return in a `{0}` expected to return `{1}`",
1451 declType, resultType));
1452 }
1453 return success();
1454}
1455
1456FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
1457 return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult {
1458 if (isa<ast::OpRewriteStmt>(statement))
1459 return success();
1460 return emitError(
1461 statement->getLoc(),
1462 "expected Pattern lambda body to contain a single operation "
1463 "rewrite statement, such as `erase`, `replace`, or `rewrite`");
1464 });
1465}
1466
1467FailureOr<ast::Decl *> Parser::parsePatternDecl() {
1468 SMRange loc = curToken.getLoc();
1469 consumeToken(Token::kw_Pattern);
1470 llvm::SaveAndRestore saveCtx(parserContext, ParserContext::PatternMatch);
1471
1472 // Check for an optional identifier for the pattern name.
1473 const ast::Name *name = nullptr;
1474 if (curToken.is(Token::identifier)) {
1475 name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1476 consumeToken(Token::identifier);
1477 }
1478
1479 // Parse any pattern metadata.
1480 ParsedPatternMetadata metadata;
1481 if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata)))
1482 return failure();
1483
1484 // Parse the pattern body.
1485 ast::CompoundStmt *body;
1486
1487 // Handle a lambda body.
1488 if (curToken.is(Token::equal_arrow)) {
1489 FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody();
1490 if (failed(bodyResult))
1491 return failure();
1492 body = *bodyResult;
1493 } else {
1494 if (curToken.isNot(Token::l_brace))
1495 return emitError("expected `{` or `=>` to start pattern body");
1496 FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1497 if (failed(bodyResult))
1498 return failure();
1499 body = *bodyResult;
1500
1501 // Verify the body of the pattern.
1502 auto bodyIt = body->begin(), bodyE = body->end();
1503 for (; bodyIt != bodyE; ++bodyIt) {
1504 if (isa<ast::ReturnStmt>(*bodyIt)) {
1505 return emitError((*bodyIt)->getLoc(),
1506 "`return` statements are only permitted within a "
1507 "`Constraint` or `Rewrite` body");
1508 }
1509 // Break when we've found the rewrite statement.
1510 if (isa<ast::OpRewriteStmt>(*bodyIt))
1511 break;
1512 }
1513 if (bodyIt == bodyE) {
1514 return emitError(loc,
1515 "expected Pattern body to terminate with an operation "
1516 "rewrite statement, such as `erase`");
1517 }
1518 if (std::next(bodyIt) != bodyE) {
1519 return emitError((*std::next(bodyIt))->getLoc(),
1520 "Pattern body was terminated by an operation "
1521 "rewrite statement, but found trailing statements");
1522 }
1523 }
1524
1525 return createPatternDecl(loc, name, metadata, body);
1526}
1527
1528LogicalResult
1529Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
1530 std::optional<SMRange> benefitLoc;
1531 std::optional<SMRange> hasBoundedRecursionLoc;
1532
1533 do {
1534 // Handle metadata code completion.
1535 if (curToken.is(Token::code_complete))
1536 return codeCompletePatternMetadata();
1537
1538 if (curToken.isNot(Token::identifier))
1539 return emitError("expected pattern metadata identifier");
1540 StringRef metadataStr = curToken.getSpelling();
1541 SMRange metadataLoc = curToken.getLoc();
1542 consumeToken(Token::identifier);
1543
1544 // Parse the benefit metadata: benefit(<integer-value>)
1545 if (metadataStr == "benefit") {
1546 if (benefitLoc) {
1547 return emitErrorAndNote(metadataLoc,
1548 "pattern benefit has already been specified",
1549 *benefitLoc, "see previous definition here");
1550 }
1551 if (failed(parseToken(Token::l_paren,
1552 "expected `(` before pattern benefit")))
1553 return failure();
1554
1555 uint16_t benefitValue = 0;
1556 if (curToken.isNot(Token::integer))
1557 return emitError("expected integral pattern benefit");
1558 if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue))
1559 return emitError(
1560 "expected pattern benefit to fit within a 16-bit integer");
1561 consumeToken(Token::integer);
1562
1563 metadata.benefit = benefitValue;
1564 benefitLoc = metadataLoc;
1565
1566 if (failed(
1567 parseToken(Token::r_paren, "expected `)` after pattern benefit")))
1568 return failure();
1569 continue;
1570 }
1571
1572 // Parse the bounded recursion metadata: recursion
1573 if (metadataStr == "recursion") {
1574 if (hasBoundedRecursionLoc) {
1575 return emitErrorAndNote(
1576 metadataLoc,
1577 "pattern recursion metadata has already been specified",
1578 *hasBoundedRecursionLoc, "see previous definition here");
1579 }
1580 metadata.hasBoundedRecursion = true;
1581 hasBoundedRecursionLoc = metadataLoc;
1582 continue;
1583 }
1584
1585 return emitError(metadataLoc, "unknown pattern metadata");
1586 } while (consumeIf(Token::comma));
1587
1588 return success();
1589}
1590
1591FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() {
1592 consumeToken(Token::less);
1593
1594 FailureOr<ast::Expr *> typeExpr = parseExpr();
1595 if (failed(typeExpr) ||
1596 failed(parseToken(Token::greater,
1597 "expected `>` after variable type constraint")))
1598 return failure();
1599 return typeExpr;
1600}
1601
1602LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) {
1603 assert(curDeclScope && "defining decl outside of a decl scope");
1604 if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) {
1605 return emitErrorAndNote(
1606 name.getLoc(), "`" + name.getName() + "` has already been defined",
1607 lastDecl->getName()->getLoc(), "see previous definition here");
1608 }
1609 return success();
1610}
1611
1612FailureOr<ast::VariableDecl *>
1613Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1614 ast::Expr *initExpr,
1615 ArrayRef<ast::ConstraintRef> constraints) {
1616 assert(curDeclScope && "defining variable outside of decl scope");
1617 const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc);
1618
1619 // If the name of the variable indicates a special variable, we don't add it
1620 // to the scope. This variable is local to the definition point.
1621 if (name.empty() || name == "_") {
1622 return ast::VariableDecl::create(ctx, nameDecl, type, initExpr,
1623 constraints);
1624 }
1625 if (failed(checkDefineNamedDecl(nameDecl)))
1626 return failure();
1627
1628 auto *varDecl =
1629 ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints);
1630 curDeclScope->add(varDecl);
1631 return varDecl;
1632}
1633
1634FailureOr<ast::VariableDecl *>
1635Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1636 ArrayRef<ast::ConstraintRef> constraints) {
1637 return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr,
1638 constraints);
1639}
1640
1641LogicalResult Parser::parseVariableDeclConstraintList(
1642 SmallVectorImpl<ast::ConstraintRef> &constraints) {
1643 std::optional<SMRange> typeConstraint;
1644 auto parseSingleConstraint = [&] {
1645 FailureOr<ast::ConstraintRef> constraint = parseConstraint(
1646 typeConstraint, constraints, /*allowInlineTypeConstraints=*/true);
1647 if (failed(constraint))
1648 return failure();
1649 constraints.push_back(*constraint);
1650 return success();
1651 };
1652
1653 // Check to see if this is a single constraint, or a list.
1654 if (!consumeIf(Token::l_square))
1655 return parseSingleConstraint();
1656
1657 do {
1658 if (failed(parseSingleConstraint()))
1659 return failure();
1660 } while (consumeIf(Token::comma));
1661 return parseToken(Token::r_square, "expected `]` after constraint list");
1662}
1663
1664FailureOr<ast::ConstraintRef>
1665Parser::parseConstraint(std::optional<SMRange> &typeConstraint,
1666 ArrayRef<ast::ConstraintRef> existingConstraints,
1667 bool allowInlineTypeConstraints) {
1668 auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
1669 if (!allowInlineTypeConstraints) {
1670 return emitError(
1671 curToken.getLoc(),
1672 "inline `Attr`, `Value`, and `ValueRange` type constraints are not "
1673 "permitted on arguments or results");
1674 }
1675 if (typeConstraint)
1676 return emitErrorAndNote(
1677 curToken.getLoc(),
1678 "the type of this variable has already been constrained",
1679 *typeConstraint, "see previous constraint location here");
1680 FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr();
1681 if (failed(constraintExpr))
1682 return failure();
1683 typeExpr = *constraintExpr;
1684 typeConstraint = typeExpr->getLoc();
1685 return success();
1686 };
1687
1688 SMRange loc = curToken.getLoc();
1689 switch (curToken.getKind()) {
1690 case Token::kw_Attr: {
1691 consumeToken(Token::kw_Attr);
1692
1693 // Check for a type constraint.
1694 ast::Expr *typeExpr = nullptr;
1695 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1696 return failure();
1697 return ast::ConstraintRef(
1698 ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc);
1699 }
1700 case Token::kw_Op: {
1701 consumeToken(Token::kw_Op);
1702
1703 // Parse an optional operation name. If the name isn't provided, this refers
1704 // to "any" operation.
1705 FailureOr<ast::OpNameDecl *> opName =
1706 parseWrappedOperationName(/*allowEmptyName=*/true);
1707 if (failed(opName))
1708 return failure();
1709
1710 return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName),
1711 loc);
1712 }
1713 case Token::kw_Type:
1714 consumeToken(Token::kw_Type);
1715 return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc);
1717 consumeToken(Token::kw_TypeRange);
1718 return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc),
1719 loc);
1720 case Token::kw_Value: {
1721 consumeToken(Token::kw_Value);
1722
1723 // Check for a type constraint.
1724 ast::Expr *typeExpr = nullptr;
1725 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1726 return failure();
1727
1728 return ast::ConstraintRef(
1729 ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc);
1730 }
1731 case Token::kw_ValueRange: {
1732 consumeToken(Token::kw_ValueRange);
1733
1734 // Check for a type constraint.
1735 ast::Expr *typeExpr = nullptr;
1736 if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1737 return failure();
1738
1739 return ast::ConstraintRef(
1740 ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc);
1741 }
1742
1743 case Token::kw_Constraint: {
1744 // Handle an inline constraint.
1745 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1746 if (failed(decl))
1747 return failure();
1748 return ast::ConstraintRef(*decl, loc);
1749 }
1750 case Token::identifier: {
1751 StringRef constraintName = curToken.getSpelling();
1752 consumeToken(Token::identifier);
1753
1754 // Lookup the referenced constraint.
1755 ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName);
1756 if (!cstDecl) {
1757 return emitError(loc, "unknown reference to constraint `" +
1758 constraintName + "`");
1759 }
1760
1761 // Handle a reference to a proper constraint.
1762 if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl))
1763 return ast::ConstraintRef(cst, loc);
1764
1765 return emitErrorAndNote(
1766 loc, "invalid reference to non-constraint", cstDecl->getLoc(),
1767 "see the definition of `" + constraintName + "` here");
1768 }
1769 // Handle single entity constraint code completion.
1770 case Token::code_complete: {
1771 // Try to infer the current type for use by code completion.
1772 ast::Type inferredType;
1773 if (failed(validateVariableConstraints(existingConstraints, inferredType)))
1774 return failure();
1775
1776 return codeCompleteConstraintName(inferredType, allowInlineTypeConstraints);
1777 }
1778 default:
1779 break;
1780 }
1781 return emitError(loc, "expected identifier constraint");
1782}
1783
1784FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() {
1785 std::optional<SMRange> typeConstraint;
1786 return parseConstraint(typeConstraint, /*existingConstraints=*/{},
1787 /*allowInlineTypeConstraints=*/false);
1788}
1789
1790//===----------------------------------------------------------------------===//
1791// Exprs
1792//===----------------------------------------------------------------------===//
1793
1794FailureOr<ast::Expr *> Parser::parseExpr() {
1795 if (curToken.is(Token::underscore))
1796 return parseUnderscoreExpr();
1797
1798 // Parse the LHS expression.
1799 FailureOr<ast::Expr *> lhsExpr;
1800 switch (curToken.getKind()) {
1801 case Token::kw_attr:
1802 lhsExpr = parseAttributeExpr();
1803 break;
1805 lhsExpr = parseInlineConstraintLambdaExpr();
1806 break;
1807 case Token::kw_not:
1808 lhsExpr = parseNegatedExpr();
1809 break;
1810 case Token::identifier:
1811 lhsExpr = parseIdentifierExpr();
1812 break;
1813 case Token::kw_op:
1814 lhsExpr = parseOperationExpr();
1815 break;
1816 case Token::kw_Rewrite:
1817 lhsExpr = parseInlineRewriteLambdaExpr();
1818 break;
1819 case Token::kw_type:
1820 lhsExpr = parseTypeExpr();
1821 break;
1822 case Token::l_paren:
1823 lhsExpr = parseTupleExpr();
1824 break;
1825 default:
1826 return emitError("expected expression");
1827 }
1828 if (failed(lhsExpr))
1829 return failure();
1830
1831 // Check for an operator expression.
1832 while (true) {
1833 switch (curToken.getKind()) {
1834 case Token::dot:
1835 lhsExpr = parseMemberAccessExpr(*lhsExpr);
1836 break;
1837 case Token::l_paren:
1838 lhsExpr = parseCallExpr(*lhsExpr);
1839 break;
1840 default:
1841 return lhsExpr;
1842 }
1843 if (failed(lhsExpr))
1844 return failure();
1845 }
1846}
1847
1848FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
1849 SMRange loc = curToken.getLoc();
1850 consumeToken(Token::kw_attr);
1851
1852 // If we aren't followed by a `<`, the `attr` keyword is treated as a normal
1853 // identifier.
1854 if (!consumeIf(Token::less)) {
1855 resetToken(loc);
1856 return parseIdentifierExpr();
1857 }
1858
1859 if (!curToken.isString())
1860 return emitError("expected string literal containing MLIR attribute");
1861 std::string attrExpr = curToken.getStringValue();
1862 consumeToken();
1863
1864 loc.End = curToken.getEndLoc();
1865 if (failed(
1866 parseToken(Token::greater, "expected `>` after attribute literal")))
1867 return failure();
1868 return ast::AttributeExpr::create(ctx, loc, attrExpr);
1869}
1870
1871FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr,
1872 bool isNegated) {
1873 consumeToken(Token::l_paren);
1874
1875 // Parse the arguments of the call.
1876 SmallVector<ast::Expr *> arguments;
1877 if (curToken.isNot(Token::r_paren)) {
1878 do {
1879 // Handle code completion for the call arguments.
1880 if (curToken.is(Token::code_complete)) {
1881 codeCompleteCallSignature(parentExpr, arguments.size());
1882 return failure();
1883 }
1884
1885 FailureOr<ast::Expr *> argument = parseExpr();
1886 if (failed(argument))
1887 return failure();
1888 arguments.push_back(*argument);
1889 } while (consumeIf(Token::comma));
1890 }
1891
1892 SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1893 if (failed(parseToken(Token::r_paren, "expected `)` after argument list")))
1894 return failure();
1895
1896 return createCallExpr(loc, parentExpr, arguments, isNegated);
1897}
1898
1899FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
1900 ast::Decl *decl = curDeclScope->lookup(name);
1901 if (!decl)
1902 return emitError(loc, "undefined reference to `" + name + "`");
1903
1904 return createDeclRefExpr(loc, decl);
1905}
1906
1907FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
1908 StringRef name = curToken.getSpelling();
1909 SMRange nameLoc = curToken.getLoc();
1910 consumeToken();
1911
1912 // Check to see if this is a decl ref expression that defines a variable
1913 // inline.
1914 if (consumeIf(Token::colon)) {
1915 SmallVector<ast::ConstraintRef> constraints;
1916 if (failed(parseVariableDeclConstraintList(constraints)))
1917 return failure();
1918 ast::Type type;
1919 if (failed(validateVariableConstraints(constraints, type)))
1920 return failure();
1921 return createInlineVariableExpr(type, name, nameLoc, constraints);
1922 }
1923
1924 return parseDeclRefExpr(name, nameLoc);
1925}
1926
1927FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() {
1928 FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1929 if (failed(decl))
1930 return failure();
1931
1932 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1934}
1935
1936FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() {
1937 FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl();
1938 if (failed(decl))
1939 return failure();
1940
1941 return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1943}
1944
1945FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
1946 SMRange dotLoc = curToken.getLoc();
1947 consumeToken(Token::dot);
1948
1949 // Check for code completion of the member name.
1950 if (curToken.is(Token::code_complete))
1951 return codeCompleteMemberAccess(parentExpr);
1952
1953 // Parse the member name.
1954 Token memberNameTok = curToken;
1955 if (memberNameTok.isNot(Token::identifier, Token::integer) &&
1956 !memberNameTok.isKeyword())
1957 return emitError(dotLoc, "expected identifier or numeric member name");
1958 StringRef memberName = memberNameTok.getSpelling();
1959 SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1960 consumeToken();
1961
1962 return createMemberAccessExpr(parentExpr, memberName, loc);
1963}
1964
1965FailureOr<ast::Expr *> Parser::parseNegatedExpr() {
1966 consumeToken(Token::kw_not);
1967 // Only native constraints are supported after negation
1968 if (!curToken.is(Token::identifier))
1969 return emitError("expected native constraint");
1970 FailureOr<ast::Expr *> identifierExpr = parseIdentifierExpr();
1971 if (failed(identifierExpr))
1972 return failure();
1973 if (!curToken.is(Token::l_paren))
1974 return emitError("expected `(` after function name");
1975 return parseCallExpr(*identifierExpr, /*isNegated = */ true);
1976}
1977
1978FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
1979 SMRange loc = curToken.getLoc();
1980
1981 // Check for code completion for the dialect name.
1982 if (curToken.is(Token::code_complete))
1983 return codeCompleteDialectName();
1984
1985 // Handle the case of an no operation name.
1986 if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) {
1987 if (allowEmptyName)
1988 return ast::OpNameDecl::create(ctx, SMRange());
1989 return emitError("expected dialect namespace");
1990 }
1991 StringRef name = curToken.getSpelling();
1992 consumeToken();
1993
1994 // Otherwise, this is a literal operation name.
1995 if (failed(parseToken(Token::dot, "expected `.` after dialect namespace")))
1996 return failure();
1997
1998 // Check for code completion for the operation name.
1999 if (curToken.is(Token::code_complete))
2000 return codeCompleteOperationName(name);
2001
2002 if (curToken.isNot(Token::identifier) && !curToken.isKeyword())
2003 return emitError("expected operation name after dialect namespace");
2004
2005 name = StringRef(name.data(), name.size() + 1);
2006 do {
2007 name = StringRef(name.data(), name.size() + curToken.getSpelling().size());
2008 loc.End = curToken.getEndLoc();
2009 consumeToken();
2010 } while (curToken.isAny(Token::identifier, Token::dot) ||
2011 curToken.isKeyword());
2012 return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc));
2013}
2014
2015FailureOr<ast::OpNameDecl *>
2016Parser::parseWrappedOperationName(bool allowEmptyName) {
2017 if (!consumeIf(Token::less))
2018 return ast::OpNameDecl::create(ctx, SMRange());
2019
2020 FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName);
2021 if (failed(opNameDecl))
2022 return failure();
2023
2024 if (failed(parseToken(Token::greater, "expected `>` after operation name")))
2025 return failure();
2026 return opNameDecl;
2027}
2028
2029FailureOr<ast::Expr *>
2030Parser::parseOperationExpr(OpResultTypeContext inputResultTypeContext) {
2031 SMRange loc = curToken.getLoc();
2032 consumeToken(Token::kw_op);
2033
2034 // If it isn't followed by a `<`, the `op` keyword is treated as a normal
2035 // identifier.
2036 if (curToken.isNot(Token::less)) {
2037 resetToken(loc);
2038 return parseIdentifierExpr();
2039 }
2040
2041 // Parse the operation name. The name may be elided, in which case the
2042 // operation refers to "any" operation(i.e. a difference between `MyOp` and
2043 // `Operation*`). Operation names within a rewrite context must be named.
2044 bool allowEmptyName = parserContext != ParserContext::Rewrite;
2045 FailureOr<ast::OpNameDecl *> opNameDecl =
2046 parseWrappedOperationName(allowEmptyName);
2047 if (failed(opNameDecl))
2048 return failure();
2049 std::optional<StringRef> opName = (*opNameDecl)->getName();
2050
2051 // Functor used to create an implicit range variable, used for implicit "all"
2052 // operand or results variables.
2053 auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) {
2054 FailureOr<ast::VariableDecl *> rangeVar =
2055 defineVariableDecl("_", loc, type, ast::ConstraintRef(cst, loc));
2056 assert(succeeded(rangeVar) && "expected range variable to be valid");
2057 return ast::DeclRefExpr::create(ctx, loc, *rangeVar, type);
2058 };
2059
2060 // Check for the optional list of operands.
2061 SmallVector<ast::Expr *> operands;
2062 if (!consumeIf(Token::l_paren)) {
2063 // If the operand list isn't specified and we are in a match context, define
2064 // an inplace unconstrained operand range corresponding to all of the
2065 // operands of the operation. This avoids treating zero operands the same
2066 // way as "unconstrained operands".
2067 if (parserContext != ParserContext::Rewrite) {
2068 operands.push_back(createImplicitRangeVar(
2069 ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy));
2070 }
2071 } else if (!consumeIf(Token::r_paren)) {
2072 // If the operand list was specified and non-empty, parse the operands.
2073 do {
2074 // Check for operand signature code completion.
2075 if (curToken.is(Token::code_complete)) {
2076 codeCompleteOperationOperandsSignature(opName, operands.size());
2077 return failure();
2078 }
2079
2080 FailureOr<ast::Expr *> operand = parseExpr();
2081 if (failed(operand))
2082 return failure();
2083 operands.push_back(*operand);
2084 } while (consumeIf(Token::comma));
2085
2086 if (failed(parseToken(Token::r_paren,
2087 "expected `)` after operation operand list")))
2088 return failure();
2089 }
2090
2091 // Check for the optional list of attributes.
2092 SmallVector<ast::NamedAttributeDecl *> attributes;
2093 if (consumeIf(Token::l_brace)) {
2094 do {
2095 FailureOr<ast::NamedAttributeDecl *> decl =
2096 parseNamedAttributeDecl(opName);
2097 if (failed(decl))
2098 return failure();
2099 attributes.emplace_back(*decl);
2100 } while (consumeIf(Token::comma));
2101
2102 if (failed(parseToken(Token::r_brace,
2103 "expected `}` after operation attribute list")))
2104 return failure();
2105 }
2106
2107 // Handle the result types of the operation.
2108 SmallVector<ast::Expr *> resultTypes;
2109 OpResultTypeContext resultTypeContext = inputResultTypeContext;
2110
2111 // Check for an explicit list of result types.
2112 if (consumeIf(Token::arrow)) {
2113 if (failed(parseToken(Token::l_paren,
2114 "expected `(` before operation result type list")))
2115 return failure();
2116
2117 // If result types are provided, initially assume that the operation does
2118 // not rely on type inferrence. We don't assert that it isn't, because we
2119 // may be inferring the value of some type/type range variables, but given
2120 // that these variables may be defined in calls we can't always discern when
2121 // this is the case.
2122 resultTypeContext = OpResultTypeContext::Explicit;
2123
2124 // Handle the case of an empty result list.
2125 if (!consumeIf(Token::r_paren)) {
2126 do {
2127 // Check for result signature code completion.
2128 if (curToken.is(Token::code_complete)) {
2129 codeCompleteOperationResultsSignature(opName, resultTypes.size());
2130 return failure();
2131 }
2132
2133 FailureOr<ast::Expr *> resultTypeExpr = parseExpr();
2134 if (failed(resultTypeExpr))
2135 return failure();
2136 resultTypes.push_back(*resultTypeExpr);
2137 } while (consumeIf(Token::comma));
2138
2139 if (failed(parseToken(Token::r_paren,
2140 "expected `)` after operation result type list")))
2141 return failure();
2142 }
2143 } else if (parserContext != ParserContext::Rewrite) {
2144 // If the result list isn't specified and we are in a match context, define
2145 // an inplace unconstrained result range corresponding to all of the results
2146 // of the operation. This avoids treating zero results the same way as
2147 // "unconstrained results".
2148 resultTypes.push_back(createImplicitRangeVar(
2149 ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy));
2150 } else if (resultTypeContext == OpResultTypeContext::Explicit) {
2151 // If the result list isn't specified and we are in a rewrite, try to infer
2152 // them at runtime instead.
2153 resultTypeContext = OpResultTypeContext::Interface;
2154 }
2155
2156 return createOperationExpr(loc, *opNameDecl, resultTypeContext, operands,
2157 attributes, resultTypes);
2158}
2159
2160FailureOr<ast::Expr *> Parser::parseTupleExpr() {
2161 SMRange loc = curToken.getLoc();
2162 consumeToken(Token::l_paren);
2163
2165 SmallVector<StringRef> elementNames;
2166 SmallVector<ast::Expr *> elements;
2167 if (curToken.isNot(Token::r_paren)) {
2168 do {
2169 // Check for the optional element name assignment before the value.
2170 StringRef elementName;
2171 if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
2172 Token elementNameTok = curToken;
2173 consumeToken();
2174
2175 // The element name is only present if followed by an `=`.
2176 if (consumeIf(Token::equal)) {
2177 elementName = elementNameTok.getSpelling();
2178
2179 // Check to see if this name is already used.
2180 auto elementNameIt =
2181 usedNames.try_emplace(elementName, elementNameTok.getLoc());
2182 if (!elementNameIt.second) {
2183 return emitErrorAndNote(
2184 elementNameTok.getLoc(),
2185 llvm::formatv("duplicate tuple element label `{0}`",
2186 elementName),
2187 elementNameIt.first->getSecond(),
2188 "see previous label use here");
2189 }
2190 } else {
2191 // Otherwise, we treat this as part of an expression so reset the
2192 // lexer.
2193 resetToken(elementNameTok.getLoc());
2194 }
2195 }
2196 elementNames.push_back(elementName);
2197
2198 // Parse the tuple element value.
2199 FailureOr<ast::Expr *> element = parseExpr();
2200 if (failed(element))
2201 return failure();
2202 elements.push_back(*element);
2203 } while (consumeIf(Token::comma));
2204 }
2205 loc.End = curToken.getEndLoc();
2206 if (failed(
2207 parseToken(Token::r_paren, "expected `)` after tuple element list")))
2208 return failure();
2209 return createTupleExpr(loc, elements, elementNames);
2210}
2211
2212FailureOr<ast::Expr *> Parser::parseTypeExpr() {
2213 SMRange loc = curToken.getLoc();
2214 consumeToken(Token::kw_type);
2215
2216 // If we aren't followed by a `<`, the `type` keyword is treated as a normal
2217 // identifier.
2218 if (!consumeIf(Token::less)) {
2219 resetToken(loc);
2220 return parseIdentifierExpr();
2221 }
2222
2223 if (!curToken.isString())
2224 return emitError("expected string literal containing MLIR type");
2225 std::string attrExpr = curToken.getStringValue();
2226 consumeToken();
2227
2228 loc.End = curToken.getEndLoc();
2229 if (failed(parseToken(Token::greater, "expected `>` after type literal")))
2230 return failure();
2231 return ast::TypeExpr::create(ctx, loc, attrExpr);
2232}
2233
2234FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
2235 StringRef name = curToken.getSpelling();
2236 SMRange nameLoc = curToken.getLoc();
2237 consumeToken(Token::underscore);
2238
2239 // Underscore expressions require a constraint list.
2240 if (failed(parseToken(Token::colon, "expected `:` after `_` variable")))
2241 return failure();
2242
2243 // Parse the constraints for the expression.
2244 SmallVector<ast::ConstraintRef> constraints;
2245 if (failed(parseVariableDeclConstraintList(constraints)))
2246 return failure();
2247
2248 ast::Type type;
2249 if (failed(validateVariableConstraints(constraints, type)))
2250 return failure();
2251 return createInlineVariableExpr(type, name, nameLoc, constraints);
2252}
2253
2254//===----------------------------------------------------------------------===//
2255// Stmts
2256//===----------------------------------------------------------------------===//
2257
2258FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) {
2259 FailureOr<ast::Stmt *> stmt;
2260 switch (curToken.getKind()) {
2261 case Token::kw_erase:
2262 stmt = parseEraseStmt();
2263 break;
2264 case Token::kw_let:
2265 stmt = parseLetStmt();
2266 break;
2267 case Token::kw_replace:
2268 stmt = parseReplaceStmt();
2269 break;
2270 case Token::kw_return:
2271 stmt = parseReturnStmt();
2272 break;
2273 case Token::kw_rewrite:
2274 stmt = parseRewriteStmt();
2275 break;
2276 default:
2277 stmt = parseExpr();
2278 break;
2279 }
2280 if (failed(stmt) ||
2281 (expectTerminalSemicolon &&
2282 failed(parseToken(Token::semicolon, "expected `;` after statement"))))
2283 return failure();
2284 return stmt;
2285}
2286
2287FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() {
2288 SMLoc startLoc = curToken.getStartLoc();
2289 consumeToken(Token::l_brace);
2290
2291 // Push a new block scope and parse any nested statements.
2292 pushDeclScope();
2293 SmallVector<ast::Stmt *> statements;
2294 while (curToken.isNot(Token::r_brace)) {
2295 FailureOr<ast::Stmt *> statement = parseStmt();
2296 if (failed(statement))
2297 return popDeclScope(), failure();
2298 statements.push_back(*statement);
2299 }
2300 popDeclScope();
2301
2302 // Consume the end brace.
2303 SMRange location(startLoc, curToken.getEndLoc());
2304 consumeToken(Token::r_brace);
2305
2306 return ast::CompoundStmt::create(ctx, location, statements);
2307}
2308
2309FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
2310 if (parserContext == ParserContext::Constraint)
2311 return emitError("`erase` cannot be used within a Constraint");
2312 SMRange loc = curToken.getLoc();
2313 consumeToken(Token::kw_erase);
2314
2315 // Parse the root operation expression.
2316 FailureOr<ast::Expr *> rootOp = parseExpr();
2317 if (failed(rootOp))
2318 return failure();
2319
2320 return createEraseStmt(loc, *rootOp);
2321}
2322
2323FailureOr<ast::LetStmt *> Parser::parseLetStmt() {
2324 SMRange loc = curToken.getLoc();
2325 consumeToken(Token::kw_let);
2326
2327 // Parse the name of the new variable.
2328 SMRange varLoc = curToken.getLoc();
2329 if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) {
2330 // `_` is a reserved variable name.
2331 if (curToken.is(Token::underscore)) {
2332 return emitError(varLoc,
2333 "`_` may only be used to define \"inline\" variables");
2334 }
2335 return emitError(varLoc,
2336 "expected identifier after `let` to name a new variable");
2337 }
2338 StringRef varName = curToken.getSpelling();
2339 consumeToken();
2340
2341 // Parse the optional set of constraints.
2342 SmallVector<ast::ConstraintRef> constraints;
2343 if (consumeIf(Token::colon) &&
2344 failed(parseVariableDeclConstraintList(constraints)))
2345 return failure();
2346
2347 // Parse the optional initializer expression.
2348 ast::Expr *initializer = nullptr;
2349 if (consumeIf(Token::equal)) {
2350 FailureOr<ast::Expr *> initOrFailure = parseExpr();
2351 if (failed(initOrFailure))
2352 return failure();
2353 initializer = *initOrFailure;
2354
2355 // Check that the constraints are compatible with having an initializer,
2356 // e.g. type constraints cannot be used with initializers.
2357 for (ast::ConstraintRef constraint : constraints) {
2358 LogicalResult result =
2360 .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl,
2361 ast::ValueRangeConstraintDecl>([&](const auto *cst) {
2362 if (cst->getTypeExpr()) {
2363 return this->emitError(
2364 constraint.referenceLoc,
2365 "type constraints are not permitted on variables with "
2366 "initializers");
2367 }
2368 return success();
2369 })
2370 .Default(success());
2371 if (failed(result))
2372 return failure();
2373 }
2374 }
2375
2376 FailureOr<ast::VariableDecl *> varDecl =
2377 createVariableDecl(varName, varLoc, initializer, constraints);
2378 if (failed(varDecl))
2379 return failure();
2380 return ast::LetStmt::create(ctx, loc, *varDecl);
2381}
2382
2383FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
2384 if (parserContext == ParserContext::Constraint)
2385 return emitError("`replace` cannot be used within a Constraint");
2386 SMRange loc = curToken.getLoc();
2387 consumeToken(Token::kw_replace);
2388
2389 // Parse the root operation expression.
2390 FailureOr<ast::Expr *> rootOp = parseExpr();
2391 if (failed(rootOp))
2392 return failure();
2393
2394 if (failed(
2395 parseToken(Token::kw_with, "expected `with` after root operation")))
2396 return failure();
2397
2398 // The replacement portion of this statement is within a rewrite context.
2399 llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
2400
2401 // Parse the replacement values.
2402 SmallVector<ast::Expr *> replValues;
2403 if (consumeIf(Token::l_paren)) {
2404 if (consumeIf(Token::r_paren)) {
2405 return emitError(
2406 loc, "expected at least one replacement value, consider using "
2407 "`erase` if no replacement values are desired");
2408 }
2409
2410 do {
2411 FailureOr<ast::Expr *> replExpr = parseExpr();
2412 if (failed(replExpr))
2413 return failure();
2414 replValues.emplace_back(*replExpr);
2415 } while (consumeIf(Token::comma));
2416
2417 if (failed(parseToken(Token::r_paren,
2418 "expected `)` after replacement values")))
2419 return failure();
2420 } else {
2421 // Handle replacement with an operation uniquely, as the replacement
2422 // operation supports type inferrence from the root operation.
2423 FailureOr<ast::Expr *> replExpr;
2424 if (curToken.is(Token::kw_op))
2425 replExpr = parseOperationExpr(OpResultTypeContext::Replacement);
2426 else
2427 replExpr = parseExpr();
2428 if (failed(replExpr))
2429 return failure();
2430 replValues.emplace_back(*replExpr);
2431 }
2432
2433 return createReplaceStmt(loc, *rootOp, replValues);
2434}
2435
2436FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() {
2437 SMRange loc = curToken.getLoc();
2438 consumeToken(Token::kw_return);
2439
2440 // Parse the result value.
2441 FailureOr<ast::Expr *> resultExpr = parseExpr();
2442 if (failed(resultExpr))
2443 return failure();
2444
2445 return ast::ReturnStmt::create(ctx, loc, *resultExpr);
2446}
2447
2448FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
2449 if (parserContext == ParserContext::Constraint)
2450 return emitError("`rewrite` cannot be used within a Constraint");
2451 SMRange loc = curToken.getLoc();
2452 consumeToken(Token::kw_rewrite);
2453
2454 // Parse the root operation.
2455 FailureOr<ast::Expr *> rootOp = parseExpr();
2456 if (failed(rootOp))
2457 return failure();
2458
2459 if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body")))
2460 return failure();
2461
2462 if (curToken.isNot(Token::l_brace))
2463 return emitError("expected `{` to start rewrite body");
2464
2465 // The rewrite body of this statement is within a rewrite context.
2466 llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
2467
2468 FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt();
2469 if (failed(rewriteBody))
2470 return failure();
2471
2472 // Verify the rewrite body.
2473 for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) {
2474 if (isa<ast::ReturnStmt>(stmt)) {
2475 return emitError(stmt->getLoc(),
2476 "`return` statements are only permitted within a "
2477 "`Constraint` or `Rewrite` body");
2478 }
2479 }
2480
2481 return createRewriteStmt(loc, *rootOp, *rewriteBody);
2482}
2483
2484//===----------------------------------------------------------------------===//
2485// Creation+Analysis
2486//===----------------------------------------------------------------------===//
2487
2488//===----------------------------------------------------------------------===//
2489// Decls
2490//===----------------------------------------------------------------------===//
2491
2492ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) {
2493 // Unwrap reference expressions.
2494 if (auto *init = dyn_cast<ast::DeclRefExpr>(node))
2495 node = init->getDecl();
2496 return dyn_cast<ast::CallableDecl>(node);
2497}
2498
2499FailureOr<ast::PatternDecl *>
2500Parser::createPatternDecl(SMRange loc, const ast::Name *name,
2501 const ParsedPatternMetadata &metadata,
2502 ast::CompoundStmt *body) {
2503 return ast::PatternDecl::create(ctx, loc, name, metadata.benefit,
2504 metadata.hasBoundedRecursion, body);
2505}
2506
2507ast::Type Parser::createUserConstraintRewriteResultType(
2508 ArrayRef<ast::VariableDecl *> results) {
2509 // Single result decls use the type of the single result.
2510 if (results.size() == 1)
2511 return results[0]->getType();
2512
2513 // Multiple results use a tuple type, with the types and names grabbed from
2514 // the result variable decls.
2515 auto resultTypes = llvm::map_range(
2516 results, [&](const auto *result) { return result->getType(); });
2517 auto resultNames = llvm::map_range(
2518 results, [&](const auto *result) { return result->getName().getName(); });
2519 return ast::TupleType::get(ctx, llvm::to_vector(resultTypes),
2520 llvm::to_vector(resultNames));
2521}
2522
2523template <typename T>
2524FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl(
2525 const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
2526 ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
2527 ast::CompoundStmt *body) {
2528 if (!body->getChildren().empty()) {
2529 if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) {
2530 ast::Expr *resultExpr = retStmt->getResultExpr();
2531
2532 // Process the result of the decl. If no explicit signature results
2533 // were provided, check for return type inference. Otherwise, check that
2534 // the return expression can be converted to the expected type.
2535 if (results.empty())
2536 resultType = resultExpr->getType();
2537 else if (failed(convertExpressionTo(resultExpr, resultType)))
2538 return failure();
2539 else
2540 retStmt->setResultExpr(resultExpr);
2541 }
2542 }
2543 return T::createPDLL(ctx, name, arguments, results, body, resultType);
2544}
2545
2546FailureOr<ast::VariableDecl *>
2547Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
2548 ArrayRef<ast::ConstraintRef> constraints) {
2549 // The type of the variable, which is expected to be inferred by either a
2550 // constraint or an initializer expression.
2551 ast::Type type;
2552 if (failed(validateVariableConstraints(constraints, type)))
2553 return failure();
2554
2555 if (initializer) {
2556 // Update the variable type based on the initializer, or try to convert the
2557 // initializer to the existing type.
2558 if (!type)
2559 type = initializer->getType();
2560 else if (ast::Type mergedType = type.refineWith(initializer->getType()))
2561 type = mergedType;
2562 else if (failed(convertExpressionTo(initializer, type)))
2563 return failure();
2564
2565 // Otherwise, if there is no initializer check that the type has already
2566 // been resolved from the constraint list.
2567 } else if (!type) {
2568 return emitErrorAndNote(
2569 loc, "unable to infer type for variable `" + name + "`", loc,
2570 "the type of a variable must be inferable from the constraint "
2571 "list or the initializer");
2572 }
2573
2574 // Constraint types cannot be used when defining variables.
2575 if (isa<ast::ConstraintType, ast::RewriteType>(type)) {
2576 return emitError(
2577 loc, llvm::formatv("unable to define variable of `{0}` type", type));
2578 }
2579
2580 // Try to define a variable with the given name.
2581 FailureOr<ast::VariableDecl *> varDecl =
2582 defineVariableDecl(name, loc, type, initializer, constraints);
2583 if (failed(varDecl))
2584 return failure();
2585
2586 return *varDecl;
2587}
2588
2589FailureOr<ast::VariableDecl *>
2590Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
2591 const ast::ConstraintRef &constraint) {
2592 ast::Type argType;
2593 if (failed(validateVariableConstraint(constraint, argType)))
2594 return failure();
2595 return defineVariableDecl(name, loc, argType, constraint);
2596}
2597
2598LogicalResult
2599Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
2600 ast::Type &inferredType) {
2601 for (const ast::ConstraintRef &ref : constraints)
2602 if (failed(validateVariableConstraint(ref, inferredType)))
2603 return failure();
2604 return success();
2605}
2606
2607LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
2608 ast::Type &inferredType) {
2609 ast::Type constraintType;
2610 if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) {
2611 if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2612 if (failed(validateTypeConstraintExpr(typeExpr)))
2613 return failure();
2614 }
2615 constraintType = ast::AttributeType::get(ctx);
2616 } else if (const auto *cst =
2617 dyn_cast<ast::OpConstraintDecl>(ref.constraint)) {
2618 constraintType = ast::OperationType::get(
2619 ctx, cst->getName(), lookupODSOperation(cst->getName()));
2620 } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) {
2621 constraintType = typeTy;
2622 } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) {
2623 constraintType = typeRangeTy;
2624 } else if (const auto *cst =
2625 dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) {
2626 if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2627 if (failed(validateTypeConstraintExpr(typeExpr)))
2628 return failure();
2629 }
2630 constraintType = valueTy;
2631 } else if (const auto *cst =
2632 dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) {
2633 if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2634 if (failed(validateTypeRangeConstraintExpr(typeExpr)))
2635 return failure();
2636 }
2637 constraintType = valueRangeTy;
2638 } else if (const auto *cst =
2639 dyn_cast<ast::UserConstraintDecl>(ref.constraint)) {
2640 ArrayRef<ast::VariableDecl *> inputs = cst->getInputs();
2641 if (inputs.size() != 1) {
2642 return emitErrorAndNote(ref.referenceLoc,
2643 "`Constraint`s applied via a variable constraint "
2644 "list must take a single input, but got " +
2645 Twine(inputs.size()),
2646 cst->getLoc(),
2647 "see definition of constraint here");
2648 }
2649 constraintType = inputs.front()->getType();
2650 } else {
2651 llvm_unreachable("unknown constraint type");
2652 }
2653
2654 // Check that the constraint type is compatible with the current inferred
2655 // type.
2656 if (!inferredType) {
2657 inferredType = constraintType;
2658 } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) {
2659 inferredType = mergedTy;
2660 } else {
2661 return emitError(ref.referenceLoc,
2662 llvm::formatv("constraint type `{0}` is incompatible "
2663 "with the previously inferred type `{1}`",
2664 constraintType, inferredType));
2665 }
2666 return success();
2667}
2668
2669LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) {
2670 ast::Type typeExprType = typeExpr->getType();
2671 if (typeExprType != typeTy) {
2672 return emitError(typeExpr->getLoc(),
2673 "expected expression of `Type` in type constraint");
2674 }
2675 return success();
2676}
2677
2678LogicalResult
2679Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) {
2680 ast::Type typeExprType = typeExpr->getType();
2681 if (typeExprType != typeRangeTy) {
2682 return emitError(typeExpr->getLoc(),
2683 "expected expression of `TypeRange` in type constraint");
2684 }
2685 return success();
2686}
2687
2688//===----------------------------------------------------------------------===//
2689// Exprs
2690//===----------------------------------------------------------------------===//
2691
2692FailureOr<ast::CallExpr *>
2693Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
2694 MutableArrayRef<ast::Expr *> arguments, bool isNegated) {
2695 ast::Type parentType = parentExpr->getType();
2696
2697 ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr);
2698 if (!callableDecl) {
2699 return emitError(loc,
2700 llvm::formatv("expected a reference to a callable "
2701 "`Constraint` or `Rewrite`, but got: `{0}`",
2702 parentType));
2703 }
2704 if (parserContext == ParserContext::Rewrite) {
2705 if (isa<ast::UserConstraintDecl>(callableDecl))
2706 return emitError(
2707 loc, "unable to invoke `Constraint` within a rewrite section");
2708 if (isNegated)
2709 return emitError(loc, "unable to negate a Rewrite");
2710 } else {
2711 if (isa<ast::UserRewriteDecl>(callableDecl))
2712 return emitError(loc,
2713 "unable to invoke `Rewrite` within a match section");
2714 if (isNegated && cast<ast::UserConstraintDecl>(callableDecl)->getBody())
2715 return emitError(loc, "unable to negate non native constraints");
2716 }
2717
2718 // Verify the arguments of the call.
2719 /// Handle size mismatch.
2720 ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs();
2721 if (callArgs.size() != arguments.size()) {
2722 return emitErrorAndNote(
2723 loc,
2724 llvm::formatv("invalid number of arguments for {0} call; expected "
2725 "{1}, but got {2}",
2726 callableDecl->getCallableType(), callArgs.size(),
2727 arguments.size()),
2728 callableDecl->getLoc(),
2729 llvm::formatv("see the definition of {0} here",
2730 callableDecl->getName()->getName()));
2731 }
2732
2733 /// Handle argument type mismatch.
2734 auto attachDiagFn = [&](ast::Diagnostic &diag) {
2735 diag.attachNote(llvm::formatv("see the definition of `{0}` here",
2736 callableDecl->getName()->getName()),
2737 callableDecl->getLoc());
2738 };
2739 for (auto it : llvm::zip(callArgs, arguments)) {
2740 if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(),
2741 attachDiagFn)))
2742 return failure();
2743 }
2744
2745 return ast::CallExpr::create(ctx, loc, parentExpr, arguments,
2746 callableDecl->getResultType(), isNegated);
2747}
2748
2749FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,
2750 ast::Decl *decl) {
2751 // Check the type of decl being referenced.
2752 ast::Type declType;
2753 if (isa<ast::ConstraintDecl>(decl))
2754 declType = ast::ConstraintType::get(ctx);
2755 else if (isa<ast::UserRewriteDecl>(decl))
2756 declType = ast::RewriteType::get(ctx);
2757 else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
2758 declType = varDecl->getType();
2759 else
2760 return emitError(loc, "invalid reference to `" +
2761 decl->getName()->getName() + "`");
2762
2763 return ast::DeclRefExpr::create(ctx, loc, decl, declType);
2764}
2765
2766FailureOr<ast::DeclRefExpr *>
2767Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
2768 ArrayRef<ast::ConstraintRef> constraints) {
2769 FailureOr<ast::VariableDecl *> decl =
2770 defineVariableDecl(name, loc, type, constraints);
2771 if (failed(decl))
2772 return failure();
2773 return ast::DeclRefExpr::create(ctx, loc, *decl, type);
2774}
2775
2776FailureOr<ast::MemberAccessExpr *>
2777Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
2778 SMRange loc) {
2779 // Validate the member name for the given parent expression.
2780 FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc);
2781 if (failed(memberType))
2782 return failure();
2783
2784 return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType);
2785}
2786
2787FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
2788 StringRef name, SMRange loc) {
2789 ast::Type parentType = parentExpr->getType();
2790 if (ast::OperationType opType = dyn_cast<ast::OperationType>(parentType)) {
2792 return valueRangeTy;
2793
2794 // Verify member access based on the operation type.
2795 if (const ods::Operation *odsOp = opType.getODSOperation()) {
2796 auto results = odsOp->getResults();
2797
2798 // Handle indexed results.
2799 unsigned index = 0;
2800 if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2801 index < results.size()) {
2802 return results[index].isVariadic() ? valueRangeTy : valueTy;
2803 }
2804
2805 // Handle named results.
2806 const auto *it = llvm::find_if(results, [&](const auto &result) {
2807 return result.getName() == name;
2808 });
2809 if (it != results.end())
2810 return it->isVariadic() ? valueRangeTy : valueTy;
2811 } else if (llvm::isDigit(name[0])) {
2812 // Allow unchecked numeric indexing of the results of unregistered
2813 // operations. It returns a single value.
2814 return valueTy;
2815 }
2816 } else if (auto tupleType = dyn_cast<ast::TupleType>(parentType)) {
2817 // Handle indexed results.
2818 unsigned index = 0;
2819 if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2820 index < tupleType.size()) {
2821 return tupleType.getElementTypes()[index];
2822 }
2823
2824 // Handle named results.
2825 auto elementNames = tupleType.getElementNames();
2826 const auto *it = llvm::find(elementNames, name);
2827 if (it != elementNames.end())
2828 return tupleType.getElementTypes()[it - elementNames.begin()];
2829 }
2830 return emitError(
2831 loc,
2832 llvm::formatv("invalid member access `{0}` on expression of type `{1}`",
2833 name, parentType));
2834}
2835
2836FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
2837 SMRange loc, const ast::OpNameDecl *name,
2838 OpResultTypeContext resultTypeContext,
2839 SmallVectorImpl<ast::Expr *> &operands,
2840 MutableArrayRef<ast::NamedAttributeDecl *> attributes,
2841 SmallVectorImpl<ast::Expr *> &results) {
2842 std::optional<StringRef> opNameRef = name->getName();
2843 const ods::Operation *odsOp = lookupODSOperation(opNameRef);
2844
2845 // Verify the inputs operands.
2846 if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands)))
2847 return failure();
2848
2849 // Verify the attribute list.
2850 for (ast::NamedAttributeDecl *attr : attributes) {
2851 // Check for an attribute type, or a type awaiting resolution.
2852 ast::Type attrType = attr->getValue()->getType();
2853 if (!isa<ast::AttributeType>(attrType)) {
2854 return emitError(
2855 attr->getValue()->getLoc(),
2856 llvm::formatv("expected `Attr` expression, but got `{0}`", attrType));
2857 }
2858 }
2859
2860 assert(
2861 (resultTypeContext == OpResultTypeContext::Explicit || results.empty()) &&
2862 "unexpected inferrence when results were explicitly specified");
2863
2864 // If we aren't relying on type inferrence, or explicit results were provided,
2865 // validate them.
2866 if (resultTypeContext == OpResultTypeContext::Explicit) {
2867 if (failed(validateOperationResults(loc, opNameRef, odsOp, results)))
2868 return failure();
2869
2870 // Validate the use of interface based type inferrence for this operation.
2871 } else if (resultTypeContext == OpResultTypeContext::Interface) {
2872 assert(opNameRef &&
2873 "expected valid operation name when inferring operation results");
2874 checkOperationResultTypeInferrence(loc, *opNameRef, odsOp);
2875 }
2876
2877 return ast::OperationExpr::create(ctx, loc, odsOp, name, operands, results,
2878 attributes);
2879}
2880
2881LogicalResult
2882Parser::validateOperationOperands(SMRange loc, std::optional<StringRef> name,
2883 const ods::Operation *odsOp,
2884 SmallVectorImpl<ast::Expr *> &operands) {
2885 return validateOperationOperandsOrResults(
2886 "operand", loc, odsOp ? odsOp->getLoc() : std::optional<SMRange>(), name,
2887 operands,
2888 odsOp ? odsOp->getOperands() : ArrayRef<pdll::ods::OperandOrResult>(),
2889 valueTy, valueRangeTy);
2890}
2891
2892LogicalResult
2893Parser::validateOperationResults(SMRange loc, std::optional<StringRef> name,
2894 const ods::Operation *odsOp,
2895 SmallVectorImpl<ast::Expr *> &results) {
2896 return validateOperationOperandsOrResults(
2897 "result", loc, odsOp ? odsOp->getLoc() : std::optional<SMRange>(), name,
2898 results,
2899 odsOp ? odsOp->getResults() : ArrayRef<pdll::ods::OperandOrResult>(),
2900 typeTy, typeRangeTy);
2901}
2902
2903void Parser::checkOperationResultTypeInferrence(SMRange loc, StringRef opName,
2904 const ods::Operation *odsOp) {
2905 // If the operation might not have inferrence support, emit a warning to the
2906 // user. We don't emit an error because the interface might be added to the
2907 // operation at runtime. It's rare, but it could still happen. We emit a
2908 // warning here instead.
2909
2910 // Handle inferrence warnings for unknown operations.
2911 if (!odsOp) {
2913 loc, llvm::formatv(
2914 "operation result types are marked to be inferred, but "
2915 "`{0}` is unknown. Ensure that `{0}` supports zero "
2916 "results or implements `InferTypeOpInterface`. Include "
2917 "the ODS definition of this operation to remove this warning.",
2918 opName));
2919 return;
2920 }
2921
2922 // Handle inferrence warnings for known operations that expected at least one
2923 // result, but don't have inference support. An elided results list can mean
2924 // "zero-results", and we don't want to warn when that is the expected
2925 // behavior.
2926 bool requiresInferrence =
2927 llvm::any_of(odsOp->getResults(), [](const ods::OperandOrResult &result) {
2928 return !result.isVariableLength();
2929 });
2930 if (requiresInferrence && !odsOp->hasResultTypeInferrence()) {
2931 ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitWarning(
2932 loc,
2933 llvm::formatv("operation result types are marked to be inferred, but "
2934 "`{0}` does not provide an implementation of "
2935 "`InferTypeOpInterface`. Ensure that `{0}` attaches "
2936 "`InferTypeOpInterface` at runtime, or add support to "
2937 "the ODS definition to remove this warning.",
2938 opName));
2939 diag->attachNote(llvm::formatv("see the definition of `{0}` here", opName),
2940 odsOp->getLoc());
2941 return;
2942 }
2943}
2944
2945LogicalResult Parser::validateOperationOperandsOrResults(
2946 StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc,
2947 std::optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values,
2948 ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
2949 ast::RangeType rangeTy) {
2950 // All operation types accept a single range parameter.
2951 if (values.size() == 1) {
2952 if (failed(convertExpressionTo(values[0], rangeTy)))
2953 return failure();
2954 return success();
2955 }
2956
2957 /// If the operation has ODS information, we can more accurately verify the
2958 /// values.
2959 if (odsOpLoc) {
2960 auto emitSizeMismatchError = [&] {
2961 return emitErrorAndNote(
2962 loc,
2963 llvm::formatv("invalid number of {0} groups for `{1}`; expected "
2964 "{2}, but got {3}",
2965 groupName, *name, odsValues.size(), values.size()),
2966 *odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name));
2967 };
2968
2969 // Handle the case where no values were provided.
2970 if (values.empty()) {
2971 // If we don't expect any on the ODS side, we are done.
2972 if (odsValues.empty())
2973 return success();
2974
2975 // If we do, check if we actually need to provide values (i.e. if any of
2976 // the values are actually required).
2977 unsigned numVariadic = 0;
2978 for (const auto &odsValue : odsValues) {
2979 if (!odsValue.isVariableLength())
2980 return emitSizeMismatchError();
2981 ++numVariadic;
2982 }
2983
2984 // If we are in a non-rewrite context, we don't need to do anything more.
2985 // Zero-values is a valid constraint on the operation.
2986 if (parserContext != ParserContext::Rewrite)
2987 return success();
2988
2989 // Otherwise, when in a rewrite we may need to provide values to match the
2990 // ODS signature of the operation to create.
2991
2992 // If we only have one variadic value, just use an empty list.
2993 if (numVariadic == 1)
2994 return success();
2995
2996 // Otherwise, create dummy values for each of the entries so that we
2997 // adhere to the ODS signature.
2998 for (unsigned i = 0, e = odsValues.size(); i < e; ++i) {
2999 values.push_back(
3000 ast::RangeExpr::create(ctx, loc, /*elements=*/{}, rangeTy));
3001 }
3002 return success();
3003 }
3004
3005 // Verify that the number of values provided matches the number of value
3006 // groups ODS expects.
3007 if (odsValues.size() != values.size())
3008 return emitSizeMismatchError();
3009
3010 auto diagFn = [&](ast::Diagnostic &diag) {
3011 diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name),
3012 *odsOpLoc);
3013 };
3014 for (unsigned i = 0, e = values.size(); i < e; ++i) {
3015 ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy;
3016 if (failed(convertExpressionTo(values[i], expectedType, diagFn)))
3017 return failure();
3018 }
3019 return success();
3020 }
3021
3022 // Otherwise, accept the value groups as they have been defined and just
3023 // ensure they are one of the expected types.
3024 for (ast::Expr *&valueExpr : values) {
3025 ast::Type valueExprType = valueExpr->getType();
3026
3027 // Check if this is one of the expected types.
3028 if (valueExprType == rangeTy || valueExprType == singleTy)
3029 continue;
3030
3031 // If the operand is an Operation, allow converting to a Value or
3032 // ValueRange. This situations arises quite often with nested operation
3033 // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)`
3034 if (singleTy == valueTy) {
3035 if (isa<ast::OperationType>(valueExprType)) {
3036 valueExpr = convertOpToValue(valueExpr);
3037 continue;
3038 }
3039 }
3040
3041 // Otherwise, try to convert the expression to a range.
3042 if (succeeded(convertExpressionTo(valueExpr, rangeTy)))
3043 continue;
3044
3045 return emitError(
3046 valueExpr->getLoc(),
3047 llvm::formatv(
3048 "expected `{0}` or `{1}` convertible expression, but got `{2}`",
3049 singleTy, rangeTy, valueExprType));
3050 }
3051 return success();
3052}
3053
3054FailureOr<ast::TupleExpr *>
3055Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements,
3056 ArrayRef<StringRef> elementNames) {
3057 for (const ast::Expr *element : elements) {
3058 ast::Type eleTy = element->getType();
3059 if (isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>(eleTy)) {
3060 return emitError(
3061 element->getLoc(),
3062 llvm::formatv("unable to build a tuple with `{0}` element", eleTy));
3063 }
3064 }
3065 return ast::TupleExpr::create(ctx, loc, elements, elementNames);
3066}
3067
3068//===----------------------------------------------------------------------===//
3069// Stmts
3070//===----------------------------------------------------------------------===//
3071
3072FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc,
3073 ast::Expr *rootOp) {
3074 // Check that root is an Operation.
3075 ast::Type rootType = rootOp->getType();
3076 if (!isa<ast::OperationType>(rootType))
3077 return emitError(rootOp->getLoc(), "expected `Op` expression");
3078
3079 return ast::EraseStmt::create(ctx, loc, rootOp);
3080}
3081
3082FailureOr<ast::ReplaceStmt *>
3083Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp,
3084 MutableArrayRef<ast::Expr *> replValues) {
3085 // Check that root is an Operation.
3086 ast::Type rootType = rootOp->getType();
3087 if (!isa<ast::OperationType>(rootType)) {
3088 return emitError(
3089 rootOp->getLoc(),
3090 llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
3091 }
3092
3093 // If there are multiple replacement values, we implicitly convert any Op
3094 // expressions to the value form.
3095 bool shouldConvertOpToValues = replValues.size() > 1;
3096 for (ast::Expr *&replExpr : replValues) {
3097 ast::Type replType = replExpr->getType();
3098
3099 // Check that replExpr is an Operation, Value, or ValueRange.
3100 if (isa<ast::OperationType>(replType)) {
3101 if (shouldConvertOpToValues)
3102 replExpr = convertOpToValue(replExpr);
3103 continue;
3104 }
3105
3106 if (replType != valueTy && replType != valueRangeTy) {
3107 return emitError(replExpr->getLoc(),
3108 llvm::formatv("expected `Op`, `Value` or `ValueRange` "
3109 "expression, but got `{0}`",
3110 replType));
3111 }
3112 }
3113
3114 return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues);
3115}
3116
3117FailureOr<ast::RewriteStmt *>
3118Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp,
3119 ast::CompoundStmt *rewriteBody) {
3120 // Check that root is an Operation.
3121 ast::Type rootType = rootOp->getType();
3122 if (!isa<ast::OperationType>(rootType)) {
3123 return emitError(
3124 rootOp->getLoc(),
3125 llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
3126 }
3127
3128 return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody);
3129}
3130
3131//===----------------------------------------------------------------------===//
3132// Code Completion
3133//===----------------------------------------------------------------------===//
3134
3135LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) {
3136 ast::Type parentType = parentExpr->getType();
3137 if (ast::OperationType opType = dyn_cast<ast::OperationType>(parentType))
3138 codeCompleteContext->codeCompleteOperationMemberAccess(opType);
3139 else if (ast::TupleType tupleType = dyn_cast<ast::TupleType>(parentType))
3140 codeCompleteContext->codeCompleteTupleMemberAccess(tupleType);
3141 return failure();
3142}
3143
3144LogicalResult
3145Parser::codeCompleteAttributeName(std::optional<StringRef> opName) {
3146 if (opName)
3147 codeCompleteContext->codeCompleteOperationAttributeName(*opName);
3148 return failure();
3149}
3150
3151LogicalResult
3152Parser::codeCompleteConstraintName(ast::Type inferredType,
3153 bool allowInlineTypeConstraints) {
3154 codeCompleteContext->codeCompleteConstraintName(
3155 inferredType, allowInlineTypeConstraints, curDeclScope);
3156 return failure();
3157}
3158
3159LogicalResult Parser::codeCompleteDialectName() {
3160 codeCompleteContext->codeCompleteDialectName();
3161 return failure();
3162}
3163
3164LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) {
3165 codeCompleteContext->codeCompleteOperationName(dialectName);
3166 return failure();
3167}
3168
3169LogicalResult Parser::codeCompletePatternMetadata() {
3170 codeCompleteContext->codeCompletePatternMetadata();
3171 return failure();
3172}
3173
3174LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) {
3175 codeCompleteContext->codeCompleteIncludeFilename(curPath);
3176 return failure();
3177}
3178
3179void Parser::codeCompleteCallSignature(ast::Node *parent,
3180 unsigned currentNumArgs) {
3181 ast::CallableDecl *callableDecl = tryExtractCallableDecl(parent);
3182 if (!callableDecl)
3183 return;
3184
3185 codeCompleteContext->codeCompleteCallSignature(callableDecl, currentNumArgs);
3186}
3187
3188void Parser::codeCompleteOperationOperandsSignature(
3189 std::optional<StringRef> opName, unsigned currentNumOperands) {
3190 codeCompleteContext->codeCompleteOperationOperandsSignature(
3191 opName, currentNumOperands);
3192}
3193
3194void Parser::codeCompleteOperationResultsSignature(
3195 std::optional<StringRef> opName, unsigned currentNumResults) {
3196 codeCompleteContext->codeCompleteOperationResultsSignature(opName,
3197 currentNumResults);
3198}
3199
3200//===----------------------------------------------------------------------===//
3201// Parser
3202//===----------------------------------------------------------------------===//
3203
3204FailureOr<ast::Module *>
3205mlir::pdll::parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
3206 bool enableDocumentation,
3207 CodeCompleteContext *codeCompleteContext) {
3208 Parser parser(ctx, sourceMgr, enableDocumentation, codeCompleteContext);
3209 return parser.parseModule();
3210}
return success()
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static std::string diag(const llvm::Value &value)
Token lexToken()
Definition Lexer.cpp:85
const llvm::SourceMgr & getSourceMgr()
Definition Lexer.h:28
bool isKeyword() const
Return true if this is one of the keyword token kinds (e.g. kw_if).
Definition Token.cpp:192
SMLoc getLoc() const
Definition Token.cpp:24
bool is(Kind k) const
Definition Token.h:38
std::string getStringValue() const
Given a token containing a string literal, return its value, including removing the quote characters ...
Definition Token.cpp:86
bool isAny(Kind k1, Kind k2) const
Definition Token.h:40
Kind getKind() const
Definition Token.h:37
SMLoc getEndLoc() const
Definition Token.cpp:26
bool isNot(Kind k) const
Definition Token.h:50
StringRef getSpelling() const
Definition Token.h:34
This class provides an abstract interface into the parser for hooking in code completion events.
virtual void codeCompleteConstraintName(ast::Type currentType, bool allowInlineTypeConstraints, const ast::DeclScope *scope)
Signal code completion for a constraint name with an optional decl scope.
virtual void codeCompleteOperationAttributeName(StringRef opName)
Signal code completion for a member access into the given operation type.
virtual void codeCompleteOperationOperandsSignature(std::optional< StringRef > opName, unsigned currentNumOperands)
Signal code completion for the signature of an operation's operands.
virtual void codeCompleteOperationName(StringRef dialectName)
Signal code completion for an operation name in the given dialect.
virtual void codeCompleteOperationResultsSignature(std::optional< StringRef > opName, unsigned currentNumResults)
Signal code completion for the signature of an operation's results.
virtual void codeCompleteDialectName()
Signal code completion for a dialect name.
virtual void codeCompleteOperationMemberAccess(ast::OperationType opType)
Signal code completion for a member access into the given operation type.
virtual void codeCompleteTupleMemberAccess(ast::TupleType tupleType)
Signal code completion for a member access into the given tuple type.
virtual void codeCompletePatternMetadata()
Signal code completion for Pattern metadata.
virtual void codeCompleteCallSignature(const ast::CallableDecl *callable, unsigned currentNumArgs)
Signal code completion for the signature of a callable.
virtual void codeCompleteIncludeFilename(StringRef curPath)
Signal code completion for an include filename.
@ code_complete_string
Token signifying a code completion location within a string.
Definition Lexer.h:41
@ directive
Tokens.
Definition Lexer.h:93
@ eof
Markers.
Definition Lexer.h:36
@ code_complete
Token signifying a code completion location.
Definition Lexer.h:39
@ arrow
Punctuation.
Definition Lexer.h:74
@ less
Paired punctuation.
Definition Lexer.h:82
@ kw_Attr
General keywords.
Definition Lexer.h:54
static StringRef getMemberName()
Return the member name used for the "all-results" access.
Definition Nodes.h:487
static AllResultsMemberAccessExpr * create(Context &ctx, SMRange loc, const Expr *parentExpr, Type type)
Definition Nodes.h:489
static AttrConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr=nullptr)
Definition Nodes.cpp:385
static AttributeExpr * create(Context &ctx, SMRange loc, StringRef value)
Definition Nodes.cpp:259
static AttributeType get(Context &context)
Return an instance of the Attribute type.
Definition Types.cpp:57
static CallExpr * create(Context &ctx, SMRange loc, Expr *callable, ArrayRef< Expr * > arguments, Type resultType, bool isNegated=false)
Definition Nodes.cpp:269
Type getResultType() const
Return the result type of this decl.
Definition Nodes.h:1212
StringRef getCallableType() const
Return the callable type of this decl.
Definition Nodes.h:1197
ArrayRef< VariableDecl * > getInputs() const
Return the inputs of this decl.
Definition Nodes.h:1205
ArrayRef< Stmt * >::iterator end() const
Definition Nodes.h:192
MutableArrayRef< Stmt * > getChildren()
Return the children of this compound statement.
Definition Nodes.h:185
ArrayRef< Stmt * >::iterator begin() const
Definition Nodes.h:191
static CompoundStmt * create(Context &ctx, SMRange location, ArrayRef< Stmt * > children)
Definition Nodes.cpp:192
static ConstraintType get(Context &context)
Return an instance of the Constraint type.
Definition Types.cpp:65
This class represents the main context of the PDLL AST.
Definition Context.h:25
DiagnosticEngine & getDiagEngine()
Return the diagnostic engine of this context.
Definition Context.h:42
ods::Context & getODSContext()
Return the ODS context used by the AST.
Definition Context.h:38
static DeclRefExpr * create(Context &ctx, SMRange loc, Decl *decl, Type type)
Definition Nodes.cpp:285
This class represents a scope for named AST decls.
Definition Nodes.h:64
Decl * lookup(StringRef name)
Lookup a decl with the given name starting from this scope.
Definition Nodes.cpp:182
void add(Decl *decl)
Add a new decl to the scope.
Definition Nodes.cpp:175
DeclScope * getParentScope()
Return the parent scope of this scope, or nullptr if there is no parent.
Definition Nodes.h:70
void setDocComment(Context &ctx, StringRef comment)
Set the documentation comment for this decl.
Definition Nodes.cpp:377
const Name * getName() const
Return the name of the decl, or nullptr if it doesn't have one.
Definition Nodes.h:672
InFlightDiagnostic emitWarning(SMRange loc, const Twine &msg)
Definition Diagnostic.h:149
InFlightDiagnostic emitError(SMRange loc, const Twine &msg)
Emit an error to the diagnostic engine.
Definition Diagnostic.h:145
This class provides a simple implementation of a PDLL diagnostic.
Definition Diagnostic.h:29
Diagnostic & attachNote(const Twine &msg, std::optional< SMRange > noteLoc=std::nullopt)
Attach a note to this diagnostic.
Definition Diagnostic.h:46
static EraseStmt * create(Context &ctx, SMRange loc, Expr *rootOp)
Definition Nodes.cpp:218
This class represents a base AST Expression node.
Definition Nodes.h:348
Type getType() const
Return the type of this expression.
Definition Nodes.h:351
This class represents a diagnostic that is inflight and set to be reported.
Definition Diagnostic.h:83
static LetStmt * create(Context &ctx, SMRange loc, VariableDecl *varDecl)
Definition Nodes.cpp:206
static MemberAccessExpr * create(Context &ctx, SMRange loc, const Expr *parentExpr, StringRef memberName, Type type)
Definition Nodes.cpp:295
static Module * create(Context &ctx, SMLoc loc, ArrayRef< Decl * > children)
Definition Nodes.cpp:566
static NamedAttributeDecl * create(Context &ctx, const Name &name, Expr *value)
Definition Nodes.cpp:492
SMRange getLoc() const
Return the location of this node.
Definition Nodes.h:131
static OpConstraintDecl * create(Context &ctx, SMRange loc, const OpNameDecl *nameDecl=nullptr)
Definition Nodes.cpp:395
std::optional< StringRef > getName() const
Return the name of this operation, or std::nullopt if the name is unknown.
Definition Nodes.h:1028
static OpNameDecl * create(Context &ctx, const Name &name)
Definition Nodes.cpp:502
static OperationExpr * create(Context &ctx, SMRange loc, const ods::Operation *odsOp, const OpNameDecl *nameDecl, ArrayRef< Expr * > operands, ArrayRef< Expr * > resultTypes, ArrayRef< NamedAttributeDecl * > attributes)
Definition Nodes.cpp:307
This class represents a PDLL type that corresponds to an mlir::Operation.
Definition Types.h:134
const ods::Operation * getODSOperation() const
Return the ODS operation that this type refers to, or nullptr if the ODS operation is unknown.
Definition Types.cpp:87
static OperationType get(Context &context, std::optional< StringRef > name=std::nullopt, const ods::Operation *odsOp=nullptr)
Return an instance of the Operation type with an optional operation name.
Definition Types.cpp:73
static PatternDecl * create(Context &ctx, SMRange location, const Name *name, std::optional< uint16_t > benefit, bool hasBoundedRecursion, const CompoundStmt *body)
Definition Nodes.cpp:513
static RangeExpr * create(Context &ctx, SMRange loc, ArrayRef< Expr * > elements, RangeType type)
Definition Nodes.cpp:335
static ReplaceStmt * create(Context &ctx, SMRange loc, Expr *rootOp, ArrayRef< Expr * > replExprs)
Definition Nodes.cpp:226
static ReturnStmt * create(Context &ctx, SMRange loc, Expr *resultExpr)
Definition Nodes.cpp:250
static RewriteStmt * create(Context &ctx, SMRange loc, Expr *rootOp, CompoundStmt *rewriteBody)
Definition Nodes.cpp:240
static RewriteType get(Context &context)
Return an instance of the Rewrite type.
Definition Types.cpp:136
static TupleExpr * create(Context &ctx, SMRange loc, ArrayRef< Expr * > elements, ArrayRef< StringRef > elementNames)
Definition Nodes.cpp:349
This class represents a PDLL tuple type, i.e.
Definition Types.h:222
size_t size() const
Return the number of elements within this tuple.
Definition Types.h:238
ArrayRef< Type > getElementTypes() const
Return the element types of this tuple.
Definition Types.cpp:155
static TupleType get(Context &context, ArrayRef< Type > elementTypes, ArrayRef< StringRef > elementNames)
Return an instance of the Tuple type.
Definition Types.cpp:144
static TypeConstraintDecl * create(Context &ctx, SMRange loc)
Definition Nodes.cpp:412
static TypeExpr * create(Context &ctx, SMRange loc, StringRef value)
Definition Nodes.cpp:368
static TypeRangeConstraintDecl * create(Context &ctx, SMRange loc)
Definition Nodes.cpp:421
static TypeRangeType get(Context &context)
Return an instance of the TypeRange type.
Definition Types.cpp:113
static TypeType get(Context &context)
Return an instance of the Type type.
Definition Types.cpp:167
Type refineWith(Type other) const
Try to refine this type with the one provided.
Definition Types.cpp:33
static UserConstraintDecl * createNative(Context &ctx, const Name &name, ArrayRef< VariableDecl * > inputs, ArrayRef< VariableDecl * > results, std::optional< StringRef > codeBlock, Type resultType, ArrayRef< StringRef > nativeInputTypes={})
Create a native constraint with the given optional code block.
Definition Nodes.h:892
static ValueConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr)
Definition Nodes.cpp:431
static ValueRangeConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr=nullptr)
Definition Nodes.cpp:442
static ValueRangeType get(Context &context)
Return an instance of the ValueRange type.
Definition Types.cpp:127
static ValueType get(Context &context)
Return an instance of the Value type.
Definition Types.cpp:175
static VariableDecl * create(Context &ctx, const Name &name, Type type, Expr *initExpr, ArrayRef< ConstraintRef > constraints)
Definition Nodes.cpp:549
std::pair< Operation *, bool > insertOperation(StringRef name, StringRef summary, StringRef desc, StringRef nativeClassName, bool supportsResultTypeInferrence, SMLoc loc)
Insert a new operation with the context.
Definition Context.cpp:63
const TypeConstraint & insertTypeConstraint(StringRef name, StringRef summary, StringRef cppClass)
Insert a new type constraint with the context.
Definition Context.cpp:41
const AttributeConstraint & insertAttributeConstraint(StringRef name, StringRef summary, StringRef cppClass)
Insert a new attribute constraint with the context.
Definition Context.cpp:27
const Operation * lookupOperation(StringRef name) const
Lookup an operation registered with the given name, or null if no operation with that name is registe...
Definition Context.cpp:72
This class provides an ODS representation of a specific operation.
Definition Operation.h:125
ArrayRef< OperandOrResult > getOperands() const
Returns the operands of this operation.
Definition Operation.h:165
SMRange getLoc() const
Return the source location of this operation.
Definition Operation.h:128
bool hasResultTypeInferrence() const
Return if the operation is known to support result type inferrence.
Definition Operation.h:171
ArrayRef< OperandOrResult > getResults() const
Returns the results of this operation.
Definition Operation.h:168
StringRef getSummary() const
std::string getUniqueDefName() const
Returns a unique name for the TablGen def of this constraint.
StringRef getDescription() const
std::string getConditionTemplate() const
FmtContext & withSelf(Twine subst)
Definition Format.cpp:41
FailureOr< ast::Module * > parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr, bool enableDocumentation=false, CodeCompleteContext *codeCompleteContext=nullptr)
Parse an AST module from the main file of the given source manager.
Definition Parser.cpp:3205
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
auto tgfmt(StringRef fmt, const FmtContext *ctx, Ts &&...vals) -> FmtObject< decltype(std::make_tuple(llvm::support::detail::build_format_adapter(std::forward< Ts >(vals))...))>
Formats text by substituting placeholders in format string with replacement parameters.
Definition Format.h:262
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
const ConstraintDecl * constraint
Definition Nodes.h:722
StringRef getName() const
Return the raw string name.
Definition Nodes.h:41
SMRange getLoc() const
Get the location of this name.
Definition Nodes.h:44
static const Name & create(Context &ctx, StringRef name, SMRange location)
Definition Nodes.cpp:33