MLIR  19.0.0git
Parser.cpp
Go to the documentation of this file.
1 //===- Parser.cpp ---------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "Lexer.h"
13 #include "mlir/TableGen/Argument.h"
16 #include "mlir/TableGen/Format.h"
17 #include "mlir/TableGen/Operator.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/FormatVariadic.h"
29 #include "llvm/Support/ManagedStatic.h"
30 #include "llvm/Support/SaveAndRestore.h"
31 #include "llvm/Support/ScopedPrinter.h"
32 #include "llvm/TableGen/Error.h"
33 #include "llvm/TableGen/Parser.h"
34 #include <optional>
35 #include <string>
36 
37 using namespace mlir;
38 using namespace mlir::pdll;
39 
40 //===----------------------------------------------------------------------===//
41 // Parser
42 //===----------------------------------------------------------------------===//
43 
44 namespace {
45 class Parser {
46 public:
47  Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
48  bool enableDocumentation, CodeCompleteContext *codeCompleteContext)
49  : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext),
50  curToken(lexer.lexToken()), enableDocumentation(enableDocumentation),
51  typeTy(ast::TypeType::get(ctx)), valueTy(ast::ValueType::get(ctx)),
52  typeRangeTy(ast::TypeRangeType::get(ctx)),
53  valueRangeTy(ast::ValueRangeType::get(ctx)),
54  attrTy(ast::AttributeType::get(ctx)),
55  codeCompleteContext(codeCompleteContext) {}
56 
57  /// Try to parse a new module. Returns nullptr in the case of failure.
58  FailureOr<ast::Module *> parseModule();
59 
60 private:
61  /// The current context of the parser. It allows for the parser to know a bit
62  /// about the construct it is nested within during parsing. This is used
63  /// specifically to provide additional verification during parsing, e.g. to
64  /// prevent using rewrites within a match context, matcher constraints within
65  /// a rewrite section, etc.
66  enum class ParserContext {
67  /// The parser is in the global context.
68  Global,
69  /// The parser is currently within a Constraint, which disallows all types
70  /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.).
71  Constraint,
72  /// The parser is currently within the matcher portion of a Pattern, which
73  /// is allows a terminal operation rewrite statement but no other rewrite
74  /// transformations.
75  PatternMatch,
76  /// The parser is currently within a Rewrite, which disallows calls to
77  /// constraints, requires operation expressions to have names, etc.
78  Rewrite,
79  };
80 
81  /// The current specification context of an operations result type. This
82  /// indicates how the result types of an operation may be inferred.
83  enum class OpResultTypeContext {
84  /// The result types of the operation are not known to be inferred.
85  Explicit,
86  /// The result types of the operation are inferred from the root input of a
87  /// `replace` statement.
88  Replacement,
89  /// The result types of the operation are inferred by using the
90  /// `InferTypeOpInterface` interface provided by the operation.
91  Interface,
92  };
93 
94  //===--------------------------------------------------------------------===//
95  // Parsing
96  //===--------------------------------------------------------------------===//
97 
98  /// Push a new decl scope onto the lexer.
99  ast::DeclScope *pushDeclScope() {
100  ast::DeclScope *newScope =
101  new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope);
102  return (curDeclScope = newScope);
103  }
104  void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; }
105 
106  /// Pop the last decl scope from the lexer.
107  void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); }
108 
109  /// Parse the body of an AST module.
110  LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls);
111 
112  /// Try to convert the given expression to `type`. Returns failure and emits
113  /// an error if a conversion is not viable. On failure, `noteAttachFn` is
114  /// invoked to attach notes to the emitted error diagnostic. On success,
115  /// `expr` is updated to the expression used to convert to `type`.
116  LogicalResult convertExpressionTo(
117  ast::Expr *&expr, ast::Type type,
118  function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {});
120  convertOpExpressionTo(ast::Expr *&expr, ast::OperationType exprType,
121  ast::Type type,
122  function_ref<ast::InFlightDiagnostic()> emitErrorFn);
123  LogicalResult convertTupleExpressionTo(
124  ast::Expr *&expr, ast::TupleType exprType, ast::Type type,
125  function_ref<ast::InFlightDiagnostic()> emitErrorFn,
126  function_ref<void(ast::Diagnostic &diag)> noteAttachFn);
127 
128  /// Given an operation expression, convert it to a Value or ValueRange
129  /// typed expression.
130  ast::Expr *convertOpToValue(const ast::Expr *opExpr);
131 
132  /// Lookup ODS information for the given operation, returns nullptr if no
133  /// information is found.
134  const ods::Operation *lookupODSOperation(std::optional<StringRef> opName) {
135  return opName ? ctx.getODSContext().lookupOperation(*opName) : nullptr;
136  }
137 
138  /// Process the given documentation string, or return an empty string if
139  /// documentation isn't enabled.
140  StringRef processDoc(StringRef doc) {
141  return enableDocumentation ? doc : StringRef();
142  }
143 
144  /// Process the given documentation string and format it, or return an empty
145  /// string if documentation isn't enabled.
146  std::string processAndFormatDoc(const Twine &doc) {
147  if (!enableDocumentation)
148  return "";
149  std::string docStr;
150  {
151  llvm::raw_string_ostream docOS(docStr);
152  std::string tmpDocStr = doc.str();
154  StringRef(tmpDocStr).rtrim(" \t"));
155  }
156  return docStr;
157  }
158 
159  //===--------------------------------------------------------------------===//
160  // Directives
161 
162  LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls);
163  LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls);
164  LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc,
166 
167  /// Process the records of a parsed tablegen include file.
168  void processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
170 
171  /// Create a user defined native constraint for a constraint imported from
172  /// ODS.
173  template <typename ConstraintT>
174  ast::Decl *
175  createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock,
176  SMRange loc, ast::Type type,
177  StringRef nativeType, StringRef docString);
178  template <typename ConstraintT>
179  ast::Decl *
180  createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
181  SMRange loc, ast::Type type,
182  StringRef nativeType);
183 
184  //===--------------------------------------------------------------------===//
185  // Decls
186 
187  /// This structure contains the set of pattern metadata that may be parsed.
188  struct ParsedPatternMetadata {
189  std::optional<uint16_t> benefit;
190  bool hasBoundedRecursion = false;
191  };
192 
193  FailureOr<ast::Decl *> parseTopLevelDecl();
195  parseNamedAttributeDecl(std::optional<StringRef> parentOpName);
196 
197  /// Parse an argument variable as part of the signature of a
198  /// UserConstraintDecl or UserRewriteDecl.
199  FailureOr<ast::VariableDecl *> parseArgumentDecl();
200 
201  /// Parse a result variable as part of the signature of a UserConstraintDecl
202  /// or UserRewriteDecl.
203  FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum);
204 
205  /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being
206  /// defined in a non-global context.
208  parseUserConstraintDecl(bool isInline = false);
209 
210  /// Parse an inline UserConstraintDecl. An inline decl is one defined in a
211  /// non-global context, such as within a Pattern/Constraint/etc.
212  FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl();
213 
214  /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
215  /// PDLL constructs.
216  FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl(
217  const ast::Name &name, bool isInline,
218  ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
219  ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
220 
221  /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being
222  /// defined in a non-global context.
223  FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false);
224 
225  /// Parse an inline UserRewriteDecl. An inline decl is one defined in a
226  /// non-global context, such as within a Pattern/Rewrite/etc.
227  FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl();
228 
229  /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
230  /// PDLL constructs.
231  FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl(
232  const ast::Name &name, bool isInline,
233  ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
234  ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
235 
236  /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have
237  /// effectively the same syntax, and only differ on slight semantics (given
238  /// the different parsing contexts).
239  template <typename T, typename ParseUserPDLLDeclFnT>
240  FailureOr<T *> parseUserConstraintOrRewriteDecl(
241  ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
242  StringRef anonymousNamePrefix, bool isInline);
243 
244  /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl.
245  /// These decls have effectively the same syntax.
246  template <typename T>
247  FailureOr<T *> parseUserNativeConstraintOrRewriteDecl(
248  const ast::Name &name, bool isInline,
250  ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
251 
252  /// Parse the functional signature (i.e. the arguments and results) of a
253  /// UserConstraintDecl or UserRewriteDecl.
254  LogicalResult parseUserConstraintOrRewriteSignature(
257  ast::DeclScope *&argumentScope, ast::Type &resultType);
258 
259  /// Validate the return (which if present is specified by bodyIt) of a
260  /// UserConstraintDecl or UserRewriteDecl.
261  LogicalResult validateUserConstraintOrRewriteReturn(
262  StringRef declType, ast::CompoundStmt *body,
265  ArrayRef<ast::VariableDecl *> results, ast::Type &resultType);
266 
268  parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
269  bool expectTerminalSemicolon = true);
270  FailureOr<ast::CompoundStmt *> parsePatternLambdaBody();
271  FailureOr<ast::Decl *> parsePatternDecl();
272  LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
273 
274  /// Check to see if a decl has already been defined with the given name, if
275  /// one has emit and error and return failure. Returns success otherwise.
276  LogicalResult checkDefineNamedDecl(const ast::Name &name);
277 
278  /// Try to define a variable decl with the given components, returns the
279  /// variable on success.
281  defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
282  ast::Expr *initExpr,
283  ArrayRef<ast::ConstraintRef> constraints);
285  defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
286  ArrayRef<ast::ConstraintRef> constraints);
287 
288  /// Parse the constraint reference list for a variable decl.
289  LogicalResult parseVariableDeclConstraintList(
291 
292  /// Parse the expression used within a type constraint, e.g. Attr<type-expr>.
293  FailureOr<ast::Expr *> parseTypeConstraintExpr();
294 
295  /// Try to parse a single reference to a constraint. `typeConstraint` is the
296  /// location of a previously parsed type constraint for the entity that will
297  /// be constrained by the parsed constraint. `existingConstraints` are any
298  /// existing constraints that have already been parsed for the same entity
299  /// that will be constrained by this constraint. `allowInlineTypeConstraints`
300  /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`.
302  parseConstraint(std::optional<SMRange> &typeConstraint,
303  ArrayRef<ast::ConstraintRef> existingConstraints,
304  bool allowInlineTypeConstraints);
305 
306  /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl
307  /// argument or result variable. The constraints for these variables do not
308  /// allow inline type constraints, and only permit a single constraint.
309  FailureOr<ast::ConstraintRef> parseArgOrResultConstraint();
310 
311  //===--------------------------------------------------------------------===//
312  // Exprs
313 
314  FailureOr<ast::Expr *> parseExpr();
315 
316  /// Identifier expressions.
317  FailureOr<ast::Expr *> parseAttributeExpr();
318  FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr,
319  bool isNegated = false);
320  FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
321  FailureOr<ast::Expr *> parseIdentifierExpr();
322  FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
323  FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
324  FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
325  FailureOr<ast::Expr *> parseNegatedExpr();
326  FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
327  FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
329  parseOperationExpr(OpResultTypeContext inputResultTypeContext =
330  OpResultTypeContext::Explicit);
331  FailureOr<ast::Expr *> parseTupleExpr();
332  FailureOr<ast::Expr *> parseTypeExpr();
333  FailureOr<ast::Expr *> parseUnderscoreExpr();
334 
335  //===--------------------------------------------------------------------===//
336  // Stmts
337 
338  FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true);
339  FailureOr<ast::CompoundStmt *> parseCompoundStmt();
340  FailureOr<ast::EraseStmt *> parseEraseStmt();
341  FailureOr<ast::LetStmt *> parseLetStmt();
342  FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
343  FailureOr<ast::ReturnStmt *> parseReturnStmt();
344  FailureOr<ast::RewriteStmt *> parseRewriteStmt();
345 
346  //===--------------------------------------------------------------------===//
347  // Creation+Analysis
348  //===--------------------------------------------------------------------===//
349 
350  //===--------------------------------------------------------------------===//
351  // Decls
352 
353  /// Try to extract a callable from the given AST node. Returns nullptr on
354  /// failure.
355  ast::CallableDecl *tryExtractCallableDecl(ast::Node *node);
356 
357  /// Try to create a pattern decl with the given components, returning the
358  /// Pattern on success.
360  createPatternDecl(SMRange loc, const ast::Name *name,
361  const ParsedPatternMetadata &metadata,
362  ast::CompoundStmt *body);
363 
364  /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set
365  /// of results, defined as part of the signature.
366  ast::Type
367  createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results);
368 
369  /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl.
370  template <typename T>
371  FailureOr<T *> createUserPDLLConstraintOrRewriteDecl(
372  const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
373  ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
374  ast::CompoundStmt *body);
375 
376  /// Try to create a variable decl with the given components, returning the
377  /// Variable on success.
379  createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
380  ArrayRef<ast::ConstraintRef> constraints);
381 
382  /// Create a variable for an argument or result defined as part of the
383  /// signature of a UserConstraintDecl/UserRewriteDecl.
385  createArgOrResultVariableDecl(StringRef name, SMRange loc,
386  const ast::ConstraintRef &constraint);
387 
388  /// Validate the constraints used to constraint a variable decl.
389  /// `inferredType` is the type of the variable inferred by the constraints
390  /// within the list, and is updated to the most refined type as determined by
391  /// the constraints. Returns success if the constraint list is valid, failure
392  /// otherwise.
394  validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
395  ast::Type &inferredType);
396  /// Validate a single reference to a constraint. `inferredType` contains the
397  /// currently inferred variabled type and is refined within the type defined
398  /// by the constraint. Returns success if the constraint is valid, failure
399  /// otherwise.
400  LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref,
401  ast::Type &inferredType);
402  LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr);
403  LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr);
404 
405  //===--------------------------------------------------------------------===//
406  // Exprs
407 
409  createCallExpr(SMRange loc, ast::Expr *parentExpr,
411  bool isNegated = false);
412  FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl);
414  createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
415  ArrayRef<ast::ConstraintRef> constraints);
417  createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc);
418 
419  /// Validate the member access `name` into the given parent expression. On
420  /// success, this also returns the type of the member accessed.
421  FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr,
422  StringRef name, SMRange loc);
424  createOperationExpr(SMRange loc, const ast::OpNameDecl *name,
425  OpResultTypeContext resultTypeContext,
430  validateOperationOperands(SMRange loc, std::optional<StringRef> name,
431  const ods::Operation *odsOp,
432  SmallVectorImpl<ast::Expr *> &operands);
433  LogicalResult validateOperationResults(SMRange loc,
434  std::optional<StringRef> name,
435  const ods::Operation *odsOp,
437  void checkOperationResultTypeInferrence(SMRange loc, StringRef name,
438  const ods::Operation *odsOp);
439  LogicalResult validateOperationOperandsOrResults(
440  StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc,
441  std::optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values,
442  ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
443  ast::RangeType rangeTy);
444  FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
445  ArrayRef<ast::Expr *> elements,
446  ArrayRef<StringRef> elementNames);
447 
448  //===--------------------------------------------------------------------===//
449  // Stmts
450 
451  FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp);
453  createReplaceStmt(SMRange loc, ast::Expr *rootOp,
454  MutableArrayRef<ast::Expr *> replValues);
456  createRewriteStmt(SMRange loc, ast::Expr *rootOp,
457  ast::CompoundStmt *rewriteBody);
458 
459  //===--------------------------------------------------------------------===//
460  // Code Completion
461  //===--------------------------------------------------------------------===//
462 
463  /// The set of various code completion methods. Every completion method
464  /// returns `failure` to stop the parsing process after providing completion
465  /// results.
466 
467  LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr);
468  LogicalResult codeCompleteAttributeName(std::optional<StringRef> opName);
469  LogicalResult codeCompleteConstraintName(ast::Type inferredType,
470  bool allowInlineTypeConstraints);
471  LogicalResult codeCompleteDialectName();
472  LogicalResult codeCompleteOperationName(StringRef dialectName);
473  LogicalResult codeCompletePatternMetadata();
474  LogicalResult codeCompleteIncludeFilename(StringRef curPath);
475 
476  void codeCompleteCallSignature(ast::Node *parent, unsigned currentNumArgs);
477  void codeCompleteOperationOperandsSignature(std::optional<StringRef> opName,
478  unsigned currentNumOperands);
479  void codeCompleteOperationResultsSignature(std::optional<StringRef> opName,
480  unsigned currentNumResults);
481 
482  //===--------------------------------------------------------------------===//
483  // Lexer Utilities
484  //===--------------------------------------------------------------------===//
485 
486  /// If the current token has the specified kind, consume it and return true.
487  /// If not, return false.
488  bool consumeIf(Token::Kind kind) {
489  if (curToken.isNot(kind))
490  return false;
491  consumeToken(kind);
492  return true;
493  }
494 
495  /// Advance the current lexer onto the next token.
496  void consumeToken() {
497  assert(curToken.isNot(Token::eof, Token::error) &&
498  "shouldn't advance past EOF or errors");
499  curToken = lexer.lexToken();
500  }
501 
502  /// Advance the current lexer onto the next token, asserting what the expected
503  /// current token is. This is preferred to the above method because it leads
504  /// to more self-documenting code with better checking.
505  void consumeToken(Token::Kind kind) {
506  assert(curToken.is(kind) && "consumed an unexpected token");
507  consumeToken();
508  }
509 
510  /// Reset the lexer to the location at the given position.
511  void resetToken(SMRange tokLoc) {
512  lexer.resetPointer(tokLoc.Start.getPointer());
513  curToken = lexer.lexToken();
514  }
515 
516  /// Consume the specified token if present and return success. On failure,
517  /// output a diagnostic and return failure.
518  LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
519  if (curToken.getKind() != kind)
520  return emitError(curToken.getLoc(), msg);
521  consumeToken();
522  return success();
523  }
524  LogicalResult emitError(SMRange loc, const Twine &msg) {
525  lexer.emitError(loc, msg);
526  return failure();
527  }
528  LogicalResult emitError(const Twine &msg) {
529  return emitError(curToken.getLoc(), msg);
530  }
531  LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc,
532  const Twine &note) {
533  lexer.emitErrorAndNote(loc, msg, noteLoc, note);
534  return failure();
535  }
536 
537  //===--------------------------------------------------------------------===//
538  // Fields
539  //===--------------------------------------------------------------------===//
540 
541  /// The owning AST context.
542  ast::Context &ctx;
543 
544  /// The lexer of this parser.
545  Lexer lexer;
546 
547  /// The current token within the lexer.
548  Token curToken;
549 
550  /// A flag indicating if the parser should add documentation to AST nodes when
551  /// viable.
552  bool enableDocumentation;
553 
554  /// The most recently defined decl scope.
555  ast::DeclScope *curDeclScope = nullptr;
556  llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;
557 
558  /// The current context of the parser.
559  ParserContext parserContext = ParserContext::Global;
560 
561  /// Cached types to simplify verification and expression creation.
562  ast::Type typeTy, valueTy;
563  ast::RangeType typeRangeTy, valueRangeTy;
564  ast::Type attrTy;
565 
566  /// A counter used when naming anonymous constraints and rewrites.
567  unsigned anonymousDeclNameCounter = 0;
568 
569  /// The optional code completion context.
570  CodeCompleteContext *codeCompleteContext;
571 };
572 } // namespace
573 
574 FailureOr<ast::Module *> Parser::parseModule() {
575  SMLoc moduleLoc = curToken.getStartLoc();
576  pushDeclScope();
577 
578  // Parse the top-level decls of the module.
580  if (failed(parseModuleBody(decls)))
581  return popDeclScope(), failure();
582 
583  popDeclScope();
584  return ast::Module::create(ctx, moduleLoc, decls);
585 }
586 
587 LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) {
588  while (curToken.isNot(Token::eof)) {
589  if (curToken.is(Token::directive)) {
590  if (failed(parseDirective(decls)))
591  return failure();
592  continue;
593  }
594 
595  FailureOr<ast::Decl *> decl = parseTopLevelDecl();
596  if (failed(decl))
597  return failure();
598  decls.push_back(*decl);
599  }
600  return success();
601 }
602 
603 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) {
604  return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr,
605  valueRangeTy);
606 }
607 
608 LogicalResult Parser::convertExpressionTo(
609  ast::Expr *&expr, ast::Type type,
610  function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
611  ast::Type exprType = expr->getType();
612  if (exprType == type)
613  return success();
614 
615  auto emitConvertError = [&]() -> ast::InFlightDiagnostic {
616  ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError(
617  expr->getLoc(), llvm::formatv("unable to convert expression of type "
618  "`{0}` to the expected type of "
619  "`{1}`",
620  exprType, type));
621  if (noteAttachFn)
622  noteAttachFn(*diag);
623  return diag;
624  };
625 
626  if (auto exprOpType = dyn_cast<ast::OperationType>(exprType))
627  return convertOpExpressionTo(expr, exprOpType, type, emitConvertError);
628 
629  // FIXME: Decide how to allow/support converting a single result to multiple,
630  // and multiple to a single result. For now, we just allow Single->Range,
631  // but this isn't something really supported in the PDL dialect. We should
632  // figure out some way to support both.
633  if ((exprType == valueTy || exprType == valueRangeTy) &&
634  (type == valueTy || type == valueRangeTy))
635  return success();
636  if ((exprType == typeTy || exprType == typeRangeTy) &&
637  (type == typeTy || type == typeRangeTy))
638  return success();
639 
640  // Handle tuple types.
641  if (auto exprTupleType = dyn_cast<ast::TupleType>(exprType))
642  return convertTupleExpressionTo(expr, exprTupleType, type, emitConvertError,
643  noteAttachFn);
644 
645  return emitConvertError();
646 }
647 
648 LogicalResult Parser::convertOpExpressionTo(
649  ast::Expr *&expr, ast::OperationType exprType, ast::Type type,
650  function_ref<ast::InFlightDiagnostic()> emitErrorFn) {
651  // Two operation types are compatible if they have the same name, or if the
652  // expected type is more general.
653  if (auto opType = dyn_cast<ast::OperationType>(type)) {
654  if (opType.getName())
655  return emitErrorFn();
656  return success();
657  }
658 
659  // An operation can always convert to a ValueRange.
660  if (type == valueRangeTy) {
661  expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
662  valueRangeTy);
663  return success();
664  }
665 
666  // Allow conversion to a single value by constraining the result range.
667  if (type == valueTy) {
668  // If the operation is registered, we can verify if it can ever have a
669  // single result.
670  if (const ods::Operation *odsOp = exprType.getODSOperation()) {
671  if (odsOp->getResults().empty()) {
672  return emitErrorFn()->attachNote(
673  llvm::formatv("see the definition of `{0}`, which was defined "
674  "with zero results",
675  odsOp->getName()),
676  odsOp->getLoc());
677  }
678 
679  unsigned numSingleResults = llvm::count_if(
680  odsOp->getResults(), [](const ods::OperandOrResult &result) {
681  return result.getVariableLengthKind() ==
682  ods::VariableLengthKind::Single;
683  });
684  if (numSingleResults > 1) {
685  return emitErrorFn()->attachNote(
686  llvm::formatv("see the definition of `{0}`, which was defined "
687  "with at least {1} results",
688  odsOp->getName(), numSingleResults),
689  odsOp->getLoc());
690  }
691  }
692 
693  expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
694  valueTy);
695  return success();
696  }
697  return emitErrorFn();
698 }
699 
700 LogicalResult Parser::convertTupleExpressionTo(
701  ast::Expr *&expr, ast::TupleType exprType, ast::Type type,
702  function_ref<ast::InFlightDiagnostic()> emitErrorFn,
703  function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
704  // Handle conversions between tuples.
705  if (auto tupleType = dyn_cast<ast::TupleType>(type)) {
706  if (tupleType.size() != exprType.size())
707  return emitErrorFn();
708 
709  // Build a new tuple expression using each of the elements of the current
710  // tuple.
711  SmallVector<ast::Expr *> newExprs;
712  for (unsigned i = 0, e = exprType.size(); i < e; ++i) {
713  newExprs.push_back(ast::MemberAccessExpr::create(
714  ctx, expr->getLoc(), expr, llvm::to_string(i),
715  exprType.getElementTypes()[i]));
716 
717  auto diagFn = [&](ast::Diagnostic &diag) {
718  diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`",
719  i, exprType));
720  if (noteAttachFn)
721  noteAttachFn(diag);
722  };
723  if (failed(convertExpressionTo(newExprs.back(),
724  tupleType.getElementTypes()[i], diagFn)))
725  return failure();
726  }
727  expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs,
728  tupleType.getElementNames());
729  return success();
730  }
731 
732  // Handle conversion to a range.
733  auto convertToRange = [&](ArrayRef<ast::Type> allowedElementTypes,
734  ast::RangeType resultTy) -> LogicalResult {
735  // TODO: We currently only allow range conversion within a rewrite context.
736  if (parserContext != ParserContext::Rewrite) {
737  return emitErrorFn()->attachNote("Tuple to Range conversion is currently "
738  "only allowed within a rewrite context");
739  }
740 
741  // All of the tuple elements must be allowed types.
742  for (ast::Type elementType : exprType.getElementTypes())
743  if (!llvm::is_contained(allowedElementTypes, elementType))
744  return emitErrorFn();
745 
746  // Build a new tuple expression using each of the elements of the current
747  // tuple.
748  SmallVector<ast::Expr *> newExprs;
749  for (unsigned i = 0, e = exprType.size(); i < e; ++i) {
750  newExprs.push_back(ast::MemberAccessExpr::create(
751  ctx, expr->getLoc(), expr, llvm::to_string(i),
752  exprType.getElementTypes()[i]));
753  }
754  expr = ast::RangeExpr::create(ctx, expr->getLoc(), newExprs, resultTy);
755  return success();
756  };
757  if (type == valueRangeTy)
758  return convertToRange({valueTy, valueRangeTy}, valueRangeTy);
759  if (type == typeRangeTy)
760  return convertToRange({typeTy, typeRangeTy}, typeRangeTy);
761 
762  return emitErrorFn();
763 }
764 
765 //===----------------------------------------------------------------------===//
766 // Directives
767 
768 LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) {
769  StringRef directive = curToken.getSpelling();
770  if (directive == "#include")
771  return parseInclude(decls);
772 
773  return emitError("unknown directive `" + directive + "`");
774 }
775 
776 LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) {
777  SMRange loc = curToken.getLoc();
778  consumeToken(Token::directive);
779 
780  // Handle code completion of the include file path.
781  if (curToken.is(Token::code_complete_string))
782  return codeCompleteIncludeFilename(curToken.getStringValue());
783 
784  // Parse the file being included.
785  if (!curToken.isString())
786  return emitError(loc,
787  "expected string file name after `include` directive");
788  SMRange fileLoc = curToken.getLoc();
789  std::string filenameStr = curToken.getStringValue();
790  StringRef filename = filenameStr;
791  consumeToken();
792 
793  // Check the type of include. If ending with `.pdll`, this is another pdl file
794  // to be parsed along with the current module.
795  if (filename.ends_with(".pdll")) {
796  if (failed(lexer.pushInclude(filename, fileLoc)))
797  return emitError(fileLoc,
798  "unable to open include file `" + filename + "`");
799 
800  // If we added the include successfully, parse it into the current module.
801  // Make sure to update to the next token after we finish parsing the nested
802  // file.
803  curToken = lexer.lexToken();
804  LogicalResult result = parseModuleBody(decls);
805  curToken = lexer.lexToken();
806  return result;
807  }
808 
809  // Otherwise, this must be a `.td` include.
810  if (filename.ends_with(".td"))
811  return parseTdInclude(filename, fileLoc, decls);
812 
813  return emitError(fileLoc,
814  "expected include filename to end with `.pdll` or `.td`");
815 }
816 
817 LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc,
819  llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr();
820 
821  // Use the source manager to open the file, but don't yet add it.
822  std::string includedFile;
823  llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer =
824  parserSrcMgr.OpenIncludeFile(filename.str(), includedFile);
825  if (!includeBuffer)
826  return emitError(fileLoc, "unable to open include file `" + filename + "`");
827 
828  // Setup the source manager for parsing the tablegen file.
829  llvm::SourceMgr tdSrcMgr;
830  tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc());
831  tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs());
832 
833  // This class provides a context argument for the llvm::SourceMgr diagnostic
834  // handler.
835  struct DiagHandlerContext {
836  Parser &parser;
837  StringRef filename;
838  llvm::SMRange loc;
839  } handlerContext{*this, filename, fileLoc};
840 
841  // Set the diagnostic handler for the tablegen source manager.
842  tdSrcMgr.setDiagHandler(
843  [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) {
844  auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext);
845  (void)ctx->parser.emitError(
846  ctx->loc,
847  llvm::formatv("error while processing include file `{0}`: {1}",
848  ctx->filename, diag.getMessage()));
849  },
850  &handlerContext);
851 
852  // Parse the tablegen file.
853  llvm::RecordKeeper tdRecords;
854  if (llvm::TableGenParseFile(tdSrcMgr, tdRecords))
855  return failure();
856 
857  // Process the parsed records.
858  processTdIncludeRecords(tdRecords, decls);
859 
860  // After we are done processing, move all of the tablegen source buffers to
861  // the main parser source mgr. This allows for directly using source locations
862  // from the .td files without needing to remap them.
863  parserSrcMgr.takeSourceBuffersFrom(tdSrcMgr, fileLoc.End);
864  return success();
865 }
866 
867 void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
869  // Return the length kind of the given value.
870  auto getLengthKind = [](const auto &value) {
871  if (value.isOptional())
873  return value.isVariadic() ? ods::VariableLengthKind::Variadic
875  };
876 
877  // Insert a type constraint into the ODS context.
878  ods::Context &odsContext = ctx.getODSContext();
879  auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst)
880  -> const ods::TypeConstraint & {
881  return odsContext.insertTypeConstraint(
882  cst.constraint.getUniqueDefName(),
883  processDoc(cst.constraint.getSummary()),
884  cst.constraint.getCPPClassName());
885  };
886  auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange {
887  return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)};
888  };
889 
890  // Process the parsed tablegen records to build ODS information.
891  /// Operations.
892  for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) {
893  tblgen::Operator op(def);
894 
895  // Check to see if this operation is known to support type inferrence.
896  bool supportsResultTypeInferrence =
897  op.getTrait("::mlir::InferTypeOpInterface::Trait");
898 
899  auto [odsOp, inserted] = odsContext.insertOperation(
900  op.getOperationName(), processDoc(op.getSummary()),
901  processAndFormatDoc(op.getDescription()), op.getQualCppClassName(),
902  supportsResultTypeInferrence, op.getLoc().front());
903 
904  // Ignore operations that have already been added.
905  if (!inserted)
906  continue;
907 
908  for (const tblgen::NamedAttribute &attr : op.getAttributes()) {
909  odsOp->appendAttribute(attr.name, attr.attr.isOptional(),
910  odsContext.insertAttributeConstraint(
911  attr.attr.getUniqueDefName(),
912  processDoc(attr.attr.getSummary()),
913  attr.attr.getStorageType()));
914  }
915  for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) {
916  odsOp->appendOperand(operand.name, getLengthKind(operand),
917  addTypeConstraint(operand));
918  }
919  for (const tblgen::NamedTypeConstraint &result : op.getResults()) {
920  odsOp->appendResult(result.name, getLengthKind(result),
921  addTypeConstraint(result));
922  }
923  }
924 
925  auto shouldBeSkipped = [this](llvm::Record *def) {
926  return def->isAnonymous() || curDeclScope->lookup(def->getName()) ||
927  def->isSubClassOf("DeclareInterfaceMethods");
928  };
929 
930  /// Attr constraints.
931  for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) {
932  if (shouldBeSkipped(def))
933  continue;
934 
935  tblgen::Attribute constraint(def);
936  decls.push_back(createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
937  constraint, convertLocToRange(def->getLoc().front()), attrTy,
938  constraint.getStorageType()));
939  }
940  /// Type constraints.
941  for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) {
942  if (shouldBeSkipped(def))
943  continue;
944 
945  tblgen::TypeConstraint constraint(def);
946  decls.push_back(createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
947  constraint, convertLocToRange(def->getLoc().front()), typeTy,
948  constraint.getCPPClassName()));
949  }
950  /// OpInterfaces.
952  for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("OpInterface")) {
953  if (shouldBeSkipped(def))
954  continue;
955 
956  SMRange loc = convertLocToRange(def->getLoc().front());
957 
958  std::string cppClassName =
959  llvm::formatv("{0}::{1}", def->getValueAsString("cppNamespace"),
960  def->getValueAsString("cppInterfaceName"))
961  .str();
962  std::string codeBlock =
963  llvm::formatv("return ::mlir::success(llvm::isa<{0}>(self));",
964  cppClassName)
965  .str();
966 
967  std::string desc =
968  processAndFormatDoc(def->getValueAsString("description"));
969  decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>(
970  def->getName(), codeBlock, loc, opTy, cppClassName, desc));
971  }
972 }
973 
974 template <typename ConstraintT>
975 ast::Decl *Parser::createODSNativePDLLConstraintDecl(
976  StringRef name, StringRef codeBlock, SMRange loc, ast::Type type,
977  StringRef nativeType, StringRef docString) {
978  // Build the single input parameter.
979  ast::DeclScope *argScope = pushDeclScope();
980  auto *paramVar = ast::VariableDecl::create(
981  ctx, ast::Name::create(ctx, "self", loc), type,
982  /*initExpr=*/nullptr, ast::ConstraintRef(ConstraintT::create(ctx, loc)));
983  argScope->add(paramVar);
984  popDeclScope();
985 
986  // Build the native constraint.
987  auto *constraintDecl = ast::UserConstraintDecl::createNative(
988  ctx, ast::Name::create(ctx, name, loc), paramVar,
989  /*results=*/std::nullopt, codeBlock, ast::TupleType::get(ctx),
990  nativeType);
991  constraintDecl->setDocComment(ctx, docString);
992  curDeclScope->add(constraintDecl);
993  return constraintDecl;
994 }
995 
996 template <typename ConstraintT>
997 ast::Decl *
998 Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
999  SMRange loc, ast::Type type,
1000  StringRef nativeType) {
1001  // Format the condition template.
1002  tblgen::FmtContext fmtContext;
1003  fmtContext.withSelf("self");
1004  std::string codeBlock = tblgen::tgfmt(
1005  "return ::mlir::success(" + constraint.getConditionTemplate() + ");",
1006  &fmtContext);
1007 
1008  // If documentation was enabled, build the doc string for the generated
1009  // constraint. It would be nice to do this lazily, but TableGen information is
1010  // destroyed after we finish parsing the file.
1011  std::string docString;
1012  if (enableDocumentation) {
1013  StringRef desc = constraint.getDescription();
1014  docString = processAndFormatDoc(
1015  constraint.getSummary() +
1016  (desc.empty() ? "" : ("\n\n" + constraint.getDescription())));
1017  }
1018 
1019  return createODSNativePDLLConstraintDecl<ConstraintT>(
1020  constraint.getUniqueDefName(), codeBlock, loc, type, nativeType,
1021  docString);
1022 }
1023 
1024 //===----------------------------------------------------------------------===//
1025 // Decls
1026 
1027 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
1029  switch (curToken.getKind()) {
1030  case Token::kw_Constraint:
1031  decl = parseUserConstraintDecl();
1032  break;
1033  case Token::kw_Pattern:
1034  decl = parsePatternDecl();
1035  break;
1036  case Token::kw_Rewrite:
1037  decl = parseUserRewriteDecl();
1038  break;
1039  default:
1040  return emitError("expected top-level declaration, such as a `Pattern`");
1041  }
1042  if (failed(decl))
1043  return failure();
1044 
1045  // If the decl has a name, add it to the current scope.
1046  if (const ast::Name *name = (*decl)->getName()) {
1047  if (failed(checkDefineNamedDecl(*name)))
1048  return failure();
1049  curDeclScope->add(*decl);
1050  }
1051  return decl;
1052 }
1053 
1055 Parser::parseNamedAttributeDecl(std::optional<StringRef> parentOpName) {
1056  // Check for name code completion.
1057  if (curToken.is(Token::code_complete))
1058  return codeCompleteAttributeName(parentOpName);
1059 
1060  std::string attrNameStr;
1061  if (curToken.isString())
1062  attrNameStr = curToken.getStringValue();
1063  else if (curToken.is(Token::identifier) || curToken.isKeyword())
1064  attrNameStr = curToken.getSpelling().str();
1065  else
1066  return emitError("expected identifier or string attribute name");
1067  const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc());
1068  consumeToken();
1069 
1070  // Check for a value of the attribute.
1071  ast::Expr *attrValue = nullptr;
1072  if (consumeIf(Token::equal)) {
1073  FailureOr<ast::Expr *> attrExpr = parseExpr();
1074  if (failed(attrExpr))
1075  return failure();
1076  attrValue = *attrExpr;
1077  } else {
1078  // If there isn't a concrete value, create an expression representing a
1079  // UnitAttr.
1080  attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit");
1081  }
1082 
1083  return ast::NamedAttributeDecl::create(ctx, name, attrValue);
1084 }
1085 
1086 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody(
1087  function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
1088  bool expectTerminalSemicolon) {
1089  consumeToken(Token::equal_arrow);
1090 
1091  // Parse the single statement of the lambda body.
1092  SMLoc bodyStartLoc = curToken.getStartLoc();
1093  pushDeclScope();
1094  FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon);
1095  bool failedToParse =
1096  failed(singleStatement) || failed(processStatementFn(*singleStatement));
1097  popDeclScope();
1098  if (failedToParse)
1099  return failure();
1100 
1101  SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc());
1102  return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement);
1103 }
1104 
1105 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() {
1106  // Ensure that the argument is named.
1107  if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword())
1108  return emitError("expected identifier argument name");
1109 
1110  // Parse the argument similarly to a normal variable.
1111  StringRef name = curToken.getSpelling();
1112  SMRange nameLoc = curToken.getLoc();
1113  consumeToken();
1114 
1115  if (failed(
1116  parseToken(Token::colon, "expected `:` before argument constraint")))
1117  return failure();
1118 
1119  FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1120  if (failed(cst))
1121  return failure();
1122 
1123  return createArgOrResultVariableDecl(name, nameLoc, *cst);
1124 }
1125 
1126 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) {
1127  // Check to see if this result is named.
1128  if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
1129  // Check to see if this name actually refers to a Constraint.
1130  if (!curDeclScope->lookup<ast::ConstraintDecl>(curToken.getSpelling())) {
1131  // If it wasn't a constraint, parse the result similarly to a variable. If
1132  // there is already an existing decl, we will emit an error when defining
1133  // this variable later.
1134  StringRef name = curToken.getSpelling();
1135  SMRange nameLoc = curToken.getLoc();
1136  consumeToken();
1137 
1138  if (failed(parseToken(Token::colon,
1139  "expected `:` before result constraint")))
1140  return failure();
1141 
1142  FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1143  if (failed(cst))
1144  return failure();
1145 
1146  return createArgOrResultVariableDecl(name, nameLoc, *cst);
1147  }
1148  }
1149 
1150  // If it isn't named, we parse the constraint directly and create an unnamed
1151  // result variable.
1152  FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1153  if (failed(cst))
1154  return failure();
1155 
1156  return createArgOrResultVariableDecl("", cst->referenceLoc, *cst);
1157 }
1158 
1160 Parser::parseUserConstraintDecl(bool isInline) {
1161  // Constraints and rewrites have very similar formats, dispatch to a shared
1162  // interface for parsing.
1163  return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1164  [&](auto &&...args) {
1165  return this->parseUserPDLLConstraintDecl(args...);
1166  },
1167  ParserContext::Constraint, "constraint", isInline);
1168 }
1169 
1170 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() {
1172  parseUserConstraintDecl(/*isInline=*/true);
1173  if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1174  return failure();
1175 
1176  curDeclScope->add(*decl);
1177  return decl;
1178 }
1179 
1180 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl(
1181  const ast::Name &name, bool isInline,
1182  ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1183  ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1184  // Push the argument scope back onto the list, so that the body can
1185  // reference arguments.
1186  pushDeclScope(argumentScope);
1187 
1188  // Parse the body of the constraint. The body is either defined as a compound
1189  // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`.
1190  ast::CompoundStmt *body;
1191  if (curToken.is(Token::equal_arrow)) {
1192  FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1193  [&](ast::Stmt *&stmt) -> LogicalResult {
1194  ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt);
1195  if (!stmtExpr) {
1196  return emitError(stmt->getLoc(),
1197  "expected `Constraint` lambda body to contain a "
1198  "single expression");
1199  }
1200  stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr);
1201  return success();
1202  },
1203  /*expectTerminalSemicolon=*/!isInline);
1204  if (failed(bodyResult))
1205  return failure();
1206  body = *bodyResult;
1207  } else {
1208  FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1209  if (failed(bodyResult))
1210  return failure();
1211  body = *bodyResult;
1212 
1213  // Verify the structure of the body.
1214  auto bodyIt = body->begin(), bodyE = body->end();
1215  for (; bodyIt != bodyE; ++bodyIt)
1216  if (isa<ast::ReturnStmt>(*bodyIt))
1217  break;
1218  if (failed(validateUserConstraintOrRewriteReturn(
1219  "Constraint", body, bodyIt, bodyE, results, resultType)))
1220  return failure();
1221  }
1222  popDeclScope();
1223 
1224  return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1225  name, arguments, results, resultType, body);
1226 }
1227 
1228 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) {
1229  // Constraints and rewrites have very similar formats, dispatch to a shared
1230  // interface for parsing.
1231  return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1232  [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(args...); },
1233  ParserContext::Rewrite, "rewrite", isInline);
1234 }
1235 
1236 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() {
1238  parseUserRewriteDecl(/*isInline=*/true);
1239  if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1240  return failure();
1241 
1242  curDeclScope->add(*decl);
1243  return decl;
1244 }
1245 
1246 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl(
1247  const ast::Name &name, bool isInline,
1248  ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1249  ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1250  // Push the argument scope back onto the list, so that the body can
1251  // reference arguments.
1252  curDeclScope = argumentScope;
1253  ast::CompoundStmt *body;
1254  if (curToken.is(Token::equal_arrow)) {
1255  FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1256  [&](ast::Stmt *&statement) -> LogicalResult {
1257  if (isa<ast::OpRewriteStmt>(statement))
1258  return success();
1259 
1260  ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement);
1261  if (!statementExpr) {
1262  return emitError(
1263  statement->getLoc(),
1264  "expected `Rewrite` lambda body to contain a single expression "
1265  "or an operation rewrite statement; such as `erase`, "
1266  "`replace`, or `rewrite`");
1267  }
1268  statement =
1269  ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr);
1270  return success();
1271  },
1272  /*expectTerminalSemicolon=*/!isInline);
1273  if (failed(bodyResult))
1274  return failure();
1275  body = *bodyResult;
1276  } else {
1277  FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1278  if (failed(bodyResult))
1279  return failure();
1280  body = *bodyResult;
1281  }
1282  popDeclScope();
1283 
1284  // Verify the structure of the body.
1285  auto bodyIt = body->begin(), bodyE = body->end();
1286  for (; bodyIt != bodyE; ++bodyIt)
1287  if (isa<ast::ReturnStmt>(*bodyIt))
1288  break;
1289  if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt,
1290  bodyE, results, resultType)))
1291  return failure();
1292  return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1293  name, arguments, results, resultType, body);
1294 }
1295 
1296 template <typename T, typename ParseUserPDLLDeclFnT>
1297 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl(
1298  ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
1299  StringRef anonymousNamePrefix, bool isInline) {
1300  SMRange loc = curToken.getLoc();
1301  consumeToken();
1302  llvm::SaveAndRestore saveCtx(parserContext, declContext);
1303 
1304  // Parse the name of the decl.
1305  const ast::Name *name = nullptr;
1306  if (curToken.isNot(Token::identifier)) {
1307  // Only inline decls can be un-named. Inline decls are similar to "lambdas"
1308  // in C++, so being unnamed is fine.
1309  if (!isInline)
1310  return emitError("expected identifier name");
1311 
1312  // Create a unique anonymous name to use, as the name for this decl is not
1313  // important.
1314  std::string anonName =
1315  llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix,
1316  anonymousDeclNameCounter++)
1317  .str();
1318  name = &ast::Name::create(ctx, anonName, loc);
1319  } else {
1320  // If a name was provided, we can use it directly.
1321  name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1322  consumeToken(Token::identifier);
1323  }
1324 
1325  // Parse the functional signature of the decl.
1326  SmallVector<ast::VariableDecl *> arguments, results;
1327  ast::DeclScope *argumentScope;
1328  ast::Type resultType;
1329  if (failed(parseUserConstraintOrRewriteSignature(arguments, results,
1330  argumentScope, resultType)))
1331  return failure();
1332 
1333  // Check to see which type of constraint this is. If the constraint contains a
1334  // compound body, this is a PDLL decl.
1335  if (curToken.isAny(Token::l_brace, Token::equal_arrow))
1336  return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results,
1337  resultType);
1338 
1339  // Otherwise, this is a native decl.
1340  return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments,
1341  results, resultType);
1342 }
1343 
1344 template <typename T>
1345 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
1346  const ast::Name &name, bool isInline,
1348  ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1349  // If followed by a string, the native code body has also been specified.
1350  std::string codeStrStorage;
1351  std::optional<StringRef> optCodeStr;
1352  if (curToken.isString()) {
1353  codeStrStorage = curToken.getStringValue();
1354  optCodeStr = codeStrStorage;
1355  consumeToken();
1356  } else if (isInline) {
1357  return emitError(name.getLoc(),
1358  "external declarations must be declared in global scope");
1359  } else if (curToken.is(Token::error)) {
1360  return failure();
1361  }
1362  if (failed(parseToken(Token::semicolon,
1363  "expected `;` after native declaration")))
1364  return failure();
1365  return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
1366 }
1367 
1368 LogicalResult Parser::parseUserConstraintOrRewriteSignature(
1371  ast::DeclScope *&argumentScope, ast::Type &resultType) {
1372  // Parse the argument list of the decl.
1373  if (failed(parseToken(Token::l_paren, "expected `(` to start argument list")))
1374  return failure();
1375 
1376  argumentScope = pushDeclScope();
1377  if (curToken.isNot(Token::r_paren)) {
1378  do {
1379  FailureOr<ast::VariableDecl *> argument = parseArgumentDecl();
1380  if (failed(argument))
1381  return failure();
1382  arguments.emplace_back(*argument);
1383  } while (consumeIf(Token::comma));
1384  }
1385  popDeclScope();
1386  if (failed(parseToken(Token::r_paren, "expected `)` to end argument list")))
1387  return failure();
1388 
1389  // Parse the results of the decl.
1390  pushDeclScope();
1391  if (consumeIf(Token::arrow)) {
1392  auto parseResultFn = [&]() -> LogicalResult {
1393  FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size());
1394  if (failed(result))
1395  return failure();
1396  results.emplace_back(*result);
1397  return success();
1398  };
1399 
1400  // Check for a list of results.
1401  if (consumeIf(Token::l_paren)) {
1402  do {
1403  if (failed(parseResultFn()))
1404  return failure();
1405  } while (consumeIf(Token::comma));
1406  if (failed(parseToken(Token::r_paren, "expected `)` to end result list")))
1407  return failure();
1408 
1409  // Otherwise, there is only one result.
1410  } else if (failed(parseResultFn())) {
1411  return failure();
1412  }
1413  }
1414  popDeclScope();
1415 
1416  // Compute the result type of the decl.
1417  resultType = createUserConstraintRewriteResultType(results);
1418 
1419  // Verify that results are only named if there are more than one.
1420  if (results.size() == 1 && !results.front()->getName().getName().empty()) {
1421  return emitError(
1422  results.front()->getLoc(),
1423  "cannot create a single-element tuple with an element label");
1424  }
1425  return success();
1426 }
1427 
1428 LogicalResult Parser::validateUserConstraintOrRewriteReturn(
1429  StringRef declType, ast::CompoundStmt *body,
1432  ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) {
1433  // Handle if a `return` was provided.
1434  if (bodyIt != bodyE) {
1435  // Emit an error if we have trailing statements after the return.
1436  if (std::next(bodyIt) != bodyE) {
1437  return emitError(
1438  (*std::next(bodyIt))->getLoc(),
1439  llvm::formatv("`return` terminated the `{0}` body, but found "
1440  "trailing statements afterwards",
1441  declType));
1442  }
1443 
1444  // Otherwise if a return wasn't provided, check that no results are
1445  // expected.
1446  } else if (!results.empty()) {
1447  return emitError(
1448  {body->getLoc().End, body->getLoc().End},
1449  llvm::formatv("missing return in a `{0}` expected to return `{1}`",
1450  declType, resultType));
1451  }
1452  return success();
1453 }
1454 
1455 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
1456  return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult {
1457  if (isa<ast::OpRewriteStmt>(statement))
1458  return success();
1459  return emitError(
1460  statement->getLoc(),
1461  "expected Pattern lambda body to contain a single operation "
1462  "rewrite statement, such as `erase`, `replace`, or `rewrite`");
1463  });
1464 }
1465 
1466 FailureOr<ast::Decl *> Parser::parsePatternDecl() {
1467  SMRange loc = curToken.getLoc();
1468  consumeToken(Token::kw_Pattern);
1469  llvm::SaveAndRestore saveCtx(parserContext, ParserContext::PatternMatch);
1470 
1471  // Check for an optional identifier for the pattern name.
1472  const ast::Name *name = nullptr;
1473  if (curToken.is(Token::identifier)) {
1474  name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1475  consumeToken(Token::identifier);
1476  }
1477 
1478  // Parse any pattern metadata.
1479  ParsedPatternMetadata metadata;
1480  if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata)))
1481  return failure();
1482 
1483  // Parse the pattern body.
1484  ast::CompoundStmt *body;
1485 
1486  // Handle a lambda body.
1487  if (curToken.is(Token::equal_arrow)) {
1488  FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody();
1489  if (failed(bodyResult))
1490  return failure();
1491  body = *bodyResult;
1492  } else {
1493  if (curToken.isNot(Token::l_brace))
1494  return emitError("expected `{` or `=>` to start pattern body");
1495  FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1496  if (failed(bodyResult))
1497  return failure();
1498  body = *bodyResult;
1499 
1500  // Verify the body of the pattern.
1501  auto bodyIt = body->begin(), bodyE = body->end();
1502  for (; bodyIt != bodyE; ++bodyIt) {
1503  if (isa<ast::ReturnStmt>(*bodyIt)) {
1504  return emitError((*bodyIt)->getLoc(),
1505  "`return` statements are only permitted within a "
1506  "`Constraint` or `Rewrite` body");
1507  }
1508  // Break when we've found the rewrite statement.
1509  if (isa<ast::OpRewriteStmt>(*bodyIt))
1510  break;
1511  }
1512  if (bodyIt == bodyE) {
1513  return emitError(loc,
1514  "expected Pattern body to terminate with an operation "
1515  "rewrite statement, such as `erase`");
1516  }
1517  if (std::next(bodyIt) != bodyE) {
1518  return emitError((*std::next(bodyIt))->getLoc(),
1519  "Pattern body was terminated by an operation "
1520  "rewrite statement, but found trailing statements");
1521  }
1522  }
1523 
1524  return createPatternDecl(loc, name, metadata, body);
1525 }
1526 
1528 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
1529  std::optional<SMRange> benefitLoc;
1530  std::optional<SMRange> hasBoundedRecursionLoc;
1531 
1532  do {
1533  // Handle metadata code completion.
1534  if (curToken.is(Token::code_complete))
1535  return codeCompletePatternMetadata();
1536 
1537  if (curToken.isNot(Token::identifier))
1538  return emitError("expected pattern metadata identifier");
1539  StringRef metadataStr = curToken.getSpelling();
1540  SMRange metadataLoc = curToken.getLoc();
1541  consumeToken(Token::identifier);
1542 
1543  // Parse the benefit metadata: benefit(<integer-value>)
1544  if (metadataStr == "benefit") {
1545  if (benefitLoc) {
1546  return emitErrorAndNote(metadataLoc,
1547  "pattern benefit has already been specified",
1548  *benefitLoc, "see previous definition here");
1549  }
1550  if (failed(parseToken(Token::l_paren,
1551  "expected `(` before pattern benefit")))
1552  return failure();
1553 
1554  uint16_t benefitValue = 0;
1555  if (curToken.isNot(Token::integer))
1556  return emitError("expected integral pattern benefit");
1557  if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue))
1558  return emitError(
1559  "expected pattern benefit to fit within a 16-bit integer");
1560  consumeToken(Token::integer);
1561 
1562  metadata.benefit = benefitValue;
1563  benefitLoc = metadataLoc;
1564 
1565  if (failed(
1566  parseToken(Token::r_paren, "expected `)` after pattern benefit")))
1567  return failure();
1568  continue;
1569  }
1570 
1571  // Parse the bounded recursion metadata: recursion
1572  if (metadataStr == "recursion") {
1573  if (hasBoundedRecursionLoc) {
1574  return emitErrorAndNote(
1575  metadataLoc,
1576  "pattern recursion metadata has already been specified",
1577  *hasBoundedRecursionLoc, "see previous definition here");
1578  }
1579  metadata.hasBoundedRecursion = true;
1580  hasBoundedRecursionLoc = metadataLoc;
1581  continue;
1582  }
1583 
1584  return emitError(metadataLoc, "unknown pattern metadata");
1585  } while (consumeIf(Token::comma));
1586 
1587  return success();
1588 }
1589 
1590 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() {
1591  consumeToken(Token::less);
1592 
1593  FailureOr<ast::Expr *> typeExpr = parseExpr();
1594  if (failed(typeExpr) ||
1595  failed(parseToken(Token::greater,
1596  "expected `>` after variable type constraint")))
1597  return failure();
1598  return typeExpr;
1599 }
1600 
1601 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) {
1602  assert(curDeclScope && "defining decl outside of a decl scope");
1603  if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) {
1604  return emitErrorAndNote(
1605  name.getLoc(), "`" + name.getName() + "` has already been defined",
1606  lastDecl->getName()->getLoc(), "see previous definition here");
1607  }
1608  return success();
1609 }
1610 
1612 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1613  ast::Expr *initExpr,
1614  ArrayRef<ast::ConstraintRef> constraints) {
1615  assert(curDeclScope && "defining variable outside of decl scope");
1616  const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc);
1617 
1618  // If the name of the variable indicates a special variable, we don't add it
1619  // to the scope. This variable is local to the definition point.
1620  if (name.empty() || name == "_") {
1621  return ast::VariableDecl::create(ctx, nameDecl, type, initExpr,
1622  constraints);
1623  }
1624  if (failed(checkDefineNamedDecl(nameDecl)))
1625  return failure();
1626 
1627  auto *varDecl =
1628  ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints);
1629  curDeclScope->add(varDecl);
1630  return varDecl;
1631 }
1632 
1634 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1635  ArrayRef<ast::ConstraintRef> constraints) {
1636  return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr,
1637  constraints);
1638 }
1639 
1640 LogicalResult Parser::parseVariableDeclConstraintList(
1641  SmallVectorImpl<ast::ConstraintRef> &constraints) {
1642  std::optional<SMRange> typeConstraint;
1643  auto parseSingleConstraint = [&] {
1644  FailureOr<ast::ConstraintRef> constraint = parseConstraint(
1645  typeConstraint, constraints, /*allowInlineTypeConstraints=*/true);
1646  if (failed(constraint))
1647  return failure();
1648  constraints.push_back(*constraint);
1649  return success();
1650  };
1651 
1652  // Check to see if this is a single constraint, or a list.
1653  if (!consumeIf(Token::l_square))
1654  return parseSingleConstraint();
1655 
1656  do {
1657  if (failed(parseSingleConstraint()))
1658  return failure();
1659  } while (consumeIf(Token::comma));
1660  return parseToken(Token::r_square, "expected `]` after constraint list");
1661 }
1662 
1664 Parser::parseConstraint(std::optional<SMRange> &typeConstraint,
1665  ArrayRef<ast::ConstraintRef> existingConstraints,
1666  bool allowInlineTypeConstraints) {
1667  auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
1668  if (!allowInlineTypeConstraints) {
1669  return emitError(
1670  curToken.getLoc(),
1671  "inline `Attr`, `Value`, and `ValueRange` type constraints are not "
1672  "permitted on arguments or results");
1673  }
1674  if (typeConstraint)
1675  return emitErrorAndNote(
1676  curToken.getLoc(),
1677  "the type of this variable has already been constrained",
1678  *typeConstraint, "see previous constraint location here");
1679  FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr();
1680  if (failed(constraintExpr))
1681  return failure();
1682  typeExpr = *constraintExpr;
1683  typeConstraint = typeExpr->getLoc();
1684  return success();
1685  };
1686 
1687  SMRange loc = curToken.getLoc();
1688  switch (curToken.getKind()) {
1689  case Token::kw_Attr: {
1690  consumeToken(Token::kw_Attr);
1691 
1692  // Check for a type constraint.
1693  ast::Expr *typeExpr = nullptr;
1694  if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1695  return failure();
1696  return ast::ConstraintRef(
1697  ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc);
1698  }
1699  case Token::kw_Op: {
1700  consumeToken(Token::kw_Op);
1701 
1702  // Parse an optional operation name. If the name isn't provided, this refers
1703  // to "any" operation.
1705  parseWrappedOperationName(/*allowEmptyName=*/true);
1706  if (failed(opName))
1707  return failure();
1708 
1709  return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName),
1710  loc);
1711  }
1712  case Token::kw_Type:
1713  consumeToken(Token::kw_Type);
1714  return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc);
1715  case Token::kw_TypeRange:
1716  consumeToken(Token::kw_TypeRange);
1718  loc);
1719  case Token::kw_Value: {
1720  consumeToken(Token::kw_Value);
1721 
1722  // Check for a type constraint.
1723  ast::Expr *typeExpr = nullptr;
1724  if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1725  return failure();
1726 
1727  return ast::ConstraintRef(
1728  ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc);
1729  }
1730  case Token::kw_ValueRange: {
1731  consumeToken(Token::kw_ValueRange);
1732 
1733  // Check for a type constraint.
1734  ast::Expr *typeExpr = nullptr;
1735  if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1736  return failure();
1737 
1738  return ast::ConstraintRef(
1739  ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc);
1740  }
1741 
1742  case Token::kw_Constraint: {
1743  // Handle an inline constraint.
1744  FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1745  if (failed(decl))
1746  return failure();
1747  return ast::ConstraintRef(*decl, loc);
1748  }
1749  case Token::identifier: {
1750  StringRef constraintName = curToken.getSpelling();
1751  consumeToken(Token::identifier);
1752 
1753  // Lookup the referenced constraint.
1754  ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName);
1755  if (!cstDecl) {
1756  return emitError(loc, "unknown reference to constraint `" +
1757  constraintName + "`");
1758  }
1759 
1760  // Handle a reference to a proper constraint.
1761  if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl))
1762  return ast::ConstraintRef(cst, loc);
1763 
1764  return emitErrorAndNote(
1765  loc, "invalid reference to non-constraint", cstDecl->getLoc(),
1766  "see the definition of `" + constraintName + "` here");
1767  }
1768  // Handle single entity constraint code completion.
1769  case Token::code_complete: {
1770  // Try to infer the current type for use by code completion.
1771  ast::Type inferredType;
1772  if (failed(validateVariableConstraints(existingConstraints, inferredType)))
1773  return failure();
1774 
1775  return codeCompleteConstraintName(inferredType, allowInlineTypeConstraints);
1776  }
1777  default:
1778  break;
1779  }
1780  return emitError(loc, "expected identifier constraint");
1781 }
1782 
1783 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() {
1784  std::optional<SMRange> typeConstraint;
1785  return parseConstraint(typeConstraint, /*existingConstraints=*/std::nullopt,
1786  /*allowInlineTypeConstraints=*/false);
1787 }
1788 
1789 //===----------------------------------------------------------------------===//
1790 // Exprs
1791 
1792 FailureOr<ast::Expr *> Parser::parseExpr() {
1793  if (curToken.is(Token::underscore))
1794  return parseUnderscoreExpr();
1795 
1796  // Parse the LHS expression.
1797  FailureOr<ast::Expr *> lhsExpr;
1798  switch (curToken.getKind()) {
1799  case Token::kw_attr:
1800  lhsExpr = parseAttributeExpr();
1801  break;
1802  case Token::kw_Constraint:
1803  lhsExpr = parseInlineConstraintLambdaExpr();
1804  break;
1805  case Token::kw_not:
1806  lhsExpr = parseNegatedExpr();
1807  break;
1808  case Token::identifier:
1809  lhsExpr = parseIdentifierExpr();
1810  break;
1811  case Token::kw_op:
1812  lhsExpr = parseOperationExpr();
1813  break;
1814  case Token::kw_Rewrite:
1815  lhsExpr = parseInlineRewriteLambdaExpr();
1816  break;
1817  case Token::kw_type:
1818  lhsExpr = parseTypeExpr();
1819  break;
1820  case Token::l_paren:
1821  lhsExpr = parseTupleExpr();
1822  break;
1823  default:
1824  return emitError("expected expression");
1825  }
1826  if (failed(lhsExpr))
1827  return failure();
1828 
1829  // Check for an operator expression.
1830  while (true) {
1831  switch (curToken.getKind()) {
1832  case Token::dot:
1833  lhsExpr = parseMemberAccessExpr(*lhsExpr);
1834  break;
1835  case Token::l_paren:
1836  lhsExpr = parseCallExpr(*lhsExpr);
1837  break;
1838  default:
1839  return lhsExpr;
1840  }
1841  if (failed(lhsExpr))
1842  return failure();
1843  }
1844 }
1845 
1846 FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
1847  SMRange loc = curToken.getLoc();
1848  consumeToken(Token::kw_attr);
1849 
1850  // If we aren't followed by a `<`, the `attr` keyword is treated as a normal
1851  // identifier.
1852  if (!consumeIf(Token::less)) {
1853  resetToken(loc);
1854  return parseIdentifierExpr();
1855  }
1856 
1857  if (!curToken.isString())
1858  return emitError("expected string literal containing MLIR attribute");
1859  std::string attrExpr = curToken.getStringValue();
1860  consumeToken();
1861 
1862  loc.End = curToken.getEndLoc();
1863  if (failed(
1864  parseToken(Token::greater, "expected `>` after attribute literal")))
1865  return failure();
1866  return ast::AttributeExpr::create(ctx, loc, attrExpr);
1867 }
1868 
1869 FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr,
1870  bool isNegated) {
1871  consumeToken(Token::l_paren);
1872 
1873  // Parse the arguments of the call.
1874  SmallVector<ast::Expr *> arguments;
1875  if (curToken.isNot(Token::r_paren)) {
1876  do {
1877  // Handle code completion for the call arguments.
1878  if (curToken.is(Token::code_complete)) {
1879  codeCompleteCallSignature(parentExpr, arguments.size());
1880  return failure();
1881  }
1882 
1883  FailureOr<ast::Expr *> argument = parseExpr();
1884  if (failed(argument))
1885  return failure();
1886  arguments.push_back(*argument);
1887  } while (consumeIf(Token::comma));
1888  }
1889 
1890  SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1891  if (failed(parseToken(Token::r_paren, "expected `)` after argument list")))
1892  return failure();
1893 
1894  return createCallExpr(loc, parentExpr, arguments, isNegated);
1895 }
1896 
1897 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
1898  ast::Decl *decl = curDeclScope->lookup(name);
1899  if (!decl)
1900  return emitError(loc, "undefined reference to `" + name + "`");
1901 
1902  return createDeclRefExpr(loc, decl);
1903 }
1904 
1905 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
1906  StringRef name = curToken.getSpelling();
1907  SMRange nameLoc = curToken.getLoc();
1908  consumeToken();
1909 
1910  // Check to see if this is a decl ref expression that defines a variable
1911  // inline.
1912  if (consumeIf(Token::colon)) {
1913  SmallVector<ast::ConstraintRef> constraints;
1914  if (failed(parseVariableDeclConstraintList(constraints)))
1915  return failure();
1916  ast::Type type;
1917  if (failed(validateVariableConstraints(constraints, type)))
1918  return failure();
1919  return createInlineVariableExpr(type, name, nameLoc, constraints);
1920  }
1921 
1922  return parseDeclRefExpr(name, nameLoc);
1923 }
1924 
1925 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() {
1926  FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1927  if (failed(decl))
1928  return failure();
1929 
1930  return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1932 }
1933 
1934 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() {
1935  FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl();
1936  if (failed(decl))
1937  return failure();
1938 
1939  return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1940  ast::RewriteType::get(ctx));
1941 }
1942 
1943 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
1944  SMRange dotLoc = curToken.getLoc();
1945  consumeToken(Token::dot);
1946 
1947  // Check for code completion of the member name.
1948  if (curToken.is(Token::code_complete))
1949  return codeCompleteMemberAccess(parentExpr);
1950 
1951  // Parse the member name.
1952  Token memberNameTok = curToken;
1953  if (memberNameTok.isNot(Token::identifier, Token::integer) &&
1954  !memberNameTok.isKeyword())
1955  return emitError(dotLoc, "expected identifier or numeric member name");
1956  StringRef memberName = memberNameTok.getSpelling();
1957  SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1958  consumeToken();
1959 
1960  return createMemberAccessExpr(parentExpr, memberName, loc);
1961 }
1962 
1963 FailureOr<ast::Expr *> Parser::parseNegatedExpr() {
1964  consumeToken(Token::kw_not);
1965  // Only native constraints are supported after negation
1966  if (!curToken.is(Token::identifier))
1967  return emitError("expected native constraint");
1968  FailureOr<ast::Expr *> identifierExpr = parseIdentifierExpr();
1969  if (failed(identifierExpr))
1970  return failure();
1971  if (!curToken.is(Token::l_paren))
1972  return emitError("expected `(` after function name");
1973  return parseCallExpr(*identifierExpr, /*isNegated = */ true);
1974 }
1975 
1976 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
1977  SMRange loc = curToken.getLoc();
1978 
1979  // Check for code completion for the dialect name.
1980  if (curToken.is(Token::code_complete))
1981  return codeCompleteDialectName();
1982 
1983  // Handle the case of an no operation name.
1984  if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) {
1985  if (allowEmptyName)
1986  return ast::OpNameDecl::create(ctx, SMRange());
1987  return emitError("expected dialect namespace");
1988  }
1989  StringRef name = curToken.getSpelling();
1990  consumeToken();
1991 
1992  // Otherwise, this is a literal operation name.
1993  if (failed(parseToken(Token::dot, "expected `.` after dialect namespace")))
1994  return failure();
1995 
1996  // Check for code completion for the operation name.
1997  if (curToken.is(Token::code_complete))
1998  return codeCompleteOperationName(name);
1999 
2000  if (curToken.isNot(Token::identifier) && !curToken.isKeyword())
2001  return emitError("expected operation name after dialect namespace");
2002 
2003  name = StringRef(name.data(), name.size() + 1);
2004  do {
2005  name = StringRef(name.data(), name.size() + curToken.getSpelling().size());
2006  loc.End = curToken.getEndLoc();
2007  consumeToken();
2008  } while (curToken.isAny(Token::identifier, Token::dot) ||
2009  curToken.isKeyword());
2010  return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc));
2011 }
2012 
2014 Parser::parseWrappedOperationName(bool allowEmptyName) {
2015  if (!consumeIf(Token::less))
2016  return ast::OpNameDecl::create(ctx, SMRange());
2017 
2018  FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName);
2019  if (failed(opNameDecl))
2020  return failure();
2021 
2022  if (failed(parseToken(Token::greater, "expected `>` after operation name")))
2023  return failure();
2024  return opNameDecl;
2025 }
2026 
2028 Parser::parseOperationExpr(OpResultTypeContext inputResultTypeContext) {
2029  SMRange loc = curToken.getLoc();
2030  consumeToken(Token::kw_op);
2031 
2032  // If it isn't followed by a `<`, the `op` keyword is treated as a normal
2033  // identifier.
2034  if (curToken.isNot(Token::less)) {
2035  resetToken(loc);
2036  return parseIdentifierExpr();
2037  }
2038 
2039  // Parse the operation name. The name may be elided, in which case the
2040  // operation refers to "any" operation(i.e. a difference between `MyOp` and
2041  // `Operation*`). Operation names within a rewrite context must be named.
2042  bool allowEmptyName = parserContext != ParserContext::Rewrite;
2043  FailureOr<ast::OpNameDecl *> opNameDecl =
2044  parseWrappedOperationName(allowEmptyName);
2045  if (failed(opNameDecl))
2046  return failure();
2047  std::optional<StringRef> opName = (*opNameDecl)->getName();
2048 
2049  // Functor used to create an implicit range variable, used for implicit "all"
2050  // operand or results variables.
2051  auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) {
2053  defineVariableDecl("_", loc, type, ast::ConstraintRef(cst, loc));
2054  assert(succeeded(rangeVar) && "expected range variable to be valid");
2055  return ast::DeclRefExpr::create(ctx, loc, *rangeVar, type);
2056  };
2057 
2058  // Check for the optional list of operands.
2059  SmallVector<ast::Expr *> operands;
2060  if (!consumeIf(Token::l_paren)) {
2061  // If the operand list isn't specified and we are in a match context, define
2062  // an inplace unconstrained operand range corresponding to all of the
2063  // operands of the operation. This avoids treating zero operands the same
2064  // way as "unconstrained operands".
2065  if (parserContext != ParserContext::Rewrite) {
2066  operands.push_back(createImplicitRangeVar(
2067  ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy));
2068  }
2069  } else if (!consumeIf(Token::r_paren)) {
2070  // If the operand list was specified and non-empty, parse the operands.
2071  do {
2072  // Check for operand signature code completion.
2073  if (curToken.is(Token::code_complete)) {
2074  codeCompleteOperationOperandsSignature(opName, operands.size());
2075  return failure();
2076  }
2077 
2078  FailureOr<ast::Expr *> operand = parseExpr();
2079  if (failed(operand))
2080  return failure();
2081  operands.push_back(*operand);
2082  } while (consumeIf(Token::comma));
2083 
2084  if (failed(parseToken(Token::r_paren,
2085  "expected `)` after operation operand list")))
2086  return failure();
2087  }
2088 
2089  // Check for the optional list of attributes.
2091  if (consumeIf(Token::l_brace)) {
2092  do {
2094  parseNamedAttributeDecl(opName);
2095  if (failed(decl))
2096  return failure();
2097  attributes.emplace_back(*decl);
2098  } while (consumeIf(Token::comma));
2099 
2100  if (failed(parseToken(Token::r_brace,
2101  "expected `}` after operation attribute list")))
2102  return failure();
2103  }
2104 
2105  // Handle the result types of the operation.
2106  SmallVector<ast::Expr *> resultTypes;
2107  OpResultTypeContext resultTypeContext = inputResultTypeContext;
2108 
2109  // Check for an explicit list of result types.
2110  if (consumeIf(Token::arrow)) {
2111  if (failed(parseToken(Token::l_paren,
2112  "expected `(` before operation result type list")))
2113  return failure();
2114 
2115  // If result types are provided, initially assume that the operation does
2116  // not rely on type inferrence. We don't assert that it isn't, because we
2117  // may be inferring the value of some type/type range variables, but given
2118  // that these variables may be defined in calls we can't always discern when
2119  // this is the case.
2120  resultTypeContext = OpResultTypeContext::Explicit;
2121 
2122  // Handle the case of an empty result list.
2123  if (!consumeIf(Token::r_paren)) {
2124  do {
2125  // Check for result signature code completion.
2126  if (curToken.is(Token::code_complete)) {
2127  codeCompleteOperationResultsSignature(opName, resultTypes.size());
2128  return failure();
2129  }
2130 
2131  FailureOr<ast::Expr *> resultTypeExpr = parseExpr();
2132  if (failed(resultTypeExpr))
2133  return failure();
2134  resultTypes.push_back(*resultTypeExpr);
2135  } while (consumeIf(Token::comma));
2136 
2137  if (failed(parseToken(Token::r_paren,
2138  "expected `)` after operation result type list")))
2139  return failure();
2140  }
2141  } else if (parserContext != ParserContext::Rewrite) {
2142  // If the result list isn't specified and we are in a match context, define
2143  // an inplace unconstrained result range corresponding to all of the results
2144  // of the operation. This avoids treating zero results the same way as
2145  // "unconstrained results".
2146  resultTypes.push_back(createImplicitRangeVar(
2147  ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy));
2148  } else if (resultTypeContext == OpResultTypeContext::Explicit) {
2149  // If the result list isn't specified and we are in a rewrite, try to infer
2150  // them at runtime instead.
2151  resultTypeContext = OpResultTypeContext::Interface;
2152  }
2153 
2154  return createOperationExpr(loc, *opNameDecl, resultTypeContext, operands,
2155  attributes, resultTypes);
2156 }
2157 
2158 FailureOr<ast::Expr *> Parser::parseTupleExpr() {
2159  SMRange loc = curToken.getLoc();
2160  consumeToken(Token::l_paren);
2161 
2162  DenseMap<StringRef, SMRange> usedNames;
2163  SmallVector<StringRef> elementNames;
2164  SmallVector<ast::Expr *> elements;
2165  if (curToken.isNot(Token::r_paren)) {
2166  do {
2167  // Check for the optional element name assignment before the value.
2168  StringRef elementName;
2169  if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
2170  Token elementNameTok = curToken;
2171  consumeToken();
2172 
2173  // The element name is only present if followed by an `=`.
2174  if (consumeIf(Token::equal)) {
2175  elementName = elementNameTok.getSpelling();
2176 
2177  // Check to see if this name is already used.
2178  auto elementNameIt =
2179  usedNames.try_emplace(elementName, elementNameTok.getLoc());
2180  if (!elementNameIt.second) {
2181  return emitErrorAndNote(
2182  elementNameTok.getLoc(),
2183  llvm::formatv("duplicate tuple element label `{0}`",
2184  elementName),
2185  elementNameIt.first->getSecond(),
2186  "see previous label use here");
2187  }
2188  } else {
2189  // Otherwise, we treat this as part of an expression so reset the
2190  // lexer.
2191  resetToken(elementNameTok.getLoc());
2192  }
2193  }
2194  elementNames.push_back(elementName);
2195 
2196  // Parse the tuple element value.
2197  FailureOr<ast::Expr *> element = parseExpr();
2198  if (failed(element))
2199  return failure();
2200  elements.push_back(*element);
2201  } while (consumeIf(Token::comma));
2202  }
2203  loc.End = curToken.getEndLoc();
2204  if (failed(
2205  parseToken(Token::r_paren, "expected `)` after tuple element list")))
2206  return failure();
2207  return createTupleExpr(loc, elements, elementNames);
2208 }
2209 
2210 FailureOr<ast::Expr *> Parser::parseTypeExpr() {
2211  SMRange loc = curToken.getLoc();
2212  consumeToken(Token::kw_type);
2213 
2214  // If we aren't followed by a `<`, the `type` keyword is treated as a normal
2215  // identifier.
2216  if (!consumeIf(Token::less)) {
2217  resetToken(loc);
2218  return parseIdentifierExpr();
2219  }
2220 
2221  if (!curToken.isString())
2222  return emitError("expected string literal containing MLIR type");
2223  std::string attrExpr = curToken.getStringValue();
2224  consumeToken();
2225 
2226  loc.End = curToken.getEndLoc();
2227  if (failed(parseToken(Token::greater, "expected `>` after type literal")))
2228  return failure();
2229  return ast::TypeExpr::create(ctx, loc, attrExpr);
2230 }
2231 
2232 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
2233  StringRef name = curToken.getSpelling();
2234  SMRange nameLoc = curToken.getLoc();
2235  consumeToken(Token::underscore);
2236 
2237  // Underscore expressions require a constraint list.
2238  if (failed(parseToken(Token::colon, "expected `:` after `_` variable")))
2239  return failure();
2240 
2241  // Parse the constraints for the expression.
2242  SmallVector<ast::ConstraintRef> constraints;
2243  if (failed(parseVariableDeclConstraintList(constraints)))
2244  return failure();
2245 
2246  ast::Type type;
2247  if (failed(validateVariableConstraints(constraints, type)))
2248  return failure();
2249  return createInlineVariableExpr(type, name, nameLoc, constraints);
2250 }
2251 
2252 //===----------------------------------------------------------------------===//
2253 // Stmts
2254 
2255 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) {
2257  switch (curToken.getKind()) {
2258  case Token::kw_erase:
2259  stmt = parseEraseStmt();
2260  break;
2261  case Token::kw_let:
2262  stmt = parseLetStmt();
2263  break;
2264  case Token::kw_replace:
2265  stmt = parseReplaceStmt();
2266  break;
2267  case Token::kw_return:
2268  stmt = parseReturnStmt();
2269  break;
2270  case Token::kw_rewrite:
2271  stmt = parseRewriteStmt();
2272  break;
2273  default:
2274  stmt = parseExpr();
2275  break;
2276  }
2277  if (failed(stmt) ||
2278  (expectTerminalSemicolon &&
2279  failed(parseToken(Token::semicolon, "expected `;` after statement"))))
2280  return failure();
2281  return stmt;
2282 }
2283 
2284 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() {
2285  SMLoc startLoc = curToken.getStartLoc();
2286  consumeToken(Token::l_brace);
2287 
2288  // Push a new block scope and parse any nested statements.
2289  pushDeclScope();
2290  SmallVector<ast::Stmt *> statements;
2291  while (curToken.isNot(Token::r_brace)) {
2292  FailureOr<ast::Stmt *> statement = parseStmt();
2293  if (failed(statement))
2294  return popDeclScope(), failure();
2295  statements.push_back(*statement);
2296  }
2297  popDeclScope();
2298 
2299  // Consume the end brace.
2300  SMRange location(startLoc, curToken.getEndLoc());
2301  consumeToken(Token::r_brace);
2302 
2303  return ast::CompoundStmt::create(ctx, location, statements);
2304 }
2305 
2306 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
2307  if (parserContext == ParserContext::Constraint)
2308  return emitError("`erase` cannot be used within a Constraint");
2309  SMRange loc = curToken.getLoc();
2310  consumeToken(Token::kw_erase);
2311 
2312  // Parse the root operation expression.
2313  FailureOr<ast::Expr *> rootOp = parseExpr();
2314  if (failed(rootOp))
2315  return failure();
2316 
2317  return createEraseStmt(loc, *rootOp);
2318 }
2319 
2320 FailureOr<ast::LetStmt *> Parser::parseLetStmt() {
2321  SMRange loc = curToken.getLoc();
2322  consumeToken(Token::kw_let);
2323 
2324  // Parse the name of the new variable.
2325  SMRange varLoc = curToken.getLoc();
2326  if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) {
2327  // `_` is a reserved variable name.
2328  if (curToken.is(Token::underscore)) {
2329  return emitError(varLoc,
2330  "`_` may only be used to define \"inline\" variables");
2331  }
2332  return emitError(varLoc,
2333  "expected identifier after `let` to name a new variable");
2334  }
2335  StringRef varName = curToken.getSpelling();
2336  consumeToken();
2337 
2338  // Parse the optional set of constraints.
2339  SmallVector<ast::ConstraintRef> constraints;
2340  if (consumeIf(Token::colon) &&
2341  failed(parseVariableDeclConstraintList(constraints)))
2342  return failure();
2343 
2344  // Parse the optional initializer expression.
2345  ast::Expr *initializer = nullptr;
2346  if (consumeIf(Token::equal)) {
2347  FailureOr<ast::Expr *> initOrFailure = parseExpr();
2348  if (failed(initOrFailure))
2349  return failure();
2350  initializer = *initOrFailure;
2351 
2352  // Check that the constraints are compatible with having an initializer,
2353  // e.g. type constraints cannot be used with initializers.
2354  for (ast::ConstraintRef constraint : constraints) {
2355  LogicalResult result =
2356  TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint)
2358  ast::ValueRangeConstraintDecl>([&](const auto *cst) {
2359  if (cst->getTypeExpr()) {
2360  return this->emitError(
2361  constraint.referenceLoc,
2362  "type constraints are not permitted on variables with "
2363  "initializers");
2364  }
2365  return success();
2366  })
2367  .Default(success());
2368  if (failed(result))
2369  return failure();
2370  }
2371  }
2372 
2374  createVariableDecl(varName, varLoc, initializer, constraints);
2375  if (failed(varDecl))
2376  return failure();
2377  return ast::LetStmt::create(ctx, loc, *varDecl);
2378 }
2379 
2380 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
2381  if (parserContext == ParserContext::Constraint)
2382  return emitError("`replace` cannot be used within a Constraint");
2383  SMRange loc = curToken.getLoc();
2384  consumeToken(Token::kw_replace);
2385 
2386  // Parse the root operation expression.
2387  FailureOr<ast::Expr *> rootOp = parseExpr();
2388  if (failed(rootOp))
2389  return failure();
2390 
2391  if (failed(
2392  parseToken(Token::kw_with, "expected `with` after root operation")))
2393  return failure();
2394 
2395  // The replacement portion of this statement is within a rewrite context.
2396  llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
2397 
2398  // Parse the replacement values.
2399  SmallVector<ast::Expr *> replValues;
2400  if (consumeIf(Token::l_paren)) {
2401  if (consumeIf(Token::r_paren)) {
2402  return emitError(
2403  loc, "expected at least one replacement value, consider using "
2404  "`erase` if no replacement values are desired");
2405  }
2406 
2407  do {
2408  FailureOr<ast::Expr *> replExpr = parseExpr();
2409  if (failed(replExpr))
2410  return failure();
2411  replValues.emplace_back(*replExpr);
2412  } while (consumeIf(Token::comma));
2413 
2414  if (failed(parseToken(Token::r_paren,
2415  "expected `)` after replacement values")))
2416  return failure();
2417  } else {
2418  // Handle replacement with an operation uniquely, as the replacement
2419  // operation supports type inferrence from the root operation.
2420  FailureOr<ast::Expr *> replExpr;
2421  if (curToken.is(Token::kw_op))
2422  replExpr = parseOperationExpr(OpResultTypeContext::Replacement);
2423  else
2424  replExpr = parseExpr();
2425  if (failed(replExpr))
2426  return failure();
2427  replValues.emplace_back(*replExpr);
2428  }
2429 
2430  return createReplaceStmt(loc, *rootOp, replValues);
2431 }
2432 
2433 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() {
2434  SMRange loc = curToken.getLoc();
2435  consumeToken(Token::kw_return);
2436 
2437  // Parse the result value.
2438  FailureOr<ast::Expr *> resultExpr = parseExpr();
2439  if (failed(resultExpr))
2440  return failure();
2441 
2442  return ast::ReturnStmt::create(ctx, loc, *resultExpr);
2443 }
2444 
2445 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
2446  if (parserContext == ParserContext::Constraint)
2447  return emitError("`rewrite` cannot be used within a Constraint");
2448  SMRange loc = curToken.getLoc();
2449  consumeToken(Token::kw_rewrite);
2450 
2451  // Parse the root operation.
2452  FailureOr<ast::Expr *> rootOp = parseExpr();
2453  if (failed(rootOp))
2454  return failure();
2455 
2456  if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body")))
2457  return failure();
2458 
2459  if (curToken.isNot(Token::l_brace))
2460  return emitError("expected `{` to start rewrite body");
2461 
2462  // The rewrite body of this statement is within a rewrite context.
2463  llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
2464 
2465  FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt();
2466  if (failed(rewriteBody))
2467  return failure();
2468 
2469  // Verify the rewrite body.
2470  for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) {
2471  if (isa<ast::ReturnStmt>(stmt)) {
2472  return emitError(stmt->getLoc(),
2473  "`return` statements are only permitted within a "
2474  "`Constraint` or `Rewrite` body");
2475  }
2476  }
2477 
2478  return createRewriteStmt(loc, *rootOp, *rewriteBody);
2479 }
2480 
2481 //===----------------------------------------------------------------------===//
2482 // Creation+Analysis
2483 //===----------------------------------------------------------------------===//
2484 
2485 //===----------------------------------------------------------------------===//
2486 // Decls
2487 
2488 ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) {
2489  // Unwrap reference expressions.
2490  if (auto *init = dyn_cast<ast::DeclRefExpr>(node))
2491  node = init->getDecl();
2492  return dyn_cast<ast::CallableDecl>(node);
2493 }
2494 
2496 Parser::createPatternDecl(SMRange loc, const ast::Name *name,
2497  const ParsedPatternMetadata &metadata,
2498  ast::CompoundStmt *body) {
2499  return ast::PatternDecl::create(ctx, loc, name, metadata.benefit,
2500  metadata.hasBoundedRecursion, body);
2501 }
2502 
2503 ast::Type Parser::createUserConstraintRewriteResultType(
2505  // Single result decls use the type of the single result.
2506  if (results.size() == 1)
2507  return results[0]->getType();
2508 
2509  // Multiple results use a tuple type, with the types and names grabbed from
2510  // the result variable decls.
2511  auto resultTypes = llvm::map_range(
2512  results, [&](const auto *result) { return result->getType(); });
2513  auto resultNames = llvm::map_range(
2514  results, [&](const auto *result) { return result->getName().getName(); });
2515  return ast::TupleType::get(ctx, llvm::to_vector(resultTypes),
2516  llvm::to_vector(resultNames));
2517 }
2518 
2519 template <typename T>
2520 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl(
2521  const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
2522  ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
2523  ast::CompoundStmt *body) {
2524  if (!body->getChildren().empty()) {
2525  if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) {
2526  ast::Expr *resultExpr = retStmt->getResultExpr();
2527 
2528  // Process the result of the decl. If no explicit signature results
2529  // were provided, check for return type inference. Otherwise, check that
2530  // the return expression can be converted to the expected type.
2531  if (results.empty())
2532  resultType = resultExpr->getType();
2533  else if (failed(convertExpressionTo(resultExpr, resultType)))
2534  return failure();
2535  else
2536  retStmt->setResultExpr(resultExpr);
2537  }
2538  }
2539  return T::createPDLL(ctx, name, arguments, results, body, resultType);
2540 }
2541 
2543 Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
2544  ArrayRef<ast::ConstraintRef> constraints) {
2545  // The type of the variable, which is expected to be inferred by either a
2546  // constraint or an initializer expression.
2547  ast::Type type;
2548  if (failed(validateVariableConstraints(constraints, type)))
2549  return failure();
2550 
2551  if (initializer) {
2552  // Update the variable type based on the initializer, or try to convert the
2553  // initializer to the existing type.
2554  if (!type)
2555  type = initializer->getType();
2556  else if (ast::Type mergedType = type.refineWith(initializer->getType()))
2557  type = mergedType;
2558  else if (failed(convertExpressionTo(initializer, type)))
2559  return failure();
2560 
2561  // Otherwise, if there is no initializer check that the type has already
2562  // been resolved from the constraint list.
2563  } else if (!type) {
2564  return emitErrorAndNote(
2565  loc, "unable to infer type for variable `" + name + "`", loc,
2566  "the type of a variable must be inferable from the constraint "
2567  "list or the initializer");
2568  }
2569 
2570  // Constraint types cannot be used when defining variables.
2571  if (isa<ast::ConstraintType, ast::RewriteType>(type)) {
2572  return emitError(
2573  loc, llvm::formatv("unable to define variable of `{0}` type", type));
2574  }
2575 
2576  // Try to define a variable with the given name.
2578  defineVariableDecl(name, loc, type, initializer, constraints);
2579  if (failed(varDecl))
2580  return failure();
2581 
2582  return *varDecl;
2583 }
2584 
2586 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
2587  const ast::ConstraintRef &constraint) {
2588  ast::Type argType;
2589  if (failed(validateVariableConstraint(constraint, argType)))
2590  return failure();
2591  return defineVariableDecl(name, loc, argType, constraint);
2592 }
2593 
2595 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
2596  ast::Type &inferredType) {
2597  for (const ast::ConstraintRef &ref : constraints)
2598  if (failed(validateVariableConstraint(ref, inferredType)))
2599  return failure();
2600  return success();
2601 }
2602 
2603 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
2604  ast::Type &inferredType) {
2605  ast::Type constraintType;
2606  if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) {
2607  if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2608  if (failed(validateTypeConstraintExpr(typeExpr)))
2609  return failure();
2610  }
2611  constraintType = ast::AttributeType::get(ctx);
2612  } else if (const auto *cst =
2613  dyn_cast<ast::OpConstraintDecl>(ref.constraint)) {
2614  constraintType = ast::OperationType::get(
2615  ctx, cst->getName(), lookupODSOperation(cst->getName()));
2616  } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) {
2617  constraintType = typeTy;
2618  } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) {
2619  constraintType = typeRangeTy;
2620  } else if (const auto *cst =
2621  dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) {
2622  if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2623  if (failed(validateTypeConstraintExpr(typeExpr)))
2624  return failure();
2625  }
2626  constraintType = valueTy;
2627  } else if (const auto *cst =
2628  dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) {
2629  if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2630  if (failed(validateTypeRangeConstraintExpr(typeExpr)))
2631  return failure();
2632  }
2633  constraintType = valueRangeTy;
2634  } else if (const auto *cst =
2635  dyn_cast<ast::UserConstraintDecl>(ref.constraint)) {
2636  ArrayRef<ast::VariableDecl *> inputs = cst->getInputs();
2637  if (inputs.size() != 1) {
2638  return emitErrorAndNote(ref.referenceLoc,
2639  "`Constraint`s applied via a variable constraint "
2640  "list must take a single input, but got " +
2641  Twine(inputs.size()),
2642  cst->getLoc(),
2643  "see definition of constraint here");
2644  }
2645  constraintType = inputs.front()->getType();
2646  } else {
2647  llvm_unreachable("unknown constraint type");
2648  }
2649 
2650  // Check that the constraint type is compatible with the current inferred
2651  // type.
2652  if (!inferredType) {
2653  inferredType = constraintType;
2654  } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) {
2655  inferredType = mergedTy;
2656  } else {
2657  return emitError(ref.referenceLoc,
2658  llvm::formatv("constraint type `{0}` is incompatible "
2659  "with the previously inferred type `{1}`",
2660  constraintType, inferredType));
2661  }
2662  return success();
2663 }
2664 
2665 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) {
2666  ast::Type typeExprType = typeExpr->getType();
2667  if (typeExprType != typeTy) {
2668  return emitError(typeExpr->getLoc(),
2669  "expected expression of `Type` in type constraint");
2670  }
2671  return success();
2672 }
2673 
2675 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) {
2676  ast::Type typeExprType = typeExpr->getType();
2677  if (typeExprType != typeRangeTy) {
2678  return emitError(typeExpr->getLoc(),
2679  "expected expression of `TypeRange` in type constraint");
2680  }
2681  return success();
2682 }
2683 
2684 //===----------------------------------------------------------------------===//
2685 // Exprs
2686 
2688 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
2689  MutableArrayRef<ast::Expr *> arguments, bool isNegated) {
2690  ast::Type parentType = parentExpr->getType();
2691 
2692  ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr);
2693  if (!callableDecl) {
2694  return emitError(loc,
2695  llvm::formatv("expected a reference to a callable "
2696  "`Constraint` or `Rewrite`, but got: `{0}`",
2697  parentType));
2698  }
2699  if (parserContext == ParserContext::Rewrite) {
2700  if (isa<ast::UserConstraintDecl>(callableDecl))
2701  return emitError(
2702  loc, "unable to invoke `Constraint` within a rewrite section");
2703  if (isNegated)
2704  return emitError(loc, "unable to negate a Rewrite");
2705  } else {
2706  if (isa<ast::UserRewriteDecl>(callableDecl))
2707  return emitError(loc,
2708  "unable to invoke `Rewrite` within a match section");
2709  if (isNegated && cast<ast::UserConstraintDecl>(callableDecl)->getBody())
2710  return emitError(loc, "unable to negate non native constraints");
2711  }
2712 
2713  // Verify the arguments of the call.
2714  /// Handle size mismatch.
2715  ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs();
2716  if (callArgs.size() != arguments.size()) {
2717  return emitErrorAndNote(
2718  loc,
2719  llvm::formatv("invalid number of arguments for {0} call; expected "
2720  "{1}, but got {2}",
2721  callableDecl->getCallableType(), callArgs.size(),
2722  arguments.size()),
2723  callableDecl->getLoc(),
2724  llvm::formatv("see the definition of {0} here",
2725  callableDecl->getName()->getName()));
2726  }
2727 
2728  /// Handle argument type mismatch.
2729  auto attachDiagFn = [&](ast::Diagnostic &diag) {
2730  diag.attachNote(llvm::formatv("see the definition of `{0}` here",
2731  callableDecl->getName()->getName()),
2732  callableDecl->getLoc());
2733  };
2734  for (auto it : llvm::zip(callArgs, arguments)) {
2735  if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(),
2736  attachDiagFn)))
2737  return failure();
2738  }
2739 
2740  return ast::CallExpr::create(ctx, loc, parentExpr, arguments,
2741  callableDecl->getResultType(), isNegated);
2742 }
2743 
2744 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,
2745  ast::Decl *decl) {
2746  // Check the type of decl being referenced.
2747  ast::Type declType;
2748  if (isa<ast::ConstraintDecl>(decl))
2749  declType = ast::ConstraintType::get(ctx);
2750  else if (isa<ast::UserRewriteDecl>(decl))
2751  declType = ast::RewriteType::get(ctx);
2752  else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
2753  declType = varDecl->getType();
2754  else
2755  return emitError(loc, "invalid reference to `" +
2756  decl->getName()->getName() + "`");
2757 
2758  return ast::DeclRefExpr::create(ctx, loc, decl, declType);
2759 }
2760 
2762 Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
2763  ArrayRef<ast::ConstraintRef> constraints) {
2765  defineVariableDecl(name, loc, type, constraints);
2766  if (failed(decl))
2767  return failure();
2768  return ast::DeclRefExpr::create(ctx, loc, *decl, type);
2769 }
2770 
2772 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
2773  SMRange loc) {
2774  // Validate the member name for the given parent expression.
2775  FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc);
2776  if (failed(memberType))
2777  return failure();
2778 
2779  return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType);
2780 }
2781 
2782 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
2783  StringRef name, SMRange loc) {
2784  ast::Type parentType = parentExpr->getType();
2785  if (ast::OperationType opType = dyn_cast<ast::OperationType>(parentType)) {
2787  return valueRangeTy;
2788 
2789  // Verify member access based on the operation type.
2790  if (const ods::Operation *odsOp = opType.getODSOperation()) {
2791  auto results = odsOp->getResults();
2792 
2793  // Handle indexed results.
2794  unsigned index = 0;
2795  if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2796  index < results.size()) {
2797  return results[index].isVariadic() ? valueRangeTy : valueTy;
2798  }
2799 
2800  // Handle named results.
2801  const auto *it = llvm::find_if(results, [&](const auto &result) {
2802  return result.getName() == name;
2803  });
2804  if (it != results.end())
2805  return it->isVariadic() ? valueRangeTy : valueTy;
2806  } else if (llvm::isDigit(name[0])) {
2807  // Allow unchecked numeric indexing of the results of unregistered
2808  // operations. It returns a single value.
2809  return valueTy;
2810  }
2811  } else if (auto tupleType = dyn_cast<ast::TupleType>(parentType)) {
2812  // Handle indexed results.
2813  unsigned index = 0;
2814  if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2815  index < tupleType.size()) {
2816  return tupleType.getElementTypes()[index];
2817  }
2818 
2819  // Handle named results.
2820  auto elementNames = tupleType.getElementNames();
2821  const auto *it = llvm::find(elementNames, name);
2822  if (it != elementNames.end())
2823  return tupleType.getElementTypes()[it - elementNames.begin()];
2824  }
2825  return emitError(
2826  loc,
2827  llvm::formatv("invalid member access `{0}` on expression of type `{1}`",
2828  name, parentType));
2829 }
2830 
2831 FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
2832  SMRange loc, const ast::OpNameDecl *name,
2833  OpResultTypeContext resultTypeContext,
2834  SmallVectorImpl<ast::Expr *> &operands,
2836  SmallVectorImpl<ast::Expr *> &results) {
2837  std::optional<StringRef> opNameRef = name->getName();
2838  const ods::Operation *odsOp = lookupODSOperation(opNameRef);
2839 
2840  // Verify the inputs operands.
2841  if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands)))
2842  return failure();
2843 
2844  // Verify the attribute list.
2845  for (ast::NamedAttributeDecl *attr : attributes) {
2846  // Check for an attribute type, or a type awaiting resolution.
2847  ast::Type attrType = attr->getValue()->getType();
2848  if (!isa<ast::AttributeType>(attrType)) {
2849  return emitError(
2850  attr->getValue()->getLoc(),
2851  llvm::formatv("expected `Attr` expression, but got `{0}`", attrType));
2852  }
2853  }
2854 
2855  assert(
2856  (resultTypeContext == OpResultTypeContext::Explicit || results.empty()) &&
2857  "unexpected inferrence when results were explicitly specified");
2858 
2859  // If we aren't relying on type inferrence, or explicit results were provided,
2860  // validate them.
2861  if (resultTypeContext == OpResultTypeContext::Explicit) {
2862  if (failed(validateOperationResults(loc, opNameRef, odsOp, results)))
2863  return failure();
2864 
2865  // Validate the use of interface based type inferrence for this operation.
2866  } else if (resultTypeContext == OpResultTypeContext::Interface) {
2867  assert(opNameRef &&
2868  "expected valid operation name when inferring operation results");
2869  checkOperationResultTypeInferrence(loc, *opNameRef, odsOp);
2870  }
2871 
2872  return ast::OperationExpr::create(ctx, loc, odsOp, name, operands, results,
2873  attributes);
2874 }
2875 
2877 Parser::validateOperationOperands(SMRange loc, std::optional<StringRef> name,
2878  const ods::Operation *odsOp,
2879  SmallVectorImpl<ast::Expr *> &operands) {
2880  return validateOperationOperandsOrResults(
2881  "operand", loc, odsOp ? odsOp->getLoc() : std::optional<SMRange>(), name,
2882  operands, odsOp ? odsOp->getOperands() : std::nullopt, valueTy,
2883  valueRangeTy);
2884 }
2885 
2887 Parser::validateOperationResults(SMRange loc, std::optional<StringRef> name,
2888  const ods::Operation *odsOp,
2889  SmallVectorImpl<ast::Expr *> &results) {
2890  return validateOperationOperandsOrResults(
2891  "result", loc, odsOp ? odsOp->getLoc() : std::optional<SMRange>(), name,
2892  results, odsOp ? odsOp->getResults() : std::nullopt, typeTy, typeRangeTy);
2893 }
2894 
2895 void Parser::checkOperationResultTypeInferrence(SMRange loc, StringRef opName,
2896  const ods::Operation *odsOp) {
2897  // If the operation might not have inferrence support, emit a warning to the
2898  // user. We don't emit an error because the interface might be added to the
2899  // operation at runtime. It's rare, but it could still happen. We emit a
2900  // warning here instead.
2901 
2902  // Handle inferrence warnings for unknown operations.
2903  if (!odsOp) {
2904  ctx.getDiagEngine().emitWarning(
2905  loc, llvm::formatv(
2906  "operation result types are marked to be inferred, but "
2907  "`{0}` is unknown. Ensure that `{0}` supports zero "
2908  "results or implements `InferTypeOpInterface`. Include "
2909  "the ODS definition of this operation to remove this warning.",
2910  opName));
2911  return;
2912  }
2913 
2914  // Handle inferrence warnings for known operations that expected at least one
2915  // result, but don't have inference support. An elided results list can mean
2916  // "zero-results", and we don't want to warn when that is the expected
2917  // behavior.
2918  bool requiresInferrence =
2919  llvm::any_of(odsOp->getResults(), [](const ods::OperandOrResult &result) {
2920  return !result.isVariableLength();
2921  });
2922  if (requiresInferrence && !odsOp->hasResultTypeInferrence()) {
2923  ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitWarning(
2924  loc,
2925  llvm::formatv("operation result types are marked to be inferred, but "
2926  "`{0}` does not provide an implementation of "
2927  "`InferTypeOpInterface`. Ensure that `{0}` attaches "
2928  "`InferTypeOpInterface` at runtime, or add support to "
2929  "the ODS definition to remove this warning.",
2930  opName));
2931  diag->attachNote(llvm::formatv("see the definition of `{0}` here", opName),
2932  odsOp->getLoc());
2933  return;
2934  }
2935 }
2936 
2937 LogicalResult Parser::validateOperationOperandsOrResults(
2938  StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc,
2939  std::optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values,
2940  ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
2941  ast::RangeType rangeTy) {
2942  // All operation types accept a single range parameter.
2943  if (values.size() == 1) {
2944  if (failed(convertExpressionTo(values[0], rangeTy)))
2945  return failure();
2946  return success();
2947  }
2948 
2949  /// If the operation has ODS information, we can more accurately verify the
2950  /// values.
2951  if (odsOpLoc) {
2952  auto emitSizeMismatchError = [&] {
2953  return emitErrorAndNote(
2954  loc,
2955  llvm::formatv("invalid number of {0} groups for `{1}`; expected "
2956  "{2}, but got {3}",
2957  groupName, *name, odsValues.size(), values.size()),
2958  *odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name));
2959  };
2960 
2961  // Handle the case where no values were provided.
2962  if (values.empty()) {
2963  // If we don't expect any on the ODS side, we are done.
2964  if (odsValues.empty())
2965  return success();
2966 
2967  // If we do, check if we actually need to provide values (i.e. if any of
2968  // the values are actually required).
2969  unsigned numVariadic = 0;
2970  for (const auto &odsValue : odsValues) {
2971  if (!odsValue.isVariableLength())
2972  return emitSizeMismatchError();
2973  ++numVariadic;
2974  }
2975 
2976  // If we are in a non-rewrite context, we don't need to do anything more.
2977  // Zero-values is a valid constraint on the operation.
2978  if (parserContext != ParserContext::Rewrite)
2979  return success();
2980 
2981  // Otherwise, when in a rewrite we may need to provide values to match the
2982  // ODS signature of the operation to create.
2983 
2984  // If we only have one variadic value, just use an empty list.
2985  if (numVariadic == 1)
2986  return success();
2987 
2988  // Otherwise, create dummy values for each of the entries so that we
2989  // adhere to the ODS signature.
2990  for (unsigned i = 0, e = odsValues.size(); i < e; ++i) {
2991  values.push_back(ast::RangeExpr::create(
2992  ctx, loc, /*elements=*/std::nullopt, rangeTy));
2993  }
2994  return success();
2995  }
2996 
2997  // Verify that the number of values provided matches the number of value
2998  // groups ODS expects.
2999  if (odsValues.size() != values.size())
3000  return emitSizeMismatchError();
3001 
3002  auto diagFn = [&](ast::Diagnostic &diag) {
3003  diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name),
3004  *odsOpLoc);
3005  };
3006  for (unsigned i = 0, e = values.size(); i < e; ++i) {
3007  ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy;
3008  if (failed(convertExpressionTo(values[i], expectedType, diagFn)))
3009  return failure();
3010  }
3011  return success();
3012  }
3013 
3014  // Otherwise, accept the value groups as they have been defined and just
3015  // ensure they are one of the expected types.
3016  for (ast::Expr *&valueExpr : values) {
3017  ast::Type valueExprType = valueExpr->getType();
3018 
3019  // Check if this is one of the expected types.
3020  if (valueExprType == rangeTy || valueExprType == singleTy)
3021  continue;
3022 
3023  // If the operand is an Operation, allow converting to a Value or
3024  // ValueRange. This situations arises quite often with nested operation
3025  // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)`
3026  if (singleTy == valueTy) {
3027  if (isa<ast::OperationType>(valueExprType)) {
3028  valueExpr = convertOpToValue(valueExpr);
3029  continue;
3030  }
3031  }
3032 
3033  // Otherwise, try to convert the expression to a range.
3034  if (succeeded(convertExpressionTo(valueExpr, rangeTy)))
3035  continue;
3036 
3037  return emitError(
3038  valueExpr->getLoc(),
3039  llvm::formatv(
3040  "expected `{0}` or `{1}` convertible expression, but got `{2}`",
3041  singleTy, rangeTy, valueExprType));
3042  }
3043  return success();
3044 }
3045 
3047 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements,
3048  ArrayRef<StringRef> elementNames) {
3049  for (const ast::Expr *element : elements) {
3050  ast::Type eleTy = element->getType();
3051  if (isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>(eleTy)) {
3052  return emitError(
3053  element->getLoc(),
3054  llvm::formatv("unable to build a tuple with `{0}` element", eleTy));
3055  }
3056  }
3057  return ast::TupleExpr::create(ctx, loc, elements, elementNames);
3058 }
3059 
3060 //===----------------------------------------------------------------------===//
3061 // Stmts
3062 
3063 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc,
3064  ast::Expr *rootOp) {
3065  // Check that root is an Operation.
3066  ast::Type rootType = rootOp->getType();
3067  if (!isa<ast::OperationType>(rootType))
3068  return emitError(rootOp->getLoc(), "expected `Op` expression");
3069 
3070  return ast::EraseStmt::create(ctx, loc, rootOp);
3071 }
3072 
3074 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp,
3075  MutableArrayRef<ast::Expr *> replValues) {
3076  // Check that root is an Operation.
3077  ast::Type rootType = rootOp->getType();
3078  if (!isa<ast::OperationType>(rootType)) {
3079  return emitError(
3080  rootOp->getLoc(),
3081  llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
3082  }
3083 
3084  // If there are multiple replacement values, we implicitly convert any Op
3085  // expressions to the value form.
3086  bool shouldConvertOpToValues = replValues.size() > 1;
3087  for (ast::Expr *&replExpr : replValues) {
3088  ast::Type replType = replExpr->getType();
3089 
3090  // Check that replExpr is an Operation, Value, or ValueRange.
3091  if (isa<ast::OperationType>(replType)) {
3092  if (shouldConvertOpToValues)
3093  replExpr = convertOpToValue(replExpr);
3094  continue;
3095  }
3096 
3097  if (replType != valueTy && replType != valueRangeTy) {
3098  return emitError(replExpr->getLoc(),
3099  llvm::formatv("expected `Op`, `Value` or `ValueRange` "
3100  "expression, but got `{0}`",
3101  replType));
3102  }
3103  }
3104 
3105  return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues);
3106 }
3107 
3109 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp,
3110  ast::CompoundStmt *rewriteBody) {
3111  // Check that root is an Operation.
3112  ast::Type rootType = rootOp->getType();
3113  if (!isa<ast::OperationType>(rootType)) {
3114  return emitError(
3115  rootOp->getLoc(),
3116  llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
3117  }
3118 
3119  return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody);
3120 }
3121 
3122 //===----------------------------------------------------------------------===//
3123 // Code Completion
3124 //===----------------------------------------------------------------------===//
3125 
3126 LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) {
3127  ast::Type parentType = parentExpr->getType();
3128  if (ast::OperationType opType = dyn_cast<ast::OperationType>(parentType))
3129  codeCompleteContext->codeCompleteOperationMemberAccess(opType);
3130  else if (ast::TupleType tupleType = dyn_cast<ast::TupleType>(parentType))
3131  codeCompleteContext->codeCompleteTupleMemberAccess(tupleType);
3132  return failure();
3133 }
3134 
3136 Parser::codeCompleteAttributeName(std::optional<StringRef> opName) {
3137  if (opName)
3138  codeCompleteContext->codeCompleteOperationAttributeName(*opName);
3139  return failure();
3140 }
3141 
3143 Parser::codeCompleteConstraintName(ast::Type inferredType,
3144  bool allowInlineTypeConstraints) {
3145  codeCompleteContext->codeCompleteConstraintName(
3146  inferredType, allowInlineTypeConstraints, curDeclScope);
3147  return failure();
3148 }
3149 
3150 LogicalResult Parser::codeCompleteDialectName() {
3151  codeCompleteContext->codeCompleteDialectName();
3152  return failure();
3153 }
3154 
3155 LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) {
3156  codeCompleteContext->codeCompleteOperationName(dialectName);
3157  return failure();
3158 }
3159 
3160 LogicalResult Parser::codeCompletePatternMetadata() {
3161  codeCompleteContext->codeCompletePatternMetadata();
3162  return failure();
3163 }
3164 
3165 LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) {
3166  codeCompleteContext->codeCompleteIncludeFilename(curPath);
3167  return failure();
3168 }
3169 
3170 void Parser::codeCompleteCallSignature(ast::Node *parent,
3171  unsigned currentNumArgs) {
3172  ast::CallableDecl *callableDecl = tryExtractCallableDecl(parent);
3173  if (!callableDecl)
3174  return;
3175 
3176  codeCompleteContext->codeCompleteCallSignature(callableDecl, currentNumArgs);
3177 }
3178 
3179 void Parser::codeCompleteOperationOperandsSignature(
3180  std::optional<StringRef> opName, unsigned currentNumOperands) {
3181  codeCompleteContext->codeCompleteOperationOperandsSignature(
3182  opName, currentNumOperands);
3183 }
3184 
3185 void Parser::codeCompleteOperationResultsSignature(
3186  std::optional<StringRef> opName, unsigned currentNumResults) {
3187  codeCompleteContext->codeCompleteOperationResultsSignature(opName,
3188  currentNumResults);
3189 }
3190 
3191 //===----------------------------------------------------------------------===//
3192 // Parser
3193 //===----------------------------------------------------------------------===//
3194 
3196 mlir::pdll::parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
3197  bool enableDocumentation,
3198  CodeCompleteContext *codeCompleteContext) {
3199  Parser parser(ctx, sourceMgr, enableDocumentation, codeCompleteContext);
3200  return parser.parseModule();
3201 }
static std::string diag(const llvm::Value &value)
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class breaks up the current file into a token stream.
Definition: Lexer.h:23
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
This represents a token in the MLIR syntax.
Definition: Token.h:20
bool isKeyword() const
Return true if this is one of the keyword token kinds (e.g. kw_if).
Definition: Token.cpp:192
SMLoc getLoc() const
Definition: Token.cpp:24
bool isNot(Kind k) const
Definition: Token.h:50
StringRef getSpelling() const
Definition: Token.h:34
This class provides an abstract interface into the parser for hooking in code completion events.
Definition: CodeComplete.h:30
@ code_complete_string
Token signifying a code completion location within a string.
Definition: Lexer.h:43
@ directive
Tokens.
Definition: Lexer.h:95
@ eof
Markers.
Definition: Lexer.h:38
@ code_complete
Token signifying a code completion location.
Definition: Lexer.h:41
@ arrow
Punctuation.
Definition: Lexer.h:76
@ less
Paired punctuation.
Definition: Lexer.h:84
@ kw_Attr
General keywords.
Definition: Lexer.h:56
static StringRef getMemberName()
Return the member name used for the "all-results" access.
Definition: Nodes.h:487
static AllResultsMemberAccessExpr * create(Context &ctx, SMRange loc, const Expr *parentExpr, Type type)
Definition: Nodes.h:489
The class represents an Attribute constraint, and constrains a variable to be an Attribute.
Definition: Nodes.h:749
static AttrConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr=nullptr)
Definition: Nodes.cpp:390
static AttributeExpr * create(Context &ctx, SMRange loc, StringRef value)
Definition: Nodes.cpp:258
static AttributeType get(Context &context)
Return an instance of the Attribute type.
Definition: Types.cpp:57
static CallExpr * create(Context &ctx, SMRange loc, Expr *callable, ArrayRef< Expr * > arguments, Type resultType, bool isNegated=false)
Definition: Nodes.cpp:268
This decl represents a shared interface for all callable decls.
Definition: Nodes.h:1188
Type getResultType() const
Return the result type of this decl.
Definition: Nodes.h:1206
ArrayRef< VariableDecl * > getInputs() const
Return the inputs of this decl.
Definition: Nodes.h:1199
StringRef getCallableType() const
Return the callable type of this decl.
Definition: Nodes.h:1191
This statement represents a compound statement, which contains a collection of other statements.
Definition: Nodes.h:179
ArrayRef< Stmt * >::iterator begin() const
Definition: Nodes.h:191
MutableArrayRef< Stmt * > getChildren()
Return the children of this compound statement.
Definition: Nodes.h:185
ArrayRef< Stmt * >::iterator end() const
Definition: Nodes.h:192
static CompoundStmt * create(Context &ctx, SMRange location, ArrayRef< Stmt * > children)
Definition: Nodes.cpp:192
This class represents the base of all AST Constraint decls.
Definition: Nodes.h:704
static ConstraintType get(Context &context)
Return an instance of the Constraint type.
Definition: Types.cpp:65
This class represents the main context of the PDLL AST.
Definition: Context.h:25
static DeclRefExpr * create(Context &ctx, SMRange loc, Decl *decl, Type type)
Definition: Nodes.cpp:285
This class represents a scope for named AST decls.
Definition: Nodes.h:64
DeclScope * getParentScope()
Return the parent scope of this scope, or nullptr if there is no parent.
Definition: Nodes.h:70
void add(Decl *decl)
Add a new decl to the scope.
Definition: Nodes.cpp:175
This class represents the base Decl node.
Definition: Nodes.h:669
const Name * getName() const
Return the name of the decl, or nullptr if it doesn't have one.
Definition: Nodes.h:672
This class provides a simple implementation of a PDLL diagnostic.
Definition: Diagnostic.h:30
static EraseStmt * create(Context &ctx, SMRange loc, Expr *rootOp)
Definition: Nodes.cpp:218
This class represents a base AST Expression node.
Definition: Nodes.h:345
Type getType() const
Return the type of this expression.
Definition: Nodes.h:348
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostic.h:84
static LetStmt * create(Context &ctx, SMRange loc, VariableDecl *varDecl)
Definition: Nodes.cpp:207
static MemberAccessExpr * create(Context &ctx, SMRange loc, const Expr *parentExpr, StringRef memberName, Type type)
Definition: Nodes.cpp:295
static Module * create(Context &ctx, SMLoc loc, ArrayRef< Decl * > children)
Definition: Nodes.cpp:576
This Decl represents a NamedAttribute, and contains a string name and attribute value.
Definition: Nodes.h:992
static NamedAttributeDecl * create(Context &ctx, const Name &name, Expr *value)
Definition: Nodes.cpp:499
This class represents a base AST node.
Definition: Nodes.h:108
SMRange getLoc() const
Return the location of this node.
Definition: Nodes.h:131
static OpConstraintDecl * create(Context &ctx, SMRange loc, const OpNameDecl *nameDecl=nullptr)
Definition: Nodes.cpp:400
This Decl represents an OperationName.
Definition: Nodes.h:1016
std::optional< StringRef > getName() const
Return the name of this operation, or std::nullopt if the name is unknown.
Definition: Nodes.h:1022
static OpNameDecl * create(Context &ctx, const Name &name)
Definition: Nodes.cpp:509
static OperationExpr * create(Context &ctx, SMRange loc, const ods::Operation *odsOp, const OpNameDecl *nameDecl, ArrayRef< Expr * > operands, ArrayRef< Expr * > resultTypes, ArrayRef< NamedAttributeDecl * > attributes)
Definition: Nodes.cpp:307
This class represents a PDLL type that corresponds to an mlir::Operation.
Definition: Types.h:163
const ods::Operation * getODSOperation() const
Return the ODS operation that this type refers to, or nullptr if the ODS operation is unknown.
Definition: Types.cpp:87
static OperationType get(Context &context, std::optional< StringRef > name=std::nullopt, const ods::Operation *odsOp=nullptr)
Return an instance of the Operation type with an optional operation name.
Definition: Types.cpp:73
static PatternDecl * create(Context &ctx, SMRange location, const Name *name, std::optional< uint16_t > benefit, bool hasBoundedRecursion, const CompoundStmt *body)
Definition: Nodes.cpp:520
static RangeExpr * create(Context &ctx, SMRange loc, ArrayRef< Expr * > elements, RangeType type)
Definition: Nodes.cpp:338
This class represents a PDLL type that corresponds to a range of elements with a given element type.
Definition: Types.h:188
static ReplaceStmt * create(Context &ctx, SMRange loc, Expr *rootOp, ArrayRef< Expr * > replExprs)
Definition: Nodes.cpp:225
static ReturnStmt * create(Context &ctx, SMRange loc, Expr *resultExpr)
Definition: Nodes.cpp:249
static RewriteStmt * create(Context &ctx, SMRange loc, Expr *rootOp, CompoundStmt *rewriteBody)
Definition: Nodes.cpp:239
static RewriteType get(Context &context)
Return an instance of the Rewrite type.
Definition: Types.cpp:134
This class represents a base AST Statement node.
Definition: Nodes.h:164
static TupleExpr * create(Context &ctx, SMRange loc, ArrayRef< Expr * > elements, ArrayRef< StringRef > elementNames)
Definition: Nodes.cpp:353
This class represents a PDLL tuple type, i.e.
Definition: Types.h:249
size_t size() const
Return the number of elements within this tuple.
Definition: Types.h:266
ArrayRef< Type > getElementTypes() const
Return the element types of this tuple.
Definition: Types.cpp:153
static TupleType get(Context &context, ArrayRef< Type > elementTypes, ArrayRef< StringRef > elementNames)
Return an instance of the Tuple type.
Definition: Types.cpp:142
static TypeConstraintDecl * create(Context &ctx, SMRange loc)
Definition: Nodes.cpp:417
static TypeExpr * create(Context &ctx, SMRange loc, StringRef value)
Definition: Nodes.cpp:373
static TypeRangeConstraintDecl * create(Context &ctx, SMRange loc)
Definition: Nodes.cpp:426
Type refineWith(Type other) const
Try to refine this type with the one provided.
Definition: Types.cpp:33
static UserConstraintDecl * createNative(Context &ctx, const Name &name, ArrayRef< VariableDecl * > inputs, ArrayRef< VariableDecl * > results, std::optional< StringRef > codeBlock, Type resultType, ArrayRef< StringRef > nativeInputTypes={})
Create a native constraint with the given optional code block.
Definition: Nodes.h:886
The class represents a Value constraint, and constrains a variable to be a Value.
Definition: Nodes.h:825
static ValueConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr)
Definition: Nodes.cpp:436
The class represents a ValueRange constraint, and constrains a variable to be a ValueRange.
Definition: Nodes.h:847
static ValueRangeConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr=nullptr)
Definition: Nodes.cpp:447
static VariableDecl * create(Context &ctx, const Name &name, Type type, Expr *initExpr, ArrayRef< ConstraintRef > constraints)
Definition: Nodes.cpp:558
This class contains all of the registered ODS operation classes.
Definition: Context.h:32
std::pair< Operation *, bool > insertOperation(StringRef name, StringRef summary, StringRef desc, StringRef nativeClassName, bool supportsResultTypeInferrence, SMLoc loc)
Insert a new operation with the context.
Definition: Context.cpp:64
const TypeConstraint & insertTypeConstraint(StringRef name, StringRef summary, StringRef cppClass)
Insert a new type constraint with the context.
Definition: Context.cpp:42
const AttributeConstraint & insertAttributeConstraint(StringRef name, StringRef summary, StringRef cppClass)
Insert a new attribute constraint with the context.
Definition: Context.cpp:28
This class provides an ODS representation of a specific operation operand or result.
Definition: Operation.h:74
This class provides an ODS representation of a specific operation.
Definition: Operation.h:125
SMRange getLoc() const
Return the source location of this operation.
Definition: Operation.h:128
ArrayRef< OperandOrResult > getOperands() const
Returns the operands of this operation.
Definition: Operation.h:165
bool hasResultTypeInferrence() const
Return if the operation is known to support result type inferrence.
Definition: Operation.h:171
ArrayRef< OperandOrResult > getResults() const
Returns the results of this operation.
Definition: Operation.h:168
This class represents a generic ODS Type constraint.
Definition: Constraint.h:84
raw_ostream subclass that simplifies indention a sequence of code.
raw_indented_ostream & printReindented(StringRef str, StringRef extraPrefix="")
Prints a string re-indented to the current indent.
StringRef getSummary() const
Definition: Constraint.cpp:55
std::string getUniqueDefName() const
Returns a unique name for the TablGen def of this constraint.
Definition: Constraint.cpp:72
StringRef getDescription() const
Definition: Constraint.cpp:62
std::string getConditionTemplate() const
Definition: Constraint.cpp:51
Format context containing substitutions for special placeholders.
Definition: Format.h:40
FmtContext & withSelf(Twine subst)
Definition: Format.cpp:41
Wrapper class that contains a MLIR op's information (e.g., operands, attributes) defined in TableGen ...
Definition: Operator.h:77
FailureOr< ast::Module * > parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr, bool enableDocumentation=false, CodeCompleteContext *codeCompleteContext=nullptr)
Parse an AST module from the main file of the given source manager.
Definition: Parser.cpp:3196
auto tgfmt(StringRef fmt, const FmtContext *ctx, Ts &&...vals) -> FmtObject< decltype(std::make_tuple(llvm::support::detail::build_format_adapter(std::forward< Ts >(vals))...))>
Formats text by substituting placeholders in format string with replacement parameters.
Definition: Format.h:262
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This class represents a reference to a constraint, and contains a constraint and the location of the ...
Definition: Nodes.h:716
const ConstraintDecl * constraint
Definition: Nodes.h:722
This class provides a convenient API for interacting with source names.
Definition: Nodes.h:37
StringRef getName() const
Return the raw string name.
Definition: Nodes.h:41
SMRange getLoc() const
Get the location of this name.
Definition: Nodes.h:44
static const Name & create(Context &ctx, StringRef name, SMRange location)
Definition: Nodes.cpp:33