MLIR  17.0.0git
Parser.cpp
Go to the documentation of this file.
1 //===- Parser.cpp ---------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "Lexer.h"
13 #include "mlir/TableGen/Argument.h"
16 #include "mlir/TableGen/Format.h"
17 #include "mlir/TableGen/Operator.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/FormatVariadic.h"
29 #include "llvm/Support/ManagedStatic.h"
30 #include "llvm/Support/SaveAndRestore.h"
31 #include "llvm/Support/ScopedPrinter.h"
32 #include "llvm/TableGen/Error.h"
33 #include "llvm/TableGen/Parser.h"
34 #include <string>
35 #include <optional>
36 
37 using namespace mlir;
38 using namespace mlir::pdll;
39 
40 //===----------------------------------------------------------------------===//
41 // Parser
42 //===----------------------------------------------------------------------===//
43 
44 namespace {
45 class Parser {
46 public:
47  Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
48  bool enableDocumentation, CodeCompleteContext *codeCompleteContext)
49  : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext),
50  curToken(lexer.lexToken()), enableDocumentation(enableDocumentation),
51  typeTy(ast::TypeType::get(ctx)), valueTy(ast::ValueType::get(ctx)),
52  typeRangeTy(ast::TypeRangeType::get(ctx)),
53  valueRangeTy(ast::ValueRangeType::get(ctx)),
54  attrTy(ast::AttributeType::get(ctx)),
55  codeCompleteContext(codeCompleteContext) {}
56 
57  /// Try to parse a new module. Returns nullptr in the case of failure.
58  FailureOr<ast::Module *> parseModule();
59 
60 private:
61  /// The current context of the parser. It allows for the parser to know a bit
62  /// about the construct it is nested within during parsing. This is used
63  /// specifically to provide additional verification during parsing, e.g. to
64  /// prevent using rewrites within a match context, matcher constraints within
65  /// a rewrite section, etc.
66  enum class ParserContext {
67  /// The parser is in the global context.
68  Global,
69  /// The parser is currently within a Constraint, which disallows all types
70  /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.).
71  Constraint,
72  /// The parser is currently within the matcher portion of a Pattern, which
73  /// is allows a terminal operation rewrite statement but no other rewrite
74  /// transformations.
75  PatternMatch,
76  /// The parser is currently within a Rewrite, which disallows calls to
77  /// constraints, requires operation expressions to have names, etc.
78  Rewrite,
79  };
80 
81  /// The current specification context of an operations result type. This
82  /// indicates how the result types of an operation may be inferred.
83  enum class OpResultTypeContext {
84  /// The result types of the operation are not known to be inferred.
85  Explicit,
86  /// The result types of the operation are inferred from the root input of a
87  /// `replace` statement.
88  Replacement,
89  /// The result types of the operation are inferred by using the
90  /// `InferTypeOpInterface` interface provided by the operation.
91  Interface,
92  };
93 
94  //===--------------------------------------------------------------------===//
95  // Parsing
96  //===--------------------------------------------------------------------===//
97 
98  /// Push a new decl scope onto the lexer.
99  ast::DeclScope *pushDeclScope() {
100  ast::DeclScope *newScope =
101  new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope);
102  return (curDeclScope = newScope);
103  }
104  void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; }
105 
106  /// Pop the last decl scope from the lexer.
107  void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); }
108 
109  /// Parse the body of an AST module.
110  LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls);
111 
112  /// Try to convert the given expression to `type`. Returns failure and emits
113  /// an error if a conversion is not viable. On failure, `noteAttachFn` is
114  /// invoked to attach notes to the emitted error diagnostic. On success,
115  /// `expr` is updated to the expression used to convert to `type`.
116  LogicalResult convertExpressionTo(
117  ast::Expr *&expr, ast::Type type,
118  function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {});
120  convertOpExpressionTo(ast::Expr *&expr, ast::OperationType exprType,
121  ast::Type type,
122  function_ref<ast::InFlightDiagnostic()> emitErrorFn);
123  LogicalResult convertTupleExpressionTo(
124  ast::Expr *&expr, ast::TupleType exprType, ast::Type type,
125  function_ref<ast::InFlightDiagnostic()> emitErrorFn,
126  function_ref<void(ast::Diagnostic &diag)> noteAttachFn);
127 
128  /// Given an operation expression, convert it to a Value or ValueRange
129  /// typed expression.
130  ast::Expr *convertOpToValue(const ast::Expr *opExpr);
131 
132  /// Lookup ODS information for the given operation, returns nullptr if no
133  /// information is found.
134  const ods::Operation *lookupODSOperation(std::optional<StringRef> opName) {
135  return opName ? ctx.getODSContext().lookupOperation(*opName) : nullptr;
136  }
137 
138  /// Process the given documentation string, or return an empty string if
139  /// documentation isn't enabled.
140  StringRef processDoc(StringRef doc) {
141  return enableDocumentation ? doc : StringRef();
142  }
143 
144  /// Process the given documentation string and format it, or return an empty
145  /// string if documentation isn't enabled.
146  std::string processAndFormatDoc(const Twine &doc) {
147  if (!enableDocumentation)
148  return "";
149  std::string docStr;
150  {
151  llvm::raw_string_ostream docOS(docStr);
152  std::string tmpDocStr = doc.str();
154  StringRef(tmpDocStr).rtrim(" \t"));
155  }
156  return docStr;
157  }
158 
159  //===--------------------------------------------------------------------===//
160  // Directives
161 
162  LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls);
163  LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls);
164  LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc,
166 
167  /// Process the records of a parsed tablegen include file.
168  void processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
170 
171  /// Create a user defined native constraint for a constraint imported from
172  /// ODS.
173  template <typename ConstraintT>
174  ast::Decl *
175  createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock,
176  SMRange loc, ast::Type type,
177  StringRef nativeType, StringRef docString);
178  template <typename ConstraintT>
179  ast::Decl *
180  createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
181  SMRange loc, ast::Type type,
182  StringRef nativeType);
183 
184  //===--------------------------------------------------------------------===//
185  // Decls
186 
187  /// This structure contains the set of pattern metadata that may be parsed.
188  struct ParsedPatternMetadata {
189  std::optional<uint16_t> benefit;
190  bool hasBoundedRecursion = false;
191  };
192 
193  FailureOr<ast::Decl *> parseTopLevelDecl();
195  parseNamedAttributeDecl(std::optional<StringRef> parentOpName);
196 
197  /// Parse an argument variable as part of the signature of a
198  /// UserConstraintDecl or UserRewriteDecl.
199  FailureOr<ast::VariableDecl *> parseArgumentDecl();
200 
201  /// Parse a result variable as part of the signature of a UserConstraintDecl
202  /// or UserRewriteDecl.
203  FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum);
204 
205  /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being
206  /// defined in a non-global context.
208  parseUserConstraintDecl(bool isInline = false);
209 
210  /// Parse an inline UserConstraintDecl. An inline decl is one defined in a
211  /// non-global context, such as within a Pattern/Constraint/etc.
212  FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl();
213 
214  /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
215  /// PDLL constructs.
216  FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl(
217  const ast::Name &name, bool isInline,
218  ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
219  ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
220 
221  /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being
222  /// defined in a non-global context.
223  FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false);
224 
225  /// Parse an inline UserRewriteDecl. An inline decl is one defined in a
226  /// non-global context, such as within a Pattern/Rewrite/etc.
227  FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl();
228 
229  /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
230  /// PDLL constructs.
231  FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl(
232  const ast::Name &name, bool isInline,
233  ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
234  ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
235 
236  /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have
237  /// effectively the same syntax, and only differ on slight semantics (given
238  /// the different parsing contexts).
239  template <typename T, typename ParseUserPDLLDeclFnT>
240  FailureOr<T *> parseUserConstraintOrRewriteDecl(
241  ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
242  StringRef anonymousNamePrefix, bool isInline);
243 
244  /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl.
245  /// These decls have effectively the same syntax.
246  template <typename T>
247  FailureOr<T *> parseUserNativeConstraintOrRewriteDecl(
248  const ast::Name &name, bool isInline,
250  ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
251 
252  /// Parse the functional signature (i.e. the arguments and results) of a
253  /// UserConstraintDecl or UserRewriteDecl.
254  LogicalResult parseUserConstraintOrRewriteSignature(
257  ast::DeclScope *&argumentScope, ast::Type &resultType);
258 
259  /// Validate the return (which if present is specified by bodyIt) of a
260  /// UserConstraintDecl or UserRewriteDecl.
261  LogicalResult validateUserConstraintOrRewriteReturn(
262  StringRef declType, ast::CompoundStmt *body,
265  ArrayRef<ast::VariableDecl *> results, ast::Type &resultType);
266 
268  parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
269  bool expectTerminalSemicolon = true);
270  FailureOr<ast::CompoundStmt *> parsePatternLambdaBody();
271  FailureOr<ast::Decl *> parsePatternDecl();
272  LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
273 
274  /// Check to see if a decl has already been defined with the given name, if
275  /// one has emit and error and return failure. Returns success otherwise.
276  LogicalResult checkDefineNamedDecl(const ast::Name &name);
277 
278  /// Try to define a variable decl with the given components, returns the
279  /// variable on success.
281  defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
282  ast::Expr *initExpr,
283  ArrayRef<ast::ConstraintRef> constraints);
285  defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
286  ArrayRef<ast::ConstraintRef> constraints);
287 
288  /// Parse the constraint reference list for a variable decl.
289  LogicalResult parseVariableDeclConstraintList(
291 
292  /// Parse the expression used within a type constraint, e.g. Attr<type-expr>.
293  FailureOr<ast::Expr *> parseTypeConstraintExpr();
294 
295  /// Try to parse a single reference to a constraint. `typeConstraint` is the
296  /// location of a previously parsed type constraint for the entity that will
297  /// be constrained by the parsed constraint. `existingConstraints` are any
298  /// existing constraints that have already been parsed for the same entity
299  /// that will be constrained by this constraint. `allowInlineTypeConstraints`
300  /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`.
302  parseConstraint(std::optional<SMRange> &typeConstraint,
303  ArrayRef<ast::ConstraintRef> existingConstraints,
304  bool allowInlineTypeConstraints);
305 
306  /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl
307  /// argument or result variable. The constraints for these variables do not
308  /// allow inline type constraints, and only permit a single constraint.
309  FailureOr<ast::ConstraintRef> parseArgOrResultConstraint();
310 
311  //===--------------------------------------------------------------------===//
312  // Exprs
313 
314  FailureOr<ast::Expr *> parseExpr();
315 
316  /// Identifier expressions.
317  FailureOr<ast::Expr *> parseAttributeExpr();
318  FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr);
319  FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
320  FailureOr<ast::Expr *> parseIdentifierExpr();
321  FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
322  FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
323  FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
324  FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
325  FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
327  parseOperationExpr(OpResultTypeContext inputResultTypeContext =
328  OpResultTypeContext::Explicit);
329  FailureOr<ast::Expr *> parseTupleExpr();
330  FailureOr<ast::Expr *> parseTypeExpr();
331  FailureOr<ast::Expr *> parseUnderscoreExpr();
332 
333  //===--------------------------------------------------------------------===//
334  // Stmts
335 
336  FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true);
337  FailureOr<ast::CompoundStmt *> parseCompoundStmt();
338  FailureOr<ast::EraseStmt *> parseEraseStmt();
339  FailureOr<ast::LetStmt *> parseLetStmt();
340  FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
341  FailureOr<ast::ReturnStmt *> parseReturnStmt();
342  FailureOr<ast::RewriteStmt *> parseRewriteStmt();
343 
344  //===--------------------------------------------------------------------===//
345  // Creation+Analysis
346  //===--------------------------------------------------------------------===//
347 
348  //===--------------------------------------------------------------------===//
349  // Decls
350 
351  /// Try to extract a callable from the given AST node. Returns nullptr on
352  /// failure.
353  ast::CallableDecl *tryExtractCallableDecl(ast::Node *node);
354 
355  /// Try to create a pattern decl with the given components, returning the
356  /// Pattern on success.
358  createPatternDecl(SMRange loc, const ast::Name *name,
359  const ParsedPatternMetadata &metadata,
360  ast::CompoundStmt *body);
361 
362  /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set
363  /// of results, defined as part of the signature.
364  ast::Type
365  createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results);
366 
367  /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl.
368  template <typename T>
369  FailureOr<T *> createUserPDLLConstraintOrRewriteDecl(
370  const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
371  ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
372  ast::CompoundStmt *body);
373 
374  /// Try to create a variable decl with the given components, returning the
375  /// Variable on success.
377  createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
378  ArrayRef<ast::ConstraintRef> constraints);
379 
380  /// Create a variable for an argument or result defined as part of the
381  /// signature of a UserConstraintDecl/UserRewriteDecl.
383  createArgOrResultVariableDecl(StringRef name, SMRange loc,
384  const ast::ConstraintRef &constraint);
385 
386  /// Validate the constraints used to constraint a variable decl.
387  /// `inferredType` is the type of the variable inferred by the constraints
388  /// within the list, and is updated to the most refined type as determined by
389  /// the constraints. Returns success if the constraint list is valid, failure
390  /// otherwise.
392  validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
393  ast::Type &inferredType);
394  /// Validate a single reference to a constraint. `inferredType` contains the
395  /// currently inferred variabled type and is refined within the type defined
396  /// by the constraint. Returns success if the constraint is valid, failure
397  /// otherwise.
398  LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref,
399  ast::Type &inferredType);
400  LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr);
401  LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr);
402 
403  //===--------------------------------------------------------------------===//
404  // Exprs
405 
407  createCallExpr(SMRange loc, ast::Expr *parentExpr,
408  MutableArrayRef<ast::Expr *> arguments);
409  FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl);
411  createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
412  ArrayRef<ast::ConstraintRef> constraints);
414  createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc);
415 
416  /// Validate the member access `name` into the given parent expression. On
417  /// success, this also returns the type of the member accessed.
418  FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr,
419  StringRef name, SMRange loc);
421  createOperationExpr(SMRange loc, const ast::OpNameDecl *name,
422  OpResultTypeContext resultTypeContext,
427  validateOperationOperands(SMRange loc, std::optional<StringRef> name,
428  const ods::Operation *odsOp,
429  SmallVectorImpl<ast::Expr *> &operands);
430  LogicalResult validateOperationResults(SMRange loc,
431  std::optional<StringRef> name,
432  const ods::Operation *odsOp,
434  void checkOperationResultTypeInferrence(SMRange loc, StringRef name,
435  const ods::Operation *odsOp);
436  LogicalResult validateOperationOperandsOrResults(
437  StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc,
438  std::optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values,
439  ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
440  ast::RangeType rangeTy);
441  FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
442  ArrayRef<ast::Expr *> elements,
443  ArrayRef<StringRef> elementNames);
444 
445  //===--------------------------------------------------------------------===//
446  // Stmts
447 
448  FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp);
450  createReplaceStmt(SMRange loc, ast::Expr *rootOp,
451  MutableArrayRef<ast::Expr *> replValues);
453  createRewriteStmt(SMRange loc, ast::Expr *rootOp,
454  ast::CompoundStmt *rewriteBody);
455 
456  //===--------------------------------------------------------------------===//
457  // Code Completion
458  //===--------------------------------------------------------------------===//
459 
460  /// The set of various code completion methods. Every completion method
461  /// returns `failure` to stop the parsing process after providing completion
462  /// results.
463 
464  LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr);
465  LogicalResult codeCompleteAttributeName(std::optional<StringRef> opName);
466  LogicalResult codeCompleteConstraintName(ast::Type inferredType,
467  bool allowInlineTypeConstraints);
468  LogicalResult codeCompleteDialectName();
469  LogicalResult codeCompleteOperationName(StringRef dialectName);
470  LogicalResult codeCompletePatternMetadata();
471  LogicalResult codeCompleteIncludeFilename(StringRef curPath);
472 
473  void codeCompleteCallSignature(ast::Node *parent, unsigned currentNumArgs);
474  void codeCompleteOperationOperandsSignature(std::optional<StringRef> opName,
475  unsigned currentNumOperands);
476  void codeCompleteOperationResultsSignature(std::optional<StringRef> opName,
477  unsigned currentNumResults);
478 
479  //===--------------------------------------------------------------------===//
480  // Lexer Utilities
481  //===--------------------------------------------------------------------===//
482 
483  /// If the current token has the specified kind, consume it and return true.
484  /// If not, return false.
485  bool consumeIf(Token::Kind kind) {
486  if (curToken.isNot(kind))
487  return false;
488  consumeToken(kind);
489  return true;
490  }
491 
492  /// Advance the current lexer onto the next token.
493  void consumeToken() {
494  assert(curToken.isNot(Token::eof, Token::error) &&
495  "shouldn't advance past EOF or errors");
496  curToken = lexer.lexToken();
497  }
498 
499  /// Advance the current lexer onto the next token, asserting what the expected
500  /// current token is. This is preferred to the above method because it leads
501  /// to more self-documenting code with better checking.
502  void consumeToken(Token::Kind kind) {
503  assert(curToken.is(kind) && "consumed an unexpected token");
504  consumeToken();
505  }
506 
507  /// Reset the lexer to the location at the given position.
508  void resetToken(SMRange tokLoc) {
509  lexer.resetPointer(tokLoc.Start.getPointer());
510  curToken = lexer.lexToken();
511  }
512 
513  /// Consume the specified token if present and return success. On failure,
514  /// output a diagnostic and return failure.
515  LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
516  if (curToken.getKind() != kind)
517  return emitError(curToken.getLoc(), msg);
518  consumeToken();
519  return success();
520  }
521  LogicalResult emitError(SMRange loc, const Twine &msg) {
522  lexer.emitError(loc, msg);
523  return failure();
524  }
525  LogicalResult emitError(const Twine &msg) {
526  return emitError(curToken.getLoc(), msg);
527  }
528  LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc,
529  const Twine &note) {
530  lexer.emitErrorAndNote(loc, msg, noteLoc, note);
531  return failure();
532  }
533 
534  //===--------------------------------------------------------------------===//
535  // Fields
536  //===--------------------------------------------------------------------===//
537 
538  /// The owning AST context.
539  ast::Context &ctx;
540 
541  /// The lexer of this parser.
542  Lexer lexer;
543 
544  /// The current token within the lexer.
545  Token curToken;
546 
547  /// A flag indicating if the parser should add documentation to AST nodes when
548  /// viable.
549  bool enableDocumentation;
550 
551  /// The most recently defined decl scope.
552  ast::DeclScope *curDeclScope = nullptr;
553  llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;
554 
555  /// The current context of the parser.
556  ParserContext parserContext = ParserContext::Global;
557 
558  /// Cached types to simplify verification and expression creation.
559  ast::Type typeTy, valueTy;
560  ast::RangeType typeRangeTy, valueRangeTy;
561  ast::Type attrTy;
562 
563  /// A counter used when naming anonymous constraints and rewrites.
564  unsigned anonymousDeclNameCounter = 0;
565 
566  /// The optional code completion context.
567  CodeCompleteContext *codeCompleteContext;
568 };
569 } // namespace
570 
571 FailureOr<ast::Module *> Parser::parseModule() {
572  SMLoc moduleLoc = curToken.getStartLoc();
573  pushDeclScope();
574 
575  // Parse the top-level decls of the module.
577  if (failed(parseModuleBody(decls)))
578  return popDeclScope(), failure();
579 
580  popDeclScope();
581  return ast::Module::create(ctx, moduleLoc, decls);
582 }
583 
584 LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) {
585  while (curToken.isNot(Token::eof)) {
586  if (curToken.is(Token::directive)) {
587  if (failed(parseDirective(decls)))
588  return failure();
589  continue;
590  }
591 
592  FailureOr<ast::Decl *> decl = parseTopLevelDecl();
593  if (failed(decl))
594  return failure();
595  decls.push_back(*decl);
596  }
597  return success();
598 }
599 
600 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) {
601  return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr,
602  valueRangeTy);
603 }
604 
605 LogicalResult Parser::convertExpressionTo(
606  ast::Expr *&expr, ast::Type type,
607  function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
608  ast::Type exprType = expr->getType();
609  if (exprType == type)
610  return success();
611 
612  auto emitConvertError = [&]() -> ast::InFlightDiagnostic {
613  ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError(
614  expr->getLoc(), llvm::formatv("unable to convert expression of type "
615  "`{0}` to the expected type of "
616  "`{1}`",
617  exprType, type));
618  if (noteAttachFn)
619  noteAttachFn(*diag);
620  return diag;
621  };
622 
623  if (auto exprOpType = exprType.dyn_cast<ast::OperationType>())
624  return convertOpExpressionTo(expr, exprOpType, type, emitConvertError);
625 
626  // FIXME: Decide how to allow/support converting a single result to multiple,
627  // and multiple to a single result. For now, we just allow Single->Range,
628  // but this isn't something really supported in the PDL dialect. We should
629  // figure out some way to support both.
630  if ((exprType == valueTy || exprType == valueRangeTy) &&
631  (type == valueTy || type == valueRangeTy))
632  return success();
633  if ((exprType == typeTy || exprType == typeRangeTy) &&
634  (type == typeTy || type == typeRangeTy))
635  return success();
636 
637  // Handle tuple types.
638  if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>())
639  return convertTupleExpressionTo(expr, exprTupleType, type, emitConvertError,
640  noteAttachFn);
641 
642  return emitConvertError();
643 }
644 
645 LogicalResult Parser::convertOpExpressionTo(
646  ast::Expr *&expr, ast::OperationType exprType, ast::Type type,
647  function_ref<ast::InFlightDiagnostic()> emitErrorFn) {
648  // Two operation types are compatible if they have the same name, or if the
649  // expected type is more general.
650  if (auto opType = type.dyn_cast<ast::OperationType>()) {
651  if (opType.getName())
652  return emitErrorFn();
653  return success();
654  }
655 
656  // An operation can always convert to a ValueRange.
657  if (type == valueRangeTy) {
658  expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
659  valueRangeTy);
660  return success();
661  }
662 
663  // Allow conversion to a single value by constraining the result range.
664  if (type == valueTy) {
665  // If the operation is registered, we can verify if it can ever have a
666  // single result.
667  if (const ods::Operation *odsOp = exprType.getODSOperation()) {
668  if (odsOp->getResults().empty()) {
669  return emitErrorFn()->attachNote(
670  llvm::formatv("see the definition of `{0}`, which was defined "
671  "with zero results",
672  odsOp->getName()),
673  odsOp->getLoc());
674  }
675 
676  unsigned numSingleResults = llvm::count_if(
677  odsOp->getResults(), [](const ods::OperandOrResult &result) {
678  return result.getVariableLengthKind() ==
679  ods::VariableLengthKind::Single;
680  });
681  if (numSingleResults > 1) {
682  return emitErrorFn()->attachNote(
683  llvm::formatv("see the definition of `{0}`, which was defined "
684  "with at least {1} results",
685  odsOp->getName(), numSingleResults),
686  odsOp->getLoc());
687  }
688  }
689 
690  expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
691  valueTy);
692  return success();
693  }
694  return emitErrorFn();
695 }
696 
697 LogicalResult Parser::convertTupleExpressionTo(
698  ast::Expr *&expr, ast::TupleType exprType, ast::Type type,
699  function_ref<ast::InFlightDiagnostic()> emitErrorFn,
700  function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
701  // Handle conversions between tuples.
702  if (auto tupleType = type.dyn_cast<ast::TupleType>()) {
703  if (tupleType.size() != exprType.size())
704  return emitErrorFn();
705 
706  // Build a new tuple expression using each of the elements of the current
707  // tuple.
708  SmallVector<ast::Expr *> newExprs;
709  for (unsigned i = 0, e = exprType.size(); i < e; ++i) {
710  newExprs.push_back(ast::MemberAccessExpr::create(
711  ctx, expr->getLoc(), expr, llvm::to_string(i),
712  exprType.getElementTypes()[i]));
713 
714  auto diagFn = [&](ast::Diagnostic &diag) {
715  diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`",
716  i, exprType));
717  if (noteAttachFn)
718  noteAttachFn(diag);
719  };
720  if (failed(convertExpressionTo(newExprs.back(),
721  tupleType.getElementTypes()[i], diagFn)))
722  return failure();
723  }
724  expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs,
725  tupleType.getElementNames());
726  return success();
727  }
728 
729  // Handle conversion to a range.
730  auto convertToRange = [&](ArrayRef<ast::Type> allowedElementTypes,
731  ast::RangeType resultTy) -> LogicalResult {
732  // TODO: We currently only allow range conversion within a rewrite context.
733  if (parserContext != ParserContext::Rewrite) {
734  return emitErrorFn()->attachNote("Tuple to Range conversion is currently "
735  "only allowed within a rewrite context");
736  }
737 
738  // All of the tuple elements must be allowed types.
739  for (ast::Type elementType : exprType.getElementTypes())
740  if (!llvm::is_contained(allowedElementTypes, elementType))
741  return emitErrorFn();
742 
743  // Build a new tuple expression using each of the elements of the current
744  // tuple.
745  SmallVector<ast::Expr *> newExprs;
746  for (unsigned i = 0, e = exprType.size(); i < e; ++i) {
747  newExprs.push_back(ast::MemberAccessExpr::create(
748  ctx, expr->getLoc(), expr, llvm::to_string(i),
749  exprType.getElementTypes()[i]));
750  }
751  expr = ast::RangeExpr::create(ctx, expr->getLoc(), newExprs, resultTy);
752  return success();
753  };
754  if (type == valueRangeTy)
755  return convertToRange({valueTy, valueRangeTy}, valueRangeTy);
756  if (type == typeRangeTy)
757  return convertToRange({typeTy, typeRangeTy}, typeRangeTy);
758 
759  return emitErrorFn();
760 }
761 
762 //===----------------------------------------------------------------------===//
763 // Directives
764 
765 LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) {
766  StringRef directive = curToken.getSpelling();
767  if (directive == "#include")
768  return parseInclude(decls);
769 
770  return emitError("unknown directive `" + directive + "`");
771 }
772 
773 LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) {
774  SMRange loc = curToken.getLoc();
775  consumeToken(Token::directive);
776 
777  // Handle code completion of the include file path.
778  if (curToken.is(Token::code_complete_string))
779  return codeCompleteIncludeFilename(curToken.getStringValue());
780 
781  // Parse the file being included.
782  if (!curToken.isString())
783  return emitError(loc,
784  "expected string file name after `include` directive");
785  SMRange fileLoc = curToken.getLoc();
786  std::string filenameStr = curToken.getStringValue();
787  StringRef filename = filenameStr;
788  consumeToken();
789 
790  // Check the type of include. If ending with `.pdll`, this is another pdl file
791  // to be parsed along with the current module.
792  if (filename.endswith(".pdll")) {
793  if (failed(lexer.pushInclude(filename, fileLoc)))
794  return emitError(fileLoc,
795  "unable to open include file `" + filename + "`");
796 
797  // If we added the include successfully, parse it into the current module.
798  // Make sure to update to the next token after we finish parsing the nested
799  // file.
800  curToken = lexer.lexToken();
801  LogicalResult result = parseModuleBody(decls);
802  curToken = lexer.lexToken();
803  return result;
804  }
805 
806  // Otherwise, this must be a `.td` include.
807  if (filename.endswith(".td"))
808  return parseTdInclude(filename, fileLoc, decls);
809 
810  return emitError(fileLoc,
811  "expected include filename to end with `.pdll` or `.td`");
812 }
813 
814 LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc,
816  llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr();
817 
818  // Use the source manager to open the file, but don't yet add it.
819  std::string includedFile;
820  llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer =
821  parserSrcMgr.OpenIncludeFile(filename.str(), includedFile);
822  if (!includeBuffer)
823  return emitError(fileLoc, "unable to open include file `" + filename + "`");
824 
825  // Setup the source manager for parsing the tablegen file.
826  llvm::SourceMgr tdSrcMgr;
827  tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc());
828  tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs());
829 
830  // This class provides a context argument for the llvm::SourceMgr diagnostic
831  // handler.
832  struct DiagHandlerContext {
833  Parser &parser;
834  StringRef filename;
835  llvm::SMRange loc;
836  } handlerContext{*this, filename, fileLoc};
837 
838  // Set the diagnostic handler for the tablegen source manager.
839  tdSrcMgr.setDiagHandler(
840  [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) {
841  auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext);
842  (void)ctx->parser.emitError(
843  ctx->loc,
844  llvm::formatv("error while processing include file `{0}`: {1}",
845  ctx->filename, diag.getMessage()));
846  },
847  &handlerContext);
848 
849  // Parse the tablegen file.
850  llvm::RecordKeeper tdRecords;
851  if (llvm::TableGenParseFile(tdSrcMgr, tdRecords))
852  return failure();
853 
854  // Process the parsed records.
855  processTdIncludeRecords(tdRecords, decls);
856 
857  // After we are done processing, move all of the tablegen source buffers to
858  // the main parser source mgr. This allows for directly using source locations
859  // from the .td files without needing to remap them.
860  parserSrcMgr.takeSourceBuffersFrom(tdSrcMgr, fileLoc.End);
861  return success();
862 }
863 
864 void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
866  // Return the length kind of the given value.
867  auto getLengthKind = [](const auto &value) {
868  if (value.isOptional())
870  return value.isVariadic() ? ods::VariableLengthKind::Variadic
872  };
873 
874  // Insert a type constraint into the ODS context.
875  ods::Context &odsContext = ctx.getODSContext();
876  auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst)
877  -> const ods::TypeConstraint & {
878  return odsContext.insertTypeConstraint(
879  cst.constraint.getUniqueDefName(),
880  processDoc(cst.constraint.getSummary()),
881  cst.constraint.getCPPClassName());
882  };
883  auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange {
884  return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)};
885  };
886 
887  // Process the parsed tablegen records to build ODS information.
888  /// Operations.
889  for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) {
890  tblgen::Operator op(def);
891 
892  // Check to see if this operation is known to support type inferrence.
893  bool supportsResultTypeInferrence =
894  op.getTrait("::mlir::InferTypeOpInterface::Trait");
895 
896  auto [odsOp, inserted] = odsContext.insertOperation(
897  op.getOperationName(), processDoc(op.getSummary()),
898  processAndFormatDoc(op.getDescription()), op.getQualCppClassName(),
899  supportsResultTypeInferrence, op.getLoc().front());
900 
901  // Ignore operations that have already been added.
902  if (!inserted)
903  continue;
904 
905  for (const tblgen::NamedAttribute &attr : op.getAttributes()) {
906  odsOp->appendAttribute(attr.name, attr.attr.isOptional(),
907  odsContext.insertAttributeConstraint(
908  attr.attr.getUniqueDefName(),
909  processDoc(attr.attr.getSummary()),
910  attr.attr.getStorageType()));
911  }
912  for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) {
913  odsOp->appendOperand(operand.name, getLengthKind(operand),
914  addTypeConstraint(operand));
915  }
916  for (const tblgen::NamedTypeConstraint &result : op.getResults()) {
917  odsOp->appendResult(result.name, getLengthKind(result),
918  addTypeConstraint(result));
919  }
920  }
921 
922  auto shouldBeSkipped = [this](llvm::Record *def) {
923  return def->isAnonymous() || curDeclScope->lookup(def->getName()) ||
924  def->isSubClassOf("DeclareInterfaceMethods");
925  };
926 
927  /// Attr constraints.
928  for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) {
929  if (shouldBeSkipped(def))
930  continue;
931 
932  tblgen::Attribute constraint(def);
933  decls.push_back(createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
934  constraint, convertLocToRange(def->getLoc().front()), attrTy,
935  constraint.getStorageType()));
936  }
937  /// Type constraints.
938  for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) {
939  if (shouldBeSkipped(def))
940  continue;
941 
942  tblgen::TypeConstraint constraint(def);
943  decls.push_back(createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
944  constraint, convertLocToRange(def->getLoc().front()), typeTy,
945  constraint.getCPPClassName()));
946  }
947  /// OpInterfaces.
949  for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("OpInterface")) {
950  if (shouldBeSkipped(def))
951  continue;
952 
953  SMRange loc = convertLocToRange(def->getLoc().front());
954 
955  std::string cppClassName =
956  llvm::formatv("{0}::{1}", def->getValueAsString("cppNamespace"),
957  def->getValueAsString("cppInterfaceName"))
958  .str();
959  std::string codeBlock =
960  llvm::formatv("return ::mlir::success(llvm::isa<{0}>(self));",
961  cppClassName)
962  .str();
963 
964  std::string desc =
965  processAndFormatDoc(def->getValueAsString("description"));
966  decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>(
967  def->getName(), codeBlock, loc, opTy, cppClassName, desc));
968  }
969 }
970 
971 template <typename ConstraintT>
972 ast::Decl *Parser::createODSNativePDLLConstraintDecl(
973  StringRef name, StringRef codeBlock, SMRange loc, ast::Type type,
974  StringRef nativeType, StringRef docString) {
975  // Build the single input parameter.
976  ast::DeclScope *argScope = pushDeclScope();
977  auto *paramVar = ast::VariableDecl::create(
978  ctx, ast::Name::create(ctx, "self", loc), type,
979  /*initExpr=*/nullptr, ast::ConstraintRef(ConstraintT::create(ctx, loc)));
980  argScope->add(paramVar);
981  popDeclScope();
982 
983  // Build the native constraint.
984  auto *constraintDecl = ast::UserConstraintDecl::createNative(
985  ctx, ast::Name::create(ctx, name, loc), paramVar,
986  /*results=*/std::nullopt, codeBlock, ast::TupleType::get(ctx),
987  nativeType);
988  constraintDecl->setDocComment(ctx, docString);
989  curDeclScope->add(constraintDecl);
990  return constraintDecl;
991 }
992 
993 template <typename ConstraintT>
994 ast::Decl *
995 Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
996  SMRange loc, ast::Type type,
997  StringRef nativeType) {
998  // Format the condition template.
999  tblgen::FmtContext fmtContext;
1000  fmtContext.withSelf("self");
1001  std::string codeBlock = tblgen::tgfmt(
1002  "return ::mlir::success(" + constraint.getConditionTemplate() + ");",
1003  &fmtContext);
1004 
1005  // If documentation was enabled, build the doc string for the generated
1006  // constraint. It would be nice to do this lazily, but TableGen information is
1007  // destroyed after we finish parsing the file.
1008  std::string docString;
1009  if (enableDocumentation) {
1010  StringRef desc = constraint.getDescription();
1011  docString = processAndFormatDoc(
1012  constraint.getSummary() +
1013  (desc.empty() ? "" : ("\n\n" + constraint.getDescription())));
1014  }
1015 
1016  return createODSNativePDLLConstraintDecl<ConstraintT>(
1017  constraint.getUniqueDefName(), codeBlock, loc, type, nativeType,
1018  docString);
1019 }
1020 
1021 //===----------------------------------------------------------------------===//
1022 // Decls
1023 
1024 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
1026  switch (curToken.getKind()) {
1027  case Token::kw_Constraint:
1028  decl = parseUserConstraintDecl();
1029  break;
1030  case Token::kw_Pattern:
1031  decl = parsePatternDecl();
1032  break;
1033  case Token::kw_Rewrite:
1034  decl = parseUserRewriteDecl();
1035  break;
1036  default:
1037  return emitError("expected top-level declaration, such as a `Pattern`");
1038  }
1039  if (failed(decl))
1040  return failure();
1041 
1042  // If the decl has a name, add it to the current scope.
1043  if (const ast::Name *name = (*decl)->getName()) {
1044  if (failed(checkDefineNamedDecl(*name)))
1045  return failure();
1046  curDeclScope->add(*decl);
1047  }
1048  return decl;
1049 }
1050 
1052 Parser::parseNamedAttributeDecl(std::optional<StringRef> parentOpName) {
1053  // Check for name code completion.
1054  if (curToken.is(Token::code_complete))
1055  return codeCompleteAttributeName(parentOpName);
1056 
1057  std::string attrNameStr;
1058  if (curToken.isString())
1059  attrNameStr = curToken.getStringValue();
1060  else if (curToken.is(Token::identifier) || curToken.isKeyword())
1061  attrNameStr = curToken.getSpelling().str();
1062  else
1063  return emitError("expected identifier or string attribute name");
1064  const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc());
1065  consumeToken();
1066 
1067  // Check for a value of the attribute.
1068  ast::Expr *attrValue = nullptr;
1069  if (consumeIf(Token::equal)) {
1070  FailureOr<ast::Expr *> attrExpr = parseExpr();
1071  if (failed(attrExpr))
1072  return failure();
1073  attrValue = *attrExpr;
1074  } else {
1075  // If there isn't a concrete value, create an expression representing a
1076  // UnitAttr.
1077  attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit");
1078  }
1079 
1080  return ast::NamedAttributeDecl::create(ctx, name, attrValue);
1081 }
1082 
1083 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody(
1084  function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
1085  bool expectTerminalSemicolon) {
1086  consumeToken(Token::equal_arrow);
1087 
1088  // Parse the single statement of the lambda body.
1089  SMLoc bodyStartLoc = curToken.getStartLoc();
1090  pushDeclScope();
1091  FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon);
1092  bool failedToParse =
1093  failed(singleStatement) || failed(processStatementFn(*singleStatement));
1094  popDeclScope();
1095  if (failedToParse)
1096  return failure();
1097 
1098  SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc());
1099  return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement);
1100 }
1101 
1102 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() {
1103  // Ensure that the argument is named.
1104  if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword())
1105  return emitError("expected identifier argument name");
1106 
1107  // Parse the argument similarly to a normal variable.
1108  StringRef name = curToken.getSpelling();
1109  SMRange nameLoc = curToken.getLoc();
1110  consumeToken();
1111 
1112  if (failed(
1113  parseToken(Token::colon, "expected `:` before argument constraint")))
1114  return failure();
1115 
1116  FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1117  if (failed(cst))
1118  return failure();
1119 
1120  return createArgOrResultVariableDecl(name, nameLoc, *cst);
1121 }
1122 
1123 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) {
1124  // Check to see if this result is named.
1125  if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
1126  // Check to see if this name actually refers to a Constraint.
1127  if (!curDeclScope->lookup<ast::ConstraintDecl>(curToken.getSpelling())) {
1128  // If it wasn't a constraint, parse the result similarly to a variable. If
1129  // there is already an existing decl, we will emit an error when defining
1130  // this variable later.
1131  StringRef name = curToken.getSpelling();
1132  SMRange nameLoc = curToken.getLoc();
1133  consumeToken();
1134 
1135  if (failed(parseToken(Token::colon,
1136  "expected `:` before result constraint")))
1137  return failure();
1138 
1139  FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1140  if (failed(cst))
1141  return failure();
1142 
1143  return createArgOrResultVariableDecl(name, nameLoc, *cst);
1144  }
1145  }
1146 
1147  // If it isn't named, we parse the constraint directly and create an unnamed
1148  // result variable.
1149  FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1150  if (failed(cst))
1151  return failure();
1152 
1153  return createArgOrResultVariableDecl("", cst->referenceLoc, *cst);
1154 }
1155 
1157 Parser::parseUserConstraintDecl(bool isInline) {
1158  // Constraints and rewrites have very similar formats, dispatch to a shared
1159  // interface for parsing.
1160  return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1161  [&](auto &&...args) {
1162  return this->parseUserPDLLConstraintDecl(args...);
1163  },
1164  ParserContext::Constraint, "constraint", isInline);
1165 }
1166 
1167 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() {
1169  parseUserConstraintDecl(/*isInline=*/true);
1170  if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1171  return failure();
1172 
1173  curDeclScope->add(*decl);
1174  return decl;
1175 }
1176 
1177 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl(
1178  const ast::Name &name, bool isInline,
1179  ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1180  ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1181  // Push the argument scope back onto the list, so that the body can
1182  // reference arguments.
1183  pushDeclScope(argumentScope);
1184 
1185  // Parse the body of the constraint. The body is either defined as a compound
1186  // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`.
1187  ast::CompoundStmt *body;
1188  if (curToken.is(Token::equal_arrow)) {
1189  FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1190  [&](ast::Stmt *&stmt) -> LogicalResult {
1191  ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt);
1192  if (!stmtExpr) {
1193  return emitError(stmt->getLoc(),
1194  "expected `Constraint` lambda body to contain a "
1195  "single expression");
1196  }
1197  stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr);
1198  return success();
1199  },
1200  /*expectTerminalSemicolon=*/!isInline);
1201  if (failed(bodyResult))
1202  return failure();
1203  body = *bodyResult;
1204  } else {
1205  FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1206  if (failed(bodyResult))
1207  return failure();
1208  body = *bodyResult;
1209 
1210  // Verify the structure of the body.
1211  auto bodyIt = body->begin(), bodyE = body->end();
1212  for (; bodyIt != bodyE; ++bodyIt)
1213  if (isa<ast::ReturnStmt>(*bodyIt))
1214  break;
1215  if (failed(validateUserConstraintOrRewriteReturn(
1216  "Constraint", body, bodyIt, bodyE, results, resultType)))
1217  return failure();
1218  }
1219  popDeclScope();
1220 
1221  return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1222  name, arguments, results, resultType, body);
1223 }
1224 
1225 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) {
1226  // Constraints and rewrites have very similar formats, dispatch to a shared
1227  // interface for parsing.
1228  return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1229  [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(args...); },
1230  ParserContext::Rewrite, "rewrite", isInline);
1231 }
1232 
1233 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() {
1235  parseUserRewriteDecl(/*isInline=*/true);
1236  if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1237  return failure();
1238 
1239  curDeclScope->add(*decl);
1240  return decl;
1241 }
1242 
1243 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl(
1244  const ast::Name &name, bool isInline,
1245  ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1246  ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1247  // Push the argument scope back onto the list, so that the body can
1248  // reference arguments.
1249  curDeclScope = argumentScope;
1250  ast::CompoundStmt *body;
1251  if (curToken.is(Token::equal_arrow)) {
1252  FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1253  [&](ast::Stmt *&statement) -> LogicalResult {
1254  if (isa<ast::OpRewriteStmt>(statement))
1255  return success();
1256 
1257  ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement);
1258  if (!statementExpr) {
1259  return emitError(
1260  statement->getLoc(),
1261  "expected `Rewrite` lambda body to contain a single expression "
1262  "or an operation rewrite statement; such as `erase`, "
1263  "`replace`, or `rewrite`");
1264  }
1265  statement =
1266  ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr);
1267  return success();
1268  },
1269  /*expectTerminalSemicolon=*/!isInline);
1270  if (failed(bodyResult))
1271  return failure();
1272  body = *bodyResult;
1273  } else {
1274  FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1275  if (failed(bodyResult))
1276  return failure();
1277  body = *bodyResult;
1278  }
1279  popDeclScope();
1280 
1281  // Verify the structure of the body.
1282  auto bodyIt = body->begin(), bodyE = body->end();
1283  for (; bodyIt != bodyE; ++bodyIt)
1284  if (isa<ast::ReturnStmt>(*bodyIt))
1285  break;
1286  if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt,
1287  bodyE, results, resultType)))
1288  return failure();
1289  return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1290  name, arguments, results, resultType, body);
1291 }
1292 
1293 template <typename T, typename ParseUserPDLLDeclFnT>
1294 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl(
1295  ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
1296  StringRef anonymousNamePrefix, bool isInline) {
1297  SMRange loc = curToken.getLoc();
1298  consumeToken();
1299  llvm::SaveAndRestore saveCtx(parserContext, declContext);
1300 
1301  // Parse the name of the decl.
1302  const ast::Name *name = nullptr;
1303  if (curToken.isNot(Token::identifier)) {
1304  // Only inline decls can be un-named. Inline decls are similar to "lambdas"
1305  // in C++, so being unnamed is fine.
1306  if (!isInline)
1307  return emitError("expected identifier name");
1308 
1309  // Create a unique anonymous name to use, as the name for this decl is not
1310  // important.
1311  std::string anonName =
1312  llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix,
1313  anonymousDeclNameCounter++)
1314  .str();
1315  name = &ast::Name::create(ctx, anonName, loc);
1316  } else {
1317  // If a name was provided, we can use it directly.
1318  name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1319  consumeToken(Token::identifier);
1320  }
1321 
1322  // Parse the functional signature of the decl.
1323  SmallVector<ast::VariableDecl *> arguments, results;
1324  ast::DeclScope *argumentScope;
1325  ast::Type resultType;
1326  if (failed(parseUserConstraintOrRewriteSignature(arguments, results,
1327  argumentScope, resultType)))
1328  return failure();
1329 
1330  // Check to see which type of constraint this is. If the constraint contains a
1331  // compound body, this is a PDLL decl.
1332  if (curToken.isAny(Token::l_brace, Token::equal_arrow))
1333  return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results,
1334  resultType);
1335 
1336  // Otherwise, this is a native decl.
1337  return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments,
1338  results, resultType);
1339 }
1340 
1341 template <typename T>
1342 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
1343  const ast::Name &name, bool isInline,
1345  ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1346  // If followed by a string, the native code body has also been specified.
1347  std::string codeStrStorage;
1348  std::optional<StringRef> optCodeStr;
1349  if (curToken.isString()) {
1350  codeStrStorage = curToken.getStringValue();
1351  optCodeStr = codeStrStorage;
1352  consumeToken();
1353  } else if (isInline) {
1354  return emitError(name.getLoc(),
1355  "external declarations must be declared in global scope");
1356  } else if (curToken.is(Token::error)) {
1357  return failure();
1358  }
1359  if (failed(parseToken(Token::semicolon,
1360  "expected `;` after native declaration")))
1361  return failure();
1362  // TODO: PDL should be able to support constraint results in certain
1363  // situations, we should revise this.
1364  if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) {
1365  return emitError(
1366  "native Constraints currently do not support returning results");
1367  }
1368  return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
1369 }
1370 
1371 LogicalResult Parser::parseUserConstraintOrRewriteSignature(
1374  ast::DeclScope *&argumentScope, ast::Type &resultType) {
1375  // Parse the argument list of the decl.
1376  if (failed(parseToken(Token::l_paren, "expected `(` to start argument list")))
1377  return failure();
1378 
1379  argumentScope = pushDeclScope();
1380  if (curToken.isNot(Token::r_paren)) {
1381  do {
1382  FailureOr<ast::VariableDecl *> argument = parseArgumentDecl();
1383  if (failed(argument))
1384  return failure();
1385  arguments.emplace_back(*argument);
1386  } while (consumeIf(Token::comma));
1387  }
1388  popDeclScope();
1389  if (failed(parseToken(Token::r_paren, "expected `)` to end argument list")))
1390  return failure();
1391 
1392  // Parse the results of the decl.
1393  pushDeclScope();
1394  if (consumeIf(Token::arrow)) {
1395  auto parseResultFn = [&]() -> LogicalResult {
1396  FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size());
1397  if (failed(result))
1398  return failure();
1399  results.emplace_back(*result);
1400  return success();
1401  };
1402 
1403  // Check for a list of results.
1404  if (consumeIf(Token::l_paren)) {
1405  do {
1406  if (failed(parseResultFn()))
1407  return failure();
1408  } while (consumeIf(Token::comma));
1409  if (failed(parseToken(Token::r_paren, "expected `)` to end result list")))
1410  return failure();
1411 
1412  // Otherwise, there is only one result.
1413  } else if (failed(parseResultFn())) {
1414  return failure();
1415  }
1416  }
1417  popDeclScope();
1418 
1419  // Compute the result type of the decl.
1420  resultType = createUserConstraintRewriteResultType(results);
1421 
1422  // Verify that results are only named if there are more than one.
1423  if (results.size() == 1 && !results.front()->getName().getName().empty()) {
1424  return emitError(
1425  results.front()->getLoc(),
1426  "cannot create a single-element tuple with an element label");
1427  }
1428  return success();
1429 }
1430 
1431 LogicalResult Parser::validateUserConstraintOrRewriteReturn(
1432  StringRef declType, ast::CompoundStmt *body,
1435  ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) {
1436  // Handle if a `return` was provided.
1437  if (bodyIt != bodyE) {
1438  // Emit an error if we have trailing statements after the return.
1439  if (std::next(bodyIt) != bodyE) {
1440  return emitError(
1441  (*std::next(bodyIt))->getLoc(),
1442  llvm::formatv("`return` terminated the `{0}` body, but found "
1443  "trailing statements afterwards",
1444  declType));
1445  }
1446 
1447  // Otherwise if a return wasn't provided, check that no results are
1448  // expected.
1449  } else if (!results.empty()) {
1450  return emitError(
1451  {body->getLoc().End, body->getLoc().End},
1452  llvm::formatv("missing return in a `{0}` expected to return `{1}`",
1453  declType, resultType));
1454  }
1455  return success();
1456 }
1457 
1458 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
1459  return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult {
1460  if (isa<ast::OpRewriteStmt>(statement))
1461  return success();
1462  return emitError(
1463  statement->getLoc(),
1464  "expected Pattern lambda body to contain a single operation "
1465  "rewrite statement, such as `erase`, `replace`, or `rewrite`");
1466  });
1467 }
1468 
1469 FailureOr<ast::Decl *> Parser::parsePatternDecl() {
1470  SMRange loc = curToken.getLoc();
1471  consumeToken(Token::kw_Pattern);
1472  llvm::SaveAndRestore saveCtx(parserContext, ParserContext::PatternMatch);
1473 
1474  // Check for an optional identifier for the pattern name.
1475  const ast::Name *name = nullptr;
1476  if (curToken.is(Token::identifier)) {
1477  name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1478  consumeToken(Token::identifier);
1479  }
1480 
1481  // Parse any pattern metadata.
1482  ParsedPatternMetadata metadata;
1483  if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata)))
1484  return failure();
1485 
1486  // Parse the pattern body.
1487  ast::CompoundStmt *body;
1488 
1489  // Handle a lambda body.
1490  if (curToken.is(Token::equal_arrow)) {
1491  FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody();
1492  if (failed(bodyResult))
1493  return failure();
1494  body = *bodyResult;
1495  } else {
1496  if (curToken.isNot(Token::l_brace))
1497  return emitError("expected `{` or `=>` to start pattern body");
1498  FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1499  if (failed(bodyResult))
1500  return failure();
1501  body = *bodyResult;
1502 
1503  // Verify the body of the pattern.
1504  auto bodyIt = body->begin(), bodyE = body->end();
1505  for (; bodyIt != bodyE; ++bodyIt) {
1506  if (isa<ast::ReturnStmt>(*bodyIt)) {
1507  return emitError((*bodyIt)->getLoc(),
1508  "`return` statements are only permitted within a "
1509  "`Constraint` or `Rewrite` body");
1510  }
1511  // Break when we've found the rewrite statement.
1512  if (isa<ast::OpRewriteStmt>(*bodyIt))
1513  break;
1514  }
1515  if (bodyIt == bodyE) {
1516  return emitError(loc,
1517  "expected Pattern body to terminate with an operation "
1518  "rewrite statement, such as `erase`");
1519  }
1520  if (std::next(bodyIt) != bodyE) {
1521  return emitError((*std::next(bodyIt))->getLoc(),
1522  "Pattern body was terminated by an operation "
1523  "rewrite statement, but found trailing statements");
1524  }
1525  }
1526 
1527  return createPatternDecl(loc, name, metadata, body);
1528 }
1529 
1531 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
1532  std::optional<SMRange> benefitLoc;
1533  std::optional<SMRange> hasBoundedRecursionLoc;
1534 
1535  do {
1536  // Handle metadata code completion.
1537  if (curToken.is(Token::code_complete))
1538  return codeCompletePatternMetadata();
1539 
1540  if (curToken.isNot(Token::identifier))
1541  return emitError("expected pattern metadata identifier");
1542  StringRef metadataStr = curToken.getSpelling();
1543  SMRange metadataLoc = curToken.getLoc();
1544  consumeToken(Token::identifier);
1545 
1546  // Parse the benefit metadata: benefit(<integer-value>)
1547  if (metadataStr == "benefit") {
1548  if (benefitLoc) {
1549  return emitErrorAndNote(metadataLoc,
1550  "pattern benefit has already been specified",
1551  *benefitLoc, "see previous definition here");
1552  }
1553  if (failed(parseToken(Token::l_paren,
1554  "expected `(` before pattern benefit")))
1555  return failure();
1556 
1557  uint16_t benefitValue = 0;
1558  if (curToken.isNot(Token::integer))
1559  return emitError("expected integral pattern benefit");
1560  if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue))
1561  return emitError(
1562  "expected pattern benefit to fit within a 16-bit integer");
1563  consumeToken(Token::integer);
1564 
1565  metadata.benefit = benefitValue;
1566  benefitLoc = metadataLoc;
1567 
1568  if (failed(
1569  parseToken(Token::r_paren, "expected `)` after pattern benefit")))
1570  return failure();
1571  continue;
1572  }
1573 
1574  // Parse the bounded recursion metadata: recursion
1575  if (metadataStr == "recursion") {
1576  if (hasBoundedRecursionLoc) {
1577  return emitErrorAndNote(
1578  metadataLoc,
1579  "pattern recursion metadata has already been specified",
1580  *hasBoundedRecursionLoc, "see previous definition here");
1581  }
1582  metadata.hasBoundedRecursion = true;
1583  hasBoundedRecursionLoc = metadataLoc;
1584  continue;
1585  }
1586 
1587  return emitError(metadataLoc, "unknown pattern metadata");
1588  } while (consumeIf(Token::comma));
1589 
1590  return success();
1591 }
1592 
1593 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() {
1594  consumeToken(Token::less);
1595 
1596  FailureOr<ast::Expr *> typeExpr = parseExpr();
1597  if (failed(typeExpr) ||
1598  failed(parseToken(Token::greater,
1599  "expected `>` after variable type constraint")))
1600  return failure();
1601  return typeExpr;
1602 }
1603 
1604 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) {
1605  assert(curDeclScope && "defining decl outside of a decl scope");
1606  if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) {
1607  return emitErrorAndNote(
1608  name.getLoc(), "`" + name.getName() + "` has already been defined",
1609  lastDecl->getName()->getLoc(), "see previous definition here");
1610  }
1611  return success();
1612 }
1613 
1615 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1616  ast::Expr *initExpr,
1617  ArrayRef<ast::ConstraintRef> constraints) {
1618  assert(curDeclScope && "defining variable outside of decl scope");
1619  const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc);
1620 
1621  // If the name of the variable indicates a special variable, we don't add it
1622  // to the scope. This variable is local to the definition point.
1623  if (name.empty() || name == "_") {
1624  return ast::VariableDecl::create(ctx, nameDecl, type, initExpr,
1625  constraints);
1626  }
1627  if (failed(checkDefineNamedDecl(nameDecl)))
1628  return failure();
1629 
1630  auto *varDecl =
1631  ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints);
1632  curDeclScope->add(varDecl);
1633  return varDecl;
1634 }
1635 
1637 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1638  ArrayRef<ast::ConstraintRef> constraints) {
1639  return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr,
1640  constraints);
1641 }
1642 
1643 LogicalResult Parser::parseVariableDeclConstraintList(
1644  SmallVectorImpl<ast::ConstraintRef> &constraints) {
1645  std::optional<SMRange> typeConstraint;
1646  auto parseSingleConstraint = [&] {
1647  FailureOr<ast::ConstraintRef> constraint = parseConstraint(
1648  typeConstraint, constraints, /*allowInlineTypeConstraints=*/true);
1649  if (failed(constraint))
1650  return failure();
1651  constraints.push_back(*constraint);
1652  return success();
1653  };
1654 
1655  // Check to see if this is a single constraint, or a list.
1656  if (!consumeIf(Token::l_square))
1657  return parseSingleConstraint();
1658 
1659  do {
1660  if (failed(parseSingleConstraint()))
1661  return failure();
1662  } while (consumeIf(Token::comma));
1663  return parseToken(Token::r_square, "expected `]` after constraint list");
1664 }
1665 
1667 Parser::parseConstraint(std::optional<SMRange> &typeConstraint,
1668  ArrayRef<ast::ConstraintRef> existingConstraints,
1669  bool allowInlineTypeConstraints) {
1670  auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
1671  if (!allowInlineTypeConstraints) {
1672  return emitError(
1673  curToken.getLoc(),
1674  "inline `Attr`, `Value`, and `ValueRange` type constraints are not "
1675  "permitted on arguments or results");
1676  }
1677  if (typeConstraint)
1678  return emitErrorAndNote(
1679  curToken.getLoc(),
1680  "the type of this variable has already been constrained",
1681  *typeConstraint, "see previous constraint location here");
1682  FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr();
1683  if (failed(constraintExpr))
1684  return failure();
1685  typeExpr = *constraintExpr;
1686  typeConstraint = typeExpr->getLoc();
1687  return success();
1688  };
1689 
1690  SMRange loc = curToken.getLoc();
1691  switch (curToken.getKind()) {
1692  case Token::kw_Attr: {
1693  consumeToken(Token::kw_Attr);
1694 
1695  // Check for a type constraint.
1696  ast::Expr *typeExpr = nullptr;
1697  if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1698  return failure();
1699  return ast::ConstraintRef(
1700  ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc);
1701  }
1702  case Token::kw_Op: {
1703  consumeToken(Token::kw_Op);
1704 
1705  // Parse an optional operation name. If the name isn't provided, this refers
1706  // to "any" operation.
1708  parseWrappedOperationName(/*allowEmptyName=*/true);
1709  if (failed(opName))
1710  return failure();
1711 
1712  return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName),
1713  loc);
1714  }
1715  case Token::kw_Type:
1716  consumeToken(Token::kw_Type);
1717  return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc);
1718  case Token::kw_TypeRange:
1719  consumeToken(Token::kw_TypeRange);
1721  loc);
1722  case Token::kw_Value: {
1723  consumeToken(Token::kw_Value);
1724 
1725  // Check for a type constraint.
1726  ast::Expr *typeExpr = nullptr;
1727  if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1728  return failure();
1729 
1730  return ast::ConstraintRef(
1731  ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc);
1732  }
1733  case Token::kw_ValueRange: {
1734  consumeToken(Token::kw_ValueRange);
1735 
1736  // Check for a type constraint.
1737  ast::Expr *typeExpr = nullptr;
1738  if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1739  return failure();
1740 
1741  return ast::ConstraintRef(
1742  ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc);
1743  }
1744 
1745  case Token::kw_Constraint: {
1746  // Handle an inline constraint.
1747  FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1748  if (failed(decl))
1749  return failure();
1750  return ast::ConstraintRef(*decl, loc);
1751  }
1752  case Token::identifier: {
1753  StringRef constraintName = curToken.getSpelling();
1754  consumeToken(Token::identifier);
1755 
1756  // Lookup the referenced constraint.
1757  ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName);
1758  if (!cstDecl) {
1759  return emitError(loc, "unknown reference to constraint `" +
1760  constraintName + "`");
1761  }
1762 
1763  // Handle a reference to a proper constraint.
1764  if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl))
1765  return ast::ConstraintRef(cst, loc);
1766 
1767  return emitErrorAndNote(
1768  loc, "invalid reference to non-constraint", cstDecl->getLoc(),
1769  "see the definition of `" + constraintName + "` here");
1770  }
1771  // Handle single entity constraint code completion.
1772  case Token::code_complete: {
1773  // Try to infer the current type for use by code completion.
1774  ast::Type inferredType;
1775  if (failed(validateVariableConstraints(existingConstraints, inferredType)))
1776  return failure();
1777 
1778  return codeCompleteConstraintName(inferredType, allowInlineTypeConstraints);
1779  }
1780  default:
1781  break;
1782  }
1783  return emitError(loc, "expected identifier constraint");
1784 }
1785 
1786 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() {
1787  std::optional<SMRange> typeConstraint;
1788  return parseConstraint(typeConstraint, /*existingConstraints=*/std::nullopt,
1789  /*allowInlineTypeConstraints=*/false);
1790 }
1791 
1792 //===----------------------------------------------------------------------===//
1793 // Exprs
1794 
1795 FailureOr<ast::Expr *> Parser::parseExpr() {
1796  if (curToken.is(Token::underscore))
1797  return parseUnderscoreExpr();
1798 
1799  // Parse the LHS expression.
1800  FailureOr<ast::Expr *> lhsExpr;
1801  switch (curToken.getKind()) {
1802  case Token::kw_attr:
1803  lhsExpr = parseAttributeExpr();
1804  break;
1805  case Token::kw_Constraint:
1806  lhsExpr = parseInlineConstraintLambdaExpr();
1807  break;
1808  case Token::identifier:
1809  lhsExpr = parseIdentifierExpr();
1810  break;
1811  case Token::kw_op:
1812  lhsExpr = parseOperationExpr();
1813  break;
1814  case Token::kw_Rewrite:
1815  lhsExpr = parseInlineRewriteLambdaExpr();
1816  break;
1817  case Token::kw_type:
1818  lhsExpr = parseTypeExpr();
1819  break;
1820  case Token::l_paren:
1821  lhsExpr = parseTupleExpr();
1822  break;
1823  default:
1824  return emitError("expected expression");
1825  }
1826  if (failed(lhsExpr))
1827  return failure();
1828 
1829  // Check for an operator expression.
1830  while (true) {
1831  switch (curToken.getKind()) {
1832  case Token::dot:
1833  lhsExpr = parseMemberAccessExpr(*lhsExpr);
1834  break;
1835  case Token::l_paren:
1836  lhsExpr = parseCallExpr(*lhsExpr);
1837  break;
1838  default:
1839  return lhsExpr;
1840  }
1841  if (failed(lhsExpr))
1842  return failure();
1843  }
1844 }
1845 
1846 FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
1847  SMRange loc = curToken.getLoc();
1848  consumeToken(Token::kw_attr);
1849 
1850  // If we aren't followed by a `<`, the `attr` keyword is treated as a normal
1851  // identifier.
1852  if (!consumeIf(Token::less)) {
1853  resetToken(loc);
1854  return parseIdentifierExpr();
1855  }
1856 
1857  if (!curToken.isString())
1858  return emitError("expected string literal containing MLIR attribute");
1859  std::string attrExpr = curToken.getStringValue();
1860  consumeToken();
1861 
1862  loc.End = curToken.getEndLoc();
1863  if (failed(
1864  parseToken(Token::greater, "expected `>` after attribute literal")))
1865  return failure();
1866  return ast::AttributeExpr::create(ctx, loc, attrExpr);
1867 }
1868 
1869 FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) {
1870  consumeToken(Token::l_paren);
1871 
1872  // Parse the arguments of the call.
1873  SmallVector<ast::Expr *> arguments;
1874  if (curToken.isNot(Token::r_paren)) {
1875  do {
1876  // Handle code completion for the call arguments.
1877  if (curToken.is(Token::code_complete)) {
1878  codeCompleteCallSignature(parentExpr, arguments.size());
1879  return failure();
1880  }
1881 
1882  FailureOr<ast::Expr *> argument = parseExpr();
1883  if (failed(argument))
1884  return failure();
1885  arguments.push_back(*argument);
1886  } while (consumeIf(Token::comma));
1887  }
1888 
1889  SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1890  if (failed(parseToken(Token::r_paren, "expected `)` after argument list")))
1891  return failure();
1892 
1893  return createCallExpr(loc, parentExpr, arguments);
1894 }
1895 
1896 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
1897  ast::Decl *decl = curDeclScope->lookup(name);
1898  if (!decl)
1899  return emitError(loc, "undefined reference to `" + name + "`");
1900 
1901  return createDeclRefExpr(loc, decl);
1902 }
1903 
1904 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
1905  StringRef name = curToken.getSpelling();
1906  SMRange nameLoc = curToken.getLoc();
1907  consumeToken();
1908 
1909  // Check to see if this is a decl ref expression that defines a variable
1910  // inline.
1911  if (consumeIf(Token::colon)) {
1912  SmallVector<ast::ConstraintRef> constraints;
1913  if (failed(parseVariableDeclConstraintList(constraints)))
1914  return failure();
1915  ast::Type type;
1916  if (failed(validateVariableConstraints(constraints, type)))
1917  return failure();
1918  return createInlineVariableExpr(type, name, nameLoc, constraints);
1919  }
1920 
1921  return parseDeclRefExpr(name, nameLoc);
1922 }
1923 
1924 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() {
1925  FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1926  if (failed(decl))
1927  return failure();
1928 
1929  return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1931 }
1932 
1933 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() {
1934  FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl();
1935  if (failed(decl))
1936  return failure();
1937 
1938  return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1939  ast::RewriteType::get(ctx));
1940 }
1941 
1942 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
1943  SMRange dotLoc = curToken.getLoc();
1944  consumeToken(Token::dot);
1945 
1946  // Check for code completion of the member name.
1947  if (curToken.is(Token::code_complete))
1948  return codeCompleteMemberAccess(parentExpr);
1949 
1950  // Parse the member name.
1951  Token memberNameTok = curToken;
1952  if (memberNameTok.isNot(Token::identifier, Token::integer) &&
1953  !memberNameTok.isKeyword())
1954  return emitError(dotLoc, "expected identifier or numeric member name");
1955  StringRef memberName = memberNameTok.getSpelling();
1956  SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1957  consumeToken();
1958 
1959  return createMemberAccessExpr(parentExpr, memberName, loc);
1960 }
1961 
1962 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
1963  SMRange loc = curToken.getLoc();
1964 
1965  // Check for code completion for the dialect name.
1966  if (curToken.is(Token::code_complete))
1967  return codeCompleteDialectName();
1968 
1969  // Handle the case of an no operation name.
1970  if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) {
1971  if (allowEmptyName)
1972  return ast::OpNameDecl::create(ctx, SMRange());
1973  return emitError("expected dialect namespace");
1974  }
1975  StringRef name = curToken.getSpelling();
1976  consumeToken();
1977 
1978  // Otherwise, this is a literal operation name.
1979  if (failed(parseToken(Token::dot, "expected `.` after dialect namespace")))
1980  return failure();
1981 
1982  // Check for code completion for the operation name.
1983  if (curToken.is(Token::code_complete))
1984  return codeCompleteOperationName(name);
1985 
1986  if (curToken.isNot(Token::identifier) && !curToken.isKeyword())
1987  return emitError("expected operation name after dialect namespace");
1988 
1989  name = StringRef(name.data(), name.size() + 1);
1990  do {
1991  name = StringRef(name.data(), name.size() + curToken.getSpelling().size());
1992  loc.End = curToken.getEndLoc();
1993  consumeToken();
1994  } while (curToken.isAny(Token::identifier, Token::dot) ||
1995  curToken.isKeyword());
1996  return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc));
1997 }
1998 
2000 Parser::parseWrappedOperationName(bool allowEmptyName) {
2001  if (!consumeIf(Token::less))
2002  return ast::OpNameDecl::create(ctx, SMRange());
2003 
2004  FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName);
2005  if (failed(opNameDecl))
2006  return failure();
2007 
2008  if (failed(parseToken(Token::greater, "expected `>` after operation name")))
2009  return failure();
2010  return opNameDecl;
2011 }
2012 
2014 Parser::parseOperationExpr(OpResultTypeContext inputResultTypeContext) {
2015  SMRange loc = curToken.getLoc();
2016  consumeToken(Token::kw_op);
2017 
2018  // If it isn't followed by a `<`, the `op` keyword is treated as a normal
2019  // identifier.
2020  if (curToken.isNot(Token::less)) {
2021  resetToken(loc);
2022  return parseIdentifierExpr();
2023  }
2024 
2025  // Parse the operation name. The name may be elided, in which case the
2026  // operation refers to "any" operation(i.e. a difference between `MyOp` and
2027  // `Operation*`). Operation names within a rewrite context must be named.
2028  bool allowEmptyName = parserContext != ParserContext::Rewrite;
2029  FailureOr<ast::OpNameDecl *> opNameDecl =
2030  parseWrappedOperationName(allowEmptyName);
2031  if (failed(opNameDecl))
2032  return failure();
2033  std::optional<StringRef> opName = (*opNameDecl)->getName();
2034 
2035  // Functor used to create an implicit range variable, used for implicit "all"
2036  // operand or results variables.
2037  auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) {
2039  defineVariableDecl("_", loc, type, ast::ConstraintRef(cst, loc));
2040  assert(succeeded(rangeVar) && "expected range variable to be valid");
2041  return ast::DeclRefExpr::create(ctx, loc, *rangeVar, type);
2042  };
2043 
2044  // Check for the optional list of operands.
2045  SmallVector<ast::Expr *> operands;
2046  if (!consumeIf(Token::l_paren)) {
2047  // If the operand list isn't specified and we are in a match context, define
2048  // an inplace unconstrained operand range corresponding to all of the
2049  // operands of the operation. This avoids treating zero operands the same
2050  // way as "unconstrained operands".
2051  if (parserContext != ParserContext::Rewrite) {
2052  operands.push_back(createImplicitRangeVar(
2053  ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy));
2054  }
2055  } else if (!consumeIf(Token::r_paren)) {
2056  // If the operand list was specified and non-empty, parse the operands.
2057  do {
2058  // Check for operand signature code completion.
2059  if (curToken.is(Token::code_complete)) {
2060  codeCompleteOperationOperandsSignature(opName, operands.size());
2061  return failure();
2062  }
2063 
2064  FailureOr<ast::Expr *> operand = parseExpr();
2065  if (failed(operand))
2066  return failure();
2067  operands.push_back(*operand);
2068  } while (consumeIf(Token::comma));
2069 
2070  if (failed(parseToken(Token::r_paren,
2071  "expected `)` after operation operand list")))
2072  return failure();
2073  }
2074 
2075  // Check for the optional list of attributes.
2077  if (consumeIf(Token::l_brace)) {
2078  do {
2080  parseNamedAttributeDecl(opName);
2081  if (failed(decl))
2082  return failure();
2083  attributes.emplace_back(*decl);
2084  } while (consumeIf(Token::comma));
2085 
2086  if (failed(parseToken(Token::r_brace,
2087  "expected `}` after operation attribute list")))
2088  return failure();
2089  }
2090 
2091  // Handle the result types of the operation.
2092  SmallVector<ast::Expr *> resultTypes;
2093  OpResultTypeContext resultTypeContext = inputResultTypeContext;
2094 
2095  // Check for an explicit list of result types.
2096  if (consumeIf(Token::arrow)) {
2097  if (failed(parseToken(Token::l_paren,
2098  "expected `(` before operation result type list")))
2099  return failure();
2100 
2101  // If result types are provided, initially assume that the operation does
2102  // not rely on type inferrence. We don't assert that it isn't, because we
2103  // may be inferring the value of some type/type range variables, but given
2104  // that these variables may be defined in calls we can't always discern when
2105  // this is the case.
2106  resultTypeContext = OpResultTypeContext::Explicit;
2107 
2108  // Handle the case of an empty result list.
2109  if (!consumeIf(Token::r_paren)) {
2110  do {
2111  // Check for result signature code completion.
2112  if (curToken.is(Token::code_complete)) {
2113  codeCompleteOperationResultsSignature(opName, resultTypes.size());
2114  return failure();
2115  }
2116 
2117  FailureOr<ast::Expr *> resultTypeExpr = parseExpr();
2118  if (failed(resultTypeExpr))
2119  return failure();
2120  resultTypes.push_back(*resultTypeExpr);
2121  } while (consumeIf(Token::comma));
2122 
2123  if (failed(parseToken(Token::r_paren,
2124  "expected `)` after operation result type list")))
2125  return failure();
2126  }
2127  } else if (parserContext != ParserContext::Rewrite) {
2128  // If the result list isn't specified and we are in a match context, define
2129  // an inplace unconstrained result range corresponding to all of the results
2130  // of the operation. This avoids treating zero results the same way as
2131  // "unconstrained results".
2132  resultTypes.push_back(createImplicitRangeVar(
2133  ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy));
2134  } else if (resultTypeContext == OpResultTypeContext::Explicit) {
2135  // If the result list isn't specified and we are in a rewrite, try to infer
2136  // them at runtime instead.
2137  resultTypeContext = OpResultTypeContext::Interface;
2138  }
2139 
2140  return createOperationExpr(loc, *opNameDecl, resultTypeContext, operands,
2141  attributes, resultTypes);
2142 }
2143 
2144 FailureOr<ast::Expr *> Parser::parseTupleExpr() {
2145  SMRange loc = curToken.getLoc();
2146  consumeToken(Token::l_paren);
2147 
2148  DenseMap<StringRef, SMRange> usedNames;
2149  SmallVector<StringRef> elementNames;
2150  SmallVector<ast::Expr *> elements;
2151  if (curToken.isNot(Token::r_paren)) {
2152  do {
2153  // Check for the optional element name assignment before the value.
2154  StringRef elementName;
2155  if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
2156  Token elementNameTok = curToken;
2157  consumeToken();
2158 
2159  // The element name is only present if followed by an `=`.
2160  if (consumeIf(Token::equal)) {
2161  elementName = elementNameTok.getSpelling();
2162 
2163  // Check to see if this name is already used.
2164  auto elementNameIt =
2165  usedNames.try_emplace(elementName, elementNameTok.getLoc());
2166  if (!elementNameIt.second) {
2167  return emitErrorAndNote(
2168  elementNameTok.getLoc(),
2169  llvm::formatv("duplicate tuple element label `{0}`",
2170  elementName),
2171  elementNameIt.first->getSecond(),
2172  "see previous label use here");
2173  }
2174  } else {
2175  // Otherwise, we treat this as part of an expression so reset the
2176  // lexer.
2177  resetToken(elementNameTok.getLoc());
2178  }
2179  }
2180  elementNames.push_back(elementName);
2181 
2182  // Parse the tuple element value.
2183  FailureOr<ast::Expr *> element = parseExpr();
2184  if (failed(element))
2185  return failure();
2186  elements.push_back(*element);
2187  } while (consumeIf(Token::comma));
2188  }
2189  loc.End = curToken.getEndLoc();
2190  if (failed(
2191  parseToken(Token::r_paren, "expected `)` after tuple element list")))
2192  return failure();
2193  return createTupleExpr(loc, elements, elementNames);
2194 }
2195 
2196 FailureOr<ast::Expr *> Parser::parseTypeExpr() {
2197  SMRange loc = curToken.getLoc();
2198  consumeToken(Token::kw_type);
2199 
2200  // If we aren't followed by a `<`, the `type` keyword is treated as a normal
2201  // identifier.
2202  if (!consumeIf(Token::less)) {
2203  resetToken(loc);
2204  return parseIdentifierExpr();
2205  }
2206 
2207  if (!curToken.isString())
2208  return emitError("expected string literal containing MLIR type");
2209  std::string attrExpr = curToken.getStringValue();
2210  consumeToken();
2211 
2212  loc.End = curToken.getEndLoc();
2213  if (failed(parseToken(Token::greater, "expected `>` after type literal")))
2214  return failure();
2215  return ast::TypeExpr::create(ctx, loc, attrExpr);
2216 }
2217 
2218 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
2219  StringRef name = curToken.getSpelling();
2220  SMRange nameLoc = curToken.getLoc();
2221  consumeToken(Token::underscore);
2222 
2223  // Underscore expressions require a constraint list.
2224  if (failed(parseToken(Token::colon, "expected `:` after `_` variable")))
2225  return failure();
2226 
2227  // Parse the constraints for the expression.
2228  SmallVector<ast::ConstraintRef> constraints;
2229  if (failed(parseVariableDeclConstraintList(constraints)))
2230  return failure();
2231 
2232  ast::Type type;
2233  if (failed(validateVariableConstraints(constraints, type)))
2234  return failure();
2235  return createInlineVariableExpr(type, name, nameLoc, constraints);
2236 }
2237 
2238 //===----------------------------------------------------------------------===//
2239 // Stmts
2240 
2241 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) {
2243  switch (curToken.getKind()) {
2244  case Token::kw_erase:
2245  stmt = parseEraseStmt();
2246  break;
2247  case Token::kw_let:
2248  stmt = parseLetStmt();
2249  break;
2250  case Token::kw_replace:
2251  stmt = parseReplaceStmt();
2252  break;
2253  case Token::kw_return:
2254  stmt = parseReturnStmt();
2255  break;
2256  case Token::kw_rewrite:
2257  stmt = parseRewriteStmt();
2258  break;
2259  default:
2260  stmt = parseExpr();
2261  break;
2262  }
2263  if (failed(stmt) ||
2264  (expectTerminalSemicolon &&
2265  failed(parseToken(Token::semicolon, "expected `;` after statement"))))
2266  return failure();
2267  return stmt;
2268 }
2269 
2270 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() {
2271  SMLoc startLoc = curToken.getStartLoc();
2272  consumeToken(Token::l_brace);
2273 
2274  // Push a new block scope and parse any nested statements.
2275  pushDeclScope();
2276  SmallVector<ast::Stmt *> statements;
2277  while (curToken.isNot(Token::r_brace)) {
2278  FailureOr<ast::Stmt *> statement = parseStmt();
2279  if (failed(statement))
2280  return popDeclScope(), failure();
2281  statements.push_back(*statement);
2282  }
2283  popDeclScope();
2284 
2285  // Consume the end brace.
2286  SMRange location(startLoc, curToken.getEndLoc());
2287  consumeToken(Token::r_brace);
2288 
2289  return ast::CompoundStmt::create(ctx, location, statements);
2290 }
2291 
2292 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
2293  if (parserContext == ParserContext::Constraint)
2294  return emitError("`erase` cannot be used within a Constraint");
2295  SMRange loc = curToken.getLoc();
2296  consumeToken(Token::kw_erase);
2297 
2298  // Parse the root operation expression.
2299  FailureOr<ast::Expr *> rootOp = parseExpr();
2300  if (failed(rootOp))
2301  return failure();
2302 
2303  return createEraseStmt(loc, *rootOp);
2304 }
2305 
2306 FailureOr<ast::LetStmt *> Parser::parseLetStmt() {
2307  SMRange loc = curToken.getLoc();
2308  consumeToken(Token::kw_let);
2309 
2310  // Parse the name of the new variable.
2311  SMRange varLoc = curToken.getLoc();
2312  if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) {
2313  // `_` is a reserved variable name.
2314  if (curToken.is(Token::underscore)) {
2315  return emitError(varLoc,
2316  "`_` may only be used to define \"inline\" variables");
2317  }
2318  return emitError(varLoc,
2319  "expected identifier after `let` to name a new variable");
2320  }
2321  StringRef varName = curToken.getSpelling();
2322  consumeToken();
2323 
2324  // Parse the optional set of constraints.
2325  SmallVector<ast::ConstraintRef> constraints;
2326  if (consumeIf(Token::colon) &&
2327  failed(parseVariableDeclConstraintList(constraints)))
2328  return failure();
2329 
2330  // Parse the optional initializer expression.
2331  ast::Expr *initializer = nullptr;
2332  if (consumeIf(Token::equal)) {
2333  FailureOr<ast::Expr *> initOrFailure = parseExpr();
2334  if (failed(initOrFailure))
2335  return failure();
2336  initializer = *initOrFailure;
2337 
2338  // Check that the constraints are compatible with having an initializer,
2339  // e.g. type constraints cannot be used with initializers.
2340  for (ast::ConstraintRef constraint : constraints) {
2341  LogicalResult result =
2342  TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint)
2344  ast::ValueRangeConstraintDecl>([&](const auto *cst) {
2345  if (auto *typeConstraintExpr = cst->getTypeExpr()) {
2346  return this->emitError(
2347  constraint.referenceLoc,
2348  "type constraints are not permitted on variables with "
2349  "initializers");
2350  }
2351  return success();
2352  })
2353  .Default(success());
2354  if (failed(result))
2355  return failure();
2356  }
2357  }
2358 
2360  createVariableDecl(varName, varLoc, initializer, constraints);
2361  if (failed(varDecl))
2362  return failure();
2363  return ast::LetStmt::create(ctx, loc, *varDecl);
2364 }
2365 
2366 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
2367  if (parserContext == ParserContext::Constraint)
2368  return emitError("`replace` cannot be used within a Constraint");
2369  SMRange loc = curToken.getLoc();
2370  consumeToken(Token::kw_replace);
2371 
2372  // Parse the root operation expression.
2373  FailureOr<ast::Expr *> rootOp = parseExpr();
2374  if (failed(rootOp))
2375  return failure();
2376 
2377  if (failed(
2378  parseToken(Token::kw_with, "expected `with` after root operation")))
2379  return failure();
2380 
2381  // The replacement portion of this statement is within a rewrite context.
2382  llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
2383 
2384  // Parse the replacement values.
2385  SmallVector<ast::Expr *> replValues;
2386  if (consumeIf(Token::l_paren)) {
2387  if (consumeIf(Token::r_paren)) {
2388  return emitError(
2389  loc, "expected at least one replacement value, consider using "
2390  "`erase` if no replacement values are desired");
2391  }
2392 
2393  do {
2394  FailureOr<ast::Expr *> replExpr = parseExpr();
2395  if (failed(replExpr))
2396  return failure();
2397  replValues.emplace_back(*replExpr);
2398  } while (consumeIf(Token::comma));
2399 
2400  if (failed(parseToken(Token::r_paren,
2401  "expected `)` after replacement values")))
2402  return failure();
2403  } else {
2404  // Handle replacement with an operation uniquely, as the replacement
2405  // operation supports type inferrence from the root operation.
2406  FailureOr<ast::Expr *> replExpr;
2407  if (curToken.is(Token::kw_op))
2408  replExpr = parseOperationExpr(OpResultTypeContext::Replacement);
2409  else
2410  replExpr = parseExpr();
2411  if (failed(replExpr))
2412  return failure();
2413  replValues.emplace_back(*replExpr);
2414  }
2415 
2416  return createReplaceStmt(loc, *rootOp, replValues);
2417 }
2418 
2419 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() {
2420  SMRange loc = curToken.getLoc();
2421  consumeToken(Token::kw_return);
2422 
2423  // Parse the result value.
2424  FailureOr<ast::Expr *> resultExpr = parseExpr();
2425  if (failed(resultExpr))
2426  return failure();
2427 
2428  return ast::ReturnStmt::create(ctx, loc, *resultExpr);
2429 }
2430 
2431 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
2432  if (parserContext == ParserContext::Constraint)
2433  return emitError("`rewrite` cannot be used within a Constraint");
2434  SMRange loc = curToken.getLoc();
2435  consumeToken(Token::kw_rewrite);
2436 
2437  // Parse the root operation.
2438  FailureOr<ast::Expr *> rootOp = parseExpr();
2439  if (failed(rootOp))
2440  return failure();
2441 
2442  if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body")))
2443  return failure();
2444 
2445  if (curToken.isNot(Token::l_brace))
2446  return emitError("expected `{` to start rewrite body");
2447 
2448  // The rewrite body of this statement is within a rewrite context.
2449  llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
2450 
2451  FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt();
2452  if (failed(rewriteBody))
2453  return failure();
2454 
2455  // Verify the rewrite body.
2456  for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) {
2457  if (isa<ast::ReturnStmt>(stmt)) {
2458  return emitError(stmt->getLoc(),
2459  "`return` statements are only permitted within a "
2460  "`Constraint` or `Rewrite` body");
2461  }
2462  }
2463 
2464  return createRewriteStmt(loc, *rootOp, *rewriteBody);
2465 }
2466 
2467 //===----------------------------------------------------------------------===//
2468 // Creation+Analysis
2469 //===----------------------------------------------------------------------===//
2470 
2471 //===----------------------------------------------------------------------===//
2472 // Decls
2473 
2474 ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) {
2475  // Unwrap reference expressions.
2476  if (auto *init = dyn_cast<ast::DeclRefExpr>(node))
2477  node = init->getDecl();
2478  return dyn_cast<ast::CallableDecl>(node);
2479 }
2480 
2482 Parser::createPatternDecl(SMRange loc, const ast::Name *name,
2483  const ParsedPatternMetadata &metadata,
2484  ast::CompoundStmt *body) {
2485  return ast::PatternDecl::create(ctx, loc, name, metadata.benefit,
2486  metadata.hasBoundedRecursion, body);
2487 }
2488 
2489 ast::Type Parser::createUserConstraintRewriteResultType(
2491  // Single result decls use the type of the single result.
2492  if (results.size() == 1)
2493  return results[0]->getType();
2494 
2495  // Multiple results use a tuple type, with the types and names grabbed from
2496  // the result variable decls.
2497  auto resultTypes = llvm::map_range(
2498  results, [&](const auto *result) { return result->getType(); });
2499  auto resultNames = llvm::map_range(
2500  results, [&](const auto *result) { return result->getName().getName(); });
2501  return ast::TupleType::get(ctx, llvm::to_vector(resultTypes),
2502  llvm::to_vector(resultNames));
2503 }
2504 
2505 template <typename T>
2506 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl(
2507  const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
2508  ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
2509  ast::CompoundStmt *body) {
2510  if (!body->getChildren().empty()) {
2511  if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) {
2512  ast::Expr *resultExpr = retStmt->getResultExpr();
2513 
2514  // Process the result of the decl. If no explicit signature results
2515  // were provided, check for return type inference. Otherwise, check that
2516  // the return expression can be converted to the expected type.
2517  if (results.empty())
2518  resultType = resultExpr->getType();
2519  else if (failed(convertExpressionTo(resultExpr, resultType)))
2520  return failure();
2521  else
2522  retStmt->setResultExpr(resultExpr);
2523  }
2524  }
2525  return T::createPDLL(ctx, name, arguments, results, body, resultType);
2526 }
2527 
2529 Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
2530  ArrayRef<ast::ConstraintRef> constraints) {
2531  // The type of the variable, which is expected to be inferred by either a
2532  // constraint or an initializer expression.
2533  ast::Type type;
2534  if (failed(validateVariableConstraints(constraints, type)))
2535  return failure();
2536 
2537  if (initializer) {
2538  // Update the variable type based on the initializer, or try to convert the
2539  // initializer to the existing type.
2540  if (!type)
2541  type = initializer->getType();
2542  else if (ast::Type mergedType = type.refineWith(initializer->getType()))
2543  type = mergedType;
2544  else if (failed(convertExpressionTo(initializer, type)))
2545  return failure();
2546 
2547  // Otherwise, if there is no initializer check that the type has already
2548  // been resolved from the constraint list.
2549  } else if (!type) {
2550  return emitErrorAndNote(
2551  loc, "unable to infer type for variable `" + name + "`", loc,
2552  "the type of a variable must be inferable from the constraint "
2553  "list or the initializer");
2554  }
2555 
2556  // Constraint types cannot be used when defining variables.
2557  if (type.isa<ast::ConstraintType, ast::RewriteType>()) {
2558  return emitError(
2559  loc, llvm::formatv("unable to define variable of `{0}` type", type));
2560  }
2561 
2562  // Try to define a variable with the given name.
2564  defineVariableDecl(name, loc, type, initializer, constraints);
2565  if (failed(varDecl))
2566  return failure();
2567 
2568  return *varDecl;
2569 }
2570 
2572 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
2573  const ast::ConstraintRef &constraint) {
2574  ast::Type argType;
2575  if (failed(validateVariableConstraint(constraint, argType)))
2576  return failure();
2577  return defineVariableDecl(name, loc, argType, constraint);
2578 }
2579 
2581 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
2582  ast::Type &inferredType) {
2583  for (const ast::ConstraintRef &ref : constraints)
2584  if (failed(validateVariableConstraint(ref, inferredType)))
2585  return failure();
2586  return success();
2587 }
2588 
2589 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
2590  ast::Type &inferredType) {
2591  ast::Type constraintType;
2592  if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) {
2593  if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2594  if (failed(validateTypeConstraintExpr(typeExpr)))
2595  return failure();
2596  }
2597  constraintType = ast::AttributeType::get(ctx);
2598  } else if (const auto *cst =
2599  dyn_cast<ast::OpConstraintDecl>(ref.constraint)) {
2600  constraintType = ast::OperationType::get(
2601  ctx, cst->getName(), lookupODSOperation(cst->getName()));
2602  } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) {
2603  constraintType = typeTy;
2604  } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) {
2605  constraintType = typeRangeTy;
2606  } else if (const auto *cst =
2607  dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) {
2608  if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2609  if (failed(validateTypeConstraintExpr(typeExpr)))
2610  return failure();
2611  }
2612  constraintType = valueTy;
2613  } else if (const auto *cst =
2614  dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) {
2615  if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2616  if (failed(validateTypeRangeConstraintExpr(typeExpr)))
2617  return failure();
2618  }
2619  constraintType = valueRangeTy;
2620  } else if (const auto *cst =
2621  dyn_cast<ast::UserConstraintDecl>(ref.constraint)) {
2622  ArrayRef<ast::VariableDecl *> inputs = cst->getInputs();
2623  if (inputs.size() != 1) {
2624  return emitErrorAndNote(ref.referenceLoc,
2625  "`Constraint`s applied via a variable constraint "
2626  "list must take a single input, but got " +
2627  Twine(inputs.size()),
2628  cst->getLoc(),
2629  "see definition of constraint here");
2630  }
2631  constraintType = inputs.front()->getType();
2632  } else {
2633  llvm_unreachable("unknown constraint type");
2634  }
2635 
2636  // Check that the constraint type is compatible with the current inferred
2637  // type.
2638  if (!inferredType) {
2639  inferredType = constraintType;
2640  } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) {
2641  inferredType = mergedTy;
2642  } else {
2643  return emitError(ref.referenceLoc,
2644  llvm::formatv("constraint type `{0}` is incompatible "
2645  "with the previously inferred type `{1}`",
2646  constraintType, inferredType));
2647  }
2648  return success();
2649 }
2650 
2651 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) {
2652  ast::Type typeExprType = typeExpr->getType();
2653  if (typeExprType != typeTy) {
2654  return emitError(typeExpr->getLoc(),
2655  "expected expression of `Type` in type constraint");
2656  }
2657  return success();
2658 }
2659 
2661 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) {
2662  ast::Type typeExprType = typeExpr->getType();
2663  if (typeExprType != typeRangeTy) {
2664  return emitError(typeExpr->getLoc(),
2665  "expected expression of `TypeRange` in type constraint");
2666  }
2667  return success();
2668 }
2669 
2670 //===----------------------------------------------------------------------===//
2671 // Exprs
2672 
2674 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
2675  MutableArrayRef<ast::Expr *> arguments) {
2676  ast::Type parentType = parentExpr->getType();
2677 
2678  ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr);
2679  if (!callableDecl) {
2680  return emitError(loc,
2681  llvm::formatv("expected a reference to a callable "
2682  "`Constraint` or `Rewrite`, but got: `{0}`",
2683  parentType));
2684  }
2685  if (parserContext == ParserContext::Rewrite) {
2686  if (isa<ast::UserConstraintDecl>(callableDecl))
2687  return emitError(
2688  loc, "unable to invoke `Constraint` within a rewrite section");
2689  } else if (isa<ast::UserRewriteDecl>(callableDecl)) {
2690  return emitError(loc, "unable to invoke `Rewrite` within a match section");
2691  }
2692 
2693  // Verify the arguments of the call.
2694  /// Handle size mismatch.
2695  ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs();
2696  if (callArgs.size() != arguments.size()) {
2697  return emitErrorAndNote(
2698  loc,
2699  llvm::formatv("invalid number of arguments for {0} call; expected "
2700  "{1}, but got {2}",
2701  callableDecl->getCallableType(), callArgs.size(),
2702  arguments.size()),
2703  callableDecl->getLoc(),
2704  llvm::formatv("see the definition of {0} here",
2705  callableDecl->getName()->getName()));
2706  }
2707 
2708  /// Handle argument type mismatch.
2709  auto attachDiagFn = [&](ast::Diagnostic &diag) {
2710  diag.attachNote(llvm::formatv("see the definition of `{0}` here",
2711  callableDecl->getName()->getName()),
2712  callableDecl->getLoc());
2713  };
2714  for (auto it : llvm::zip(callArgs, arguments)) {
2715  if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(),
2716  attachDiagFn)))
2717  return failure();
2718  }
2719 
2720  return ast::CallExpr::create(ctx, loc, parentExpr, arguments,
2721  callableDecl->getResultType());
2722 }
2723 
2724 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,
2725  ast::Decl *decl) {
2726  // Check the type of decl being referenced.
2727  ast::Type declType;
2728  if (isa<ast::ConstraintDecl>(decl))
2729  declType = ast::ConstraintType::get(ctx);
2730  else if (isa<ast::UserRewriteDecl>(decl))
2731  declType = ast::RewriteType::get(ctx);
2732  else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
2733  declType = varDecl->getType();
2734  else
2735  return emitError(loc, "invalid reference to `" +
2736  decl->getName()->getName() + "`");
2737 
2738  return ast::DeclRefExpr::create(ctx, loc, decl, declType);
2739 }
2740 
2742 Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
2743  ArrayRef<ast::ConstraintRef> constraints) {
2745  defineVariableDecl(name, loc, type, constraints);
2746  if (failed(decl))
2747  return failure();
2748  return ast::DeclRefExpr::create(ctx, loc, *decl, type);
2749 }
2750 
2752 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
2753  SMRange loc) {
2754  // Validate the member name for the given parent expression.
2755  FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc);
2756  if (failed(memberType))
2757  return failure();
2758 
2759  return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType);
2760 }
2761 
2762 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
2763  StringRef name, SMRange loc) {
2764  ast::Type parentType = parentExpr->getType();
2765  if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) {
2767  return valueRangeTy;
2768 
2769  // Verify member access based on the operation type.
2770  if (const ods::Operation *odsOp = opType.getODSOperation()) {
2771  auto results = odsOp->getResults();
2772 
2773  // Handle indexed results.
2774  unsigned index = 0;
2775  if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2776  index < results.size()) {
2777  return results[index].isVariadic() ? valueRangeTy : valueTy;
2778  }
2779 
2780  // Handle named results.
2781  const auto *it = llvm::find_if(results, [&](const auto &result) {
2782  return result.getName() == name;
2783  });
2784  if (it != results.end())
2785  return it->isVariadic() ? valueRangeTy : valueTy;
2786  } else if (llvm::isDigit(name[0])) {
2787  // Allow unchecked numeric indexing of the results of unregistered
2788  // operations. It returns a single value.
2789  return valueTy;
2790  }
2791  } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) {
2792  // Handle indexed results.
2793  unsigned index = 0;
2794  if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2795  index < tupleType.size()) {
2796  return tupleType.getElementTypes()[index];
2797  }
2798 
2799  // Handle named results.
2800  auto elementNames = tupleType.getElementNames();
2801  const auto *it = llvm::find(elementNames, name);
2802  if (it != elementNames.end())
2803  return tupleType.getElementTypes()[it - elementNames.begin()];
2804  }
2805  return emitError(
2806  loc,
2807  llvm::formatv("invalid member access `{0}` on expression of type `{1}`",
2808  name, parentType));
2809 }
2810 
2811 FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
2812  SMRange loc, const ast::OpNameDecl *name,
2813  OpResultTypeContext resultTypeContext,
2814  SmallVectorImpl<ast::Expr *> &operands,
2816  SmallVectorImpl<ast::Expr *> &results) {
2817  std::optional<StringRef> opNameRef = name->getName();
2818  const ods::Operation *odsOp = lookupODSOperation(opNameRef);
2819 
2820  // Verify the inputs operands.
2821  if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands)))
2822  return failure();
2823 
2824  // Verify the attribute list.
2825  for (ast::NamedAttributeDecl *attr : attributes) {
2826  // Check for an attribute type, or a type awaiting resolution.
2827  ast::Type attrType = attr->getValue()->getType();
2828  if (!attrType.isa<ast::AttributeType>()) {
2829  return emitError(
2830  attr->getValue()->getLoc(),
2831  llvm::formatv("expected `Attr` expression, but got `{0}`", attrType));
2832  }
2833  }
2834 
2835  assert(
2836  (resultTypeContext == OpResultTypeContext::Explicit || results.empty()) &&
2837  "unexpected inferrence when results were explicitly specified");
2838 
2839  // If we aren't relying on type inferrence, or explicit results were provided,
2840  // validate them.
2841  if (resultTypeContext == OpResultTypeContext::Explicit) {
2842  if (failed(validateOperationResults(loc, opNameRef, odsOp, results)))
2843  return failure();
2844 
2845  // Validate the use of interface based type inferrence for this operation.
2846  } else if (resultTypeContext == OpResultTypeContext::Interface) {
2847  assert(opNameRef &&
2848  "expected valid operation name when inferring operation results");
2849  checkOperationResultTypeInferrence(loc, *opNameRef, odsOp);
2850  }
2851 
2852  return ast::OperationExpr::create(ctx, loc, odsOp, name, operands, results,
2853  attributes);
2854 }
2855 
2857 Parser::validateOperationOperands(SMRange loc, std::optional<StringRef> name,
2858  const ods::Operation *odsOp,
2859  SmallVectorImpl<ast::Expr *> &operands) {
2860  return validateOperationOperandsOrResults(
2861  "operand", loc, odsOp ? odsOp->getLoc() : std::optional<SMRange>(), name,
2862  operands, odsOp ? odsOp->getOperands() : std::nullopt, valueTy,
2863  valueRangeTy);
2864 }
2865 
2867 Parser::validateOperationResults(SMRange loc, std::optional<StringRef> name,
2868  const ods::Operation *odsOp,
2869  SmallVectorImpl<ast::Expr *> &results) {
2870  return validateOperationOperandsOrResults(
2871  "result", loc, odsOp ? odsOp->getLoc() : std::optional<SMRange>(), name,
2872  results, odsOp ? odsOp->getResults() : std::nullopt, typeTy, typeRangeTy);
2873 }
2874 
2875 void Parser::checkOperationResultTypeInferrence(SMRange loc, StringRef opName,
2876  const ods::Operation *odsOp) {
2877  // If the operation might not have inferrence support, emit a warning to the
2878  // user. We don't emit an error because the interface might be added to the
2879  // operation at runtime. It's rare, but it could still happen. We emit a
2880  // warning here instead.
2881 
2882  // Handle inferrence warnings for unknown operations.
2883  if (!odsOp) {
2884  ctx.getDiagEngine().emitWarning(
2885  loc, llvm::formatv(
2886  "operation result types are marked to be inferred, but "
2887  "`{0}` is unknown. Ensure that `{0}` supports zero "
2888  "results or implements `InferTypeOpInterface`. Include "
2889  "the ODS definition of this operation to remove this warning.",
2890  opName));
2891  return;
2892  }
2893 
2894  // Handle inferrence warnings for known operations that expected at least one
2895  // result, but don't have inference support. An elided results list can mean
2896  // "zero-results", and we don't want to warn when that is the expected
2897  // behavior.
2898  bool requiresInferrence =
2899  llvm::any_of(odsOp->getResults(), [](const ods::OperandOrResult &result) {
2900  return !result.isVariableLength();
2901  });
2902  if (requiresInferrence && !odsOp->hasResultTypeInferrence()) {
2903  ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitWarning(
2904  loc,
2905  llvm::formatv("operation result types are marked to be inferred, but "
2906  "`{0}` does not provide an implementation of "
2907  "`InferTypeOpInterface`. Ensure that `{0}` attaches "
2908  "`InferTypeOpInterface` at runtime, or add support to "
2909  "the ODS definition to remove this warning.",
2910  opName));
2911  diag->attachNote(llvm::formatv("see the definition of `{0}` here", opName),
2912  odsOp->getLoc());
2913  return;
2914  }
2915 }
2916 
2917 LogicalResult Parser::validateOperationOperandsOrResults(
2918  StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc,
2919  std::optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values,
2920  ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
2921  ast::RangeType rangeTy) {
2922  // All operation types accept a single range parameter.
2923  if (values.size() == 1) {
2924  if (failed(convertExpressionTo(values[0], rangeTy)))
2925  return failure();
2926  return success();
2927  }
2928 
2929  /// If the operation has ODS information, we can more accurately verify the
2930  /// values.
2931  if (odsOpLoc) {
2932  auto emitSizeMismatchError = [&] {
2933  return emitErrorAndNote(
2934  loc,
2935  llvm::formatv("invalid number of {0} groups for `{1}`; expected "
2936  "{2}, but got {3}",
2937  groupName, *name, odsValues.size(), values.size()),
2938  *odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name));
2939  };
2940 
2941  // Handle the case where no values were provided.
2942  if (values.empty()) {
2943  // If we don't expect any on the ODS side, we are done.
2944  if (odsValues.empty())
2945  return success();
2946 
2947  // If we do, check if we actually need to provide values (i.e. if any of
2948  // the values are actually required).
2949  unsigned numVariadic = 0;
2950  for (const auto &odsValue : odsValues) {
2951  if (!odsValue.isVariableLength())
2952  return emitSizeMismatchError();
2953  ++numVariadic;
2954  }
2955 
2956  // If we are in a non-rewrite context, we don't need to do anything more.
2957  // Zero-values is a valid constraint on the operation.
2958  if (parserContext != ParserContext::Rewrite)
2959  return success();
2960 
2961  // Otherwise, when in a rewrite we may need to provide values to match the
2962  // ODS signature of the operation to create.
2963 
2964  // If we only have one variadic value, just use an empty list.
2965  if (numVariadic == 1)
2966  return success();
2967 
2968  // Otherwise, create dummy values for each of the entries so that we
2969  // adhere to the ODS signature.
2970  for (unsigned i = 0, e = odsValues.size(); i < e; ++i) {
2971  values.push_back(ast::RangeExpr::create(
2972  ctx, loc, /*elements=*/std::nullopt, rangeTy));
2973  }
2974  return success();
2975  }
2976 
2977  // Verify that the number of values provided matches the number of value
2978  // groups ODS expects.
2979  if (odsValues.size() != values.size())
2980  return emitSizeMismatchError();
2981 
2982  auto diagFn = [&](ast::Diagnostic &diag) {
2983  diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name),
2984  *odsOpLoc);
2985  };
2986  for (unsigned i = 0, e = values.size(); i < e; ++i) {
2987  ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy;
2988  if (failed(convertExpressionTo(values[i], expectedType, diagFn)))
2989  return failure();
2990  }
2991  return success();
2992  }
2993 
2994  // Otherwise, accept the value groups as they have been defined and just
2995  // ensure they are one of the expected types.
2996  for (ast::Expr *&valueExpr : values) {
2997  ast::Type valueExprType = valueExpr->getType();
2998 
2999  // Check if this is one of the expected types.
3000  if (valueExprType == rangeTy || valueExprType == singleTy)
3001  continue;
3002 
3003  // If the operand is an Operation, allow converting to a Value or
3004  // ValueRange. This situations arises quite often with nested operation
3005  // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)`
3006  if (singleTy == valueTy) {
3007  if (valueExprType.isa<ast::OperationType>()) {
3008  valueExpr = convertOpToValue(valueExpr);
3009  continue;
3010  }
3011  }
3012 
3013  // Otherwise, try to convert the expression to a range.
3014  if (succeeded(convertExpressionTo(valueExpr, rangeTy)))
3015  continue;
3016 
3017  return emitError(
3018  valueExpr->getLoc(),
3019  llvm::formatv(
3020  "expected `{0}` or `{1}` convertible expression, but got `{2}`",
3021  singleTy, rangeTy, valueExprType));
3022  }
3023  return success();
3024 }
3025 
3027 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements,
3028  ArrayRef<StringRef> elementNames) {
3029  for (const ast::Expr *element : elements) {
3030  ast::Type eleTy = element->getType();
3032  return emitError(
3033  element->getLoc(),
3034  llvm::formatv("unable to build a tuple with `{0}` element", eleTy));
3035  }
3036  }
3037  return ast::TupleExpr::create(ctx, loc, elements, elementNames);
3038 }
3039 
3040 //===----------------------------------------------------------------------===//
3041 // Stmts
3042 
3043 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc,
3044  ast::Expr *rootOp) {
3045  // Check that root is an Operation.
3046  ast::Type rootType = rootOp->getType();
3047  if (!rootType.isa<ast::OperationType>())
3048  return emitError(rootOp->getLoc(), "expected `Op` expression");
3049 
3050  return ast::EraseStmt::create(ctx, loc, rootOp);
3051 }
3052 
3054 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp,
3055  MutableArrayRef<ast::Expr *> replValues) {
3056  // Check that root is an Operation.
3057  ast::Type rootType = rootOp->getType();
3058  if (!rootType.isa<ast::OperationType>()) {
3059  return emitError(
3060  rootOp->getLoc(),
3061  llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
3062  }
3063 
3064  // If there are multiple replacement values, we implicitly convert any Op
3065  // expressions to the value form.
3066  bool shouldConvertOpToValues = replValues.size() > 1;
3067  for (ast::Expr *&replExpr : replValues) {
3068  ast::Type replType = replExpr->getType();
3069 
3070  // Check that replExpr is an Operation, Value, or ValueRange.
3071  if (replType.isa<ast::OperationType>()) {
3072  if (shouldConvertOpToValues)
3073  replExpr = convertOpToValue(replExpr);
3074  continue;
3075  }
3076 
3077  if (replType != valueTy && replType != valueRangeTy) {
3078  return emitError(replExpr->getLoc(),
3079  llvm::formatv("expected `Op`, `Value` or `ValueRange` "
3080  "expression, but got `{0}`",
3081  replType));
3082  }
3083  }
3084 
3085  return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues);
3086 }
3087 
3089 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp,
3090  ast::CompoundStmt *rewriteBody) {
3091  // Check that root is an Operation.
3092  ast::Type rootType = rootOp->getType();
3093  if (!rootType.isa<ast::OperationType>()) {
3094  return emitError(
3095  rootOp->getLoc(),
3096  llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
3097  }
3098 
3099  return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody);
3100 }
3101 
3102 //===----------------------------------------------------------------------===//
3103 // Code Completion
3104 //===----------------------------------------------------------------------===//
3105 
3106 LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) {
3107  ast::Type parentType = parentExpr->getType();
3108  if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>())
3109  codeCompleteContext->codeCompleteOperationMemberAccess(opType);
3110  else if (ast::TupleType tupleType = parentType.dyn_cast<ast::TupleType>())
3111  codeCompleteContext->codeCompleteTupleMemberAccess(tupleType);
3112  return failure();
3113 }
3114 
3116 Parser::codeCompleteAttributeName(std::optional<StringRef> opName) {
3117  if (opName)
3118  codeCompleteContext->codeCompleteOperationAttributeName(*opName);
3119  return failure();
3120 }
3121 
3123 Parser::codeCompleteConstraintName(ast::Type inferredType,
3124  bool allowInlineTypeConstraints) {
3125  codeCompleteContext->codeCompleteConstraintName(
3126  inferredType, allowInlineTypeConstraints, curDeclScope);
3127  return failure();
3128 }
3129 
3130 LogicalResult Parser::codeCompleteDialectName() {
3131  codeCompleteContext->codeCompleteDialectName();
3132  return failure();
3133 }
3134 
3135 LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) {
3136  codeCompleteContext->codeCompleteOperationName(dialectName);
3137  return failure();
3138 }
3139 
3140 LogicalResult Parser::codeCompletePatternMetadata() {
3141  codeCompleteContext->codeCompletePatternMetadata();
3142  return failure();
3143 }
3144 
3145 LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) {
3146  codeCompleteContext->codeCompleteIncludeFilename(curPath);
3147  return failure();
3148 }
3149 
3150 void Parser::codeCompleteCallSignature(ast::Node *parent,
3151  unsigned currentNumArgs) {
3152  ast::CallableDecl *callableDecl = tryExtractCallableDecl(parent);
3153  if (!callableDecl)
3154  return;
3155 
3156  codeCompleteContext->codeCompleteCallSignature(callableDecl, currentNumArgs);
3157 }
3158 
3159 void Parser::codeCompleteOperationOperandsSignature(
3160  std::optional<StringRef> opName, unsigned currentNumOperands) {
3161  codeCompleteContext->codeCompleteOperationOperandsSignature(
3162  opName, currentNumOperands);
3163 }
3164 
3165 void Parser::codeCompleteOperationResultsSignature(
3166  std::optional<StringRef> opName, unsigned currentNumResults) {
3167  codeCompleteContext->codeCompleteOperationResultsSignature(opName,
3168  currentNumResults);
3169 }
3170 
3171 //===----------------------------------------------------------------------===//
3172 // Parser
3173 //===----------------------------------------------------------------------===//
3174 
3176 mlir::pdll::parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
3177  bool enableDocumentation,
3178  CodeCompleteContext *codeCompleteContext) {
3179  Parser parser(ctx, sourceMgr, enableDocumentation, codeCompleteContext);
3180  return parser.parseModule();
3181 }
static std::string diag(const llvm::Value &value)
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class breaks up the current file into a token stream.
Definition: Lexer.h:23
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
This represents a token in the MLIR syntax.
Definition: Token.h:20
bool isKeyword() const
Return true if this is one of the keyword token kinds (e.g. kw_if).
Definition: Token.cpp:187
SMLoc getLoc() const
Definition: Token.cpp:19
bool isNot(Kind k) const
Definition: Token.h:50
StringRef getSpelling() const
Definition: Token.h:34
This class provides an abstract interface into the parser for hooking in code completion events.
Definition: CodeComplete.h:30
@ code_complete_string
Token signifying a code completion location within a string.
Definition: Lexer.h:43
@ directive
Tokens.
Definition: Lexer.h:94
@ eof
Markers.
Definition: Lexer.h:38
@ code_complete
Token signifying a code completion location.
Definition: Lexer.h:41
@ arrow
Punctuation.
Definition: Lexer.h:75
@ less
Paired punctuation.
Definition: Lexer.h:83
@ kw_Attr
General keywords.
Definition: Lexer.h:56
static StringRef getMemberName()
Return the member name used for the "all-results" access.
Definition: Nodes.h:478
static AllResultsMemberAccessExpr * create(Context &ctx, SMRange loc, const Expr *parentExpr, Type type)
Definition: Nodes.h:480
The class represents an Attribute constraint, and constrains a variable to be an Attribute.
Definition: Nodes.h:740
static AttrConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr=nullptr)
Definition: Nodes.cpp:389
static AttributeExpr * create(Context &ctx, SMRange loc, StringRef value)
Definition: Nodes.cpp:258
This class represents a PDLL type that corresponds to an mlir::Attribute.
Definition: Types.h:131
static AttributeType get(Context &context)
Return an instance of the Attribute type.
Definition: Types.cpp:57
static CallExpr * create(Context &ctx, SMRange loc, Expr *callable, ArrayRef< Expr * > arguments, Type resultType)
Definition: Nodes.cpp:268
This decl represents a shared interface for all callable decls.
Definition: Nodes.h:1179
Type getResultType() const
Return the result type of this decl.
Definition: Nodes.h:1197
ArrayRef< VariableDecl * > getInputs() const
Return the inputs of this decl.
Definition: Nodes.h:1190
StringRef getCallableType() const
Return the callable type of this decl.
Definition: Nodes.h:1182
This statement represents a compound statement, which contains a collection of other statements.
Definition: Nodes.h:179
ArrayRef< Stmt * >::iterator begin() const
Definition: Nodes.h:191
MutableArrayRef< Stmt * > getChildren()
Return the children of this compound statement.
Definition: Nodes.h:185
ArrayRef< Stmt * >::iterator end() const
Definition: Nodes.h:192
static CompoundStmt * create(Context &ctx, SMRange location, ArrayRef< Stmt * > children)
Definition: Nodes.cpp:192
This class represents the base of all AST Constraint decls.
Definition: Nodes.h:695
This class represents a PDLL type that corresponds to a constraint.
Definition: Types.h:145
static ConstraintType get(Context &context)
Return an instance of the Constraint type.
Definition: Types.cpp:65
This class represents the main context of the PDLL AST.
Definition: Context.h:25
static DeclRefExpr * create(Context &ctx, SMRange loc, Decl *decl, Type type)
Definition: Nodes.cpp:284
This class represents a scope for named AST decls.
Definition: Nodes.h:64
DeclScope * getParentScope()
Return the parent scope of this scope, or nullptr if there is no parent.
Definition: Nodes.h:70
void add(Decl *decl)
Add a new decl to the scope.
Definition: Nodes.cpp:175
This class represents the base Decl node.
Definition: Nodes.h:660
const Name * getName() const
Return the name of the decl, or nullptr if it doesn't have one.
Definition: Nodes.h:663
This class provides a simple implementation of a PDLL diagnostic.
Definition: Diagnostic.h:30
static EraseStmt * create(Context &ctx, SMRange loc, Expr *rootOp)
Definition: Nodes.cpp:218
This class represents a base AST Expression node.
Definition: Nodes.h:345
Type getType() const
Return the type of this expression.
Definition: Nodes.h:348
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostic.h:84
static LetStmt * create(Context &ctx, SMRange loc, VariableDecl *varDecl)
Definition: Nodes.cpp:207
static MemberAccessExpr * create(Context &ctx, SMRange loc, const Expr *parentExpr, StringRef memberName, Type type)
Definition: Nodes.cpp:294
static Module * create(Context &ctx, SMLoc loc, ArrayRef< Decl * > children)
Definition: Nodes.cpp:575
This Decl represents a NamedAttribute, and contains a string name and attribute value.
Definition: Nodes.h:983
static NamedAttributeDecl * create(Context &ctx, const Name &name, Expr *value)
Definition: Nodes.cpp:498
This class represents a base AST node.
Definition: Nodes.h:108
SMRange getLoc() const
Return the location of this node.
Definition: Nodes.h:131
static OpConstraintDecl * create(Context &ctx, SMRange loc, const OpNameDecl *nameDecl=nullptr)
Definition: Nodes.cpp:399
This Decl represents an OperationName.
Definition: Nodes.h:1007
std::optional< StringRef > getName() const
Return the name of this operation, or std::nullopt if the name is unknown.
Definition: Nodes.h:1013
static OpNameDecl * create(Context &ctx, const Name &name)
Definition: Nodes.cpp:508
static OperationExpr * create(Context &ctx, SMRange loc, const ods::Operation *odsOp, const OpNameDecl *nameDecl, ArrayRef< Expr * > operands, ArrayRef< Expr * > resultTypes, ArrayRef< NamedAttributeDecl * > attributes)
Definition: Nodes.cpp:306
This class represents a PDLL type that corresponds to an mlir::Operation.
Definition: Types.h:158
const ods::Operation * getODSOperation() const
Return the ODS operation that this type refers to, or nullptr if the ODS operation is unknown.
Definition: Types.cpp:87
static OperationType get(Context &context, std::optional< StringRef > name=std::nullopt, const ods::Operation *odsOp=nullptr)
Return an instance of the Operation type with an optional operation name.
Definition: Types.cpp:73
static PatternDecl * create(Context &ctx, SMRange location, const Name *name, std::optional< uint16_t > benefit, bool hasBoundedRecursion, const CompoundStmt *body)
Definition: Nodes.cpp:519
static RangeExpr * create(Context &ctx, SMRange loc, ArrayRef< Expr * > elements, RangeType type)
Definition: Nodes.cpp:337
This class represents a PDLL type that corresponds to a range of elements with a given element type.
Definition: Types.h:183
static ReplaceStmt * create(Context &ctx, SMRange loc, Expr *rootOp, ArrayRef< Expr * > replExprs)
Definition: Nodes.cpp:225
static ReturnStmt * create(Context &ctx, SMRange loc, Expr *resultExpr)
Definition: Nodes.cpp:249
static RewriteStmt * create(Context &ctx, SMRange loc, Expr *rootOp, CompoundStmt *rewriteBody)
Definition: Nodes.cpp:239
This class represents a PDLL type that corresponds to a rewrite reference.
Definition: Types.h:230
static RewriteType get(Context &context)
Return an instance of the Rewrite type.
Definition: Types.cpp:133
This class represents a base AST Statement node.
Definition: Nodes.h:164
static TupleExpr * create(Context &ctx, SMRange loc, ArrayRef< Expr * > elements, ArrayRef< StringRef > elementNames)
Definition: Nodes.cpp:352
This class represents a PDLL tuple type, i.e.
Definition: Types.h:244
size_t size() const
Return the number of elements within this tuple.
Definition: Types.h:261
ArrayRef< Type > getElementTypes() const
Return the element types of this tuple.
Definition: Types.cpp:152
static TupleType get(Context &context, ArrayRef< Type > elementTypes, ArrayRef< StringRef > elementNames)
Return an instance of the Tuple type.
Definition: Types.cpp:141
static TypeConstraintDecl * create(Context &ctx, SMRange loc)
Definition: Nodes.cpp:416
static TypeExpr * create(Context &ctx, SMRange loc, StringRef value)
Definition: Nodes.cpp:372
static TypeRangeConstraintDecl * create(Context &ctx, SMRange loc)
Definition: Nodes.cpp:425
U dyn_cast() const
Definition: Types.h:76
Type refineWith(Type other) const
Try to refine this type with the one provided.
Definition: Types.cpp:33
bool isa() const
Provide type casting support.
Definition: Types.h:67
static UserConstraintDecl * createNative(Context &ctx, const Name &name, ArrayRef< VariableDecl * > inputs, ArrayRef< VariableDecl * > results, std::optional< StringRef > codeBlock, Type resultType, ArrayRef< StringRef > nativeInputTypes={})
Create a native constraint with the given optional code block.
Definition: Nodes.h:877
The class represents a Value constraint, and constrains a variable to be a Value.
Definition: Nodes.h:816
static ValueConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr)
Definition: Nodes.cpp:435
The class represents a ValueRange constraint, and constrains a variable to be a ValueRange.
Definition: Nodes.h:838
static ValueRangeConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr=nullptr)
Definition: Nodes.cpp:446
static VariableDecl * create(Context &ctx, const Name &name, Type type, Expr *initExpr, ArrayRef< ConstraintRef > constraints)
Definition: Nodes.cpp:557
This class contains all of the registered ODS operation classes.
Definition: Context.h:32
std::pair< Operation *, bool > insertOperation(StringRef name, StringRef summary, StringRef desc, StringRef nativeClassName, bool supportsResultTypeInferrence, SMLoc loc)
Insert a new operation with the context.
Definition: Context.cpp:64
const TypeConstraint & insertTypeConstraint(StringRef name, StringRef summary, StringRef cppClass)
Insert a new type constraint with the context.
Definition: Context.cpp:42
const AttributeConstraint & insertAttributeConstraint(StringRef name, StringRef summary, StringRef cppClass)
Insert a new attribute constraint with the context.
Definition: Context.cpp:28
This class provides an ODS representation of a specific operation operand or result.
Definition: Operation.h:74
This class provides an ODS representation of a specific operation.
Definition: Operation.h:125
SMRange getLoc() const
Return the source location of this operation.
Definition: Operation.h:128
ArrayRef< OperandOrResult > getOperands() const
Returns the operands of this operation.
Definition: Operation.h:165
bool hasResultTypeInferrence() const
Return if the operation is known to support result type inferrence.
Definition: Operation.h:171
ArrayRef< OperandOrResult > getResults() const
Returns the results of this operation.
Definition: Operation.h:168
This class represents a generic ODS Type constraint.
Definition: Constraint.h:84
raw_ostream subclass that simplifies indention a sequence of code.
raw_indented_ostream & printReindented(StringRef str, StringRef extraPrefix="")
Prints a string re-indented to the current indent.
StringRef getSummary() const
Definition: Constraint.cpp:54
std::string getUniqueDefName() const
Returns a unique name for the TablGen def of this constraint.
Definition: Constraint.cpp:71
StringRef getDescription() const
Definition: Constraint.cpp:61
std::string getConditionTemplate() const
Definition: Constraint.cpp:50
Format context containing substitutions for special placeholders.
Definition: Format.h:40
FmtContext & withSelf(Twine subst)
Definition: Format.cpp:41
Wrapper class that contains a MLIR op's information (e.g., operands, attributes) defined in TableGen ...
Definition: Operator.h:77
FailureOr< ast::Module * > parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr, bool enableDocumentation=false, CodeCompleteContext *codeCompleteContext=nullptr)
Parse an AST module from the main file of the given source manager.
Definition: Parser.cpp:3176
auto tgfmt(StringRef fmt, const FmtContext *ctx, Ts &&...vals) -> FmtObject< decltype(std::make_tuple(llvm::detail::build_format_adapter(std::forward< Ts >(vals))...))>
Formats text by substituting placeholders in format string with replacement parameters.
Definition: Format.h:261
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This class represents a reference to a constraint, and contains a constraint and the location of the ...
Definition: Nodes.h:707
const ConstraintDecl * constraint
Definition: Nodes.h:713
This class provides a convenient API for interacting with source names.
Definition: Nodes.h:37
StringRef getName() const
Return the raw string name.
Definition: Nodes.h:41
SMRange getLoc() const
Get the location of this name.
Definition: Nodes.h:44
static const Name & create(Context &ctx, StringRef name, SMRange location)
Definition: Nodes.cpp:33