MLIR  20.0.0git
Parser.cpp
Go to the documentation of this file.
1 //===- Parser.cpp ---------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "Lexer.h"
12 #include "mlir/TableGen/Argument.h"
15 #include "mlir/TableGen/Format.h"
16 #include "mlir/TableGen/Operator.h"
25 #include "llvm/ADT/StringExtras.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "llvm/Support/ManagedStatic.h"
29 #include "llvm/Support/SaveAndRestore.h"
30 #include "llvm/Support/ScopedPrinter.h"
31 #include "llvm/TableGen/Error.h"
32 #include "llvm/TableGen/Parser.h"
33 #include <optional>
34 #include <string>
35 
36 using namespace mlir;
37 using namespace mlir::pdll;
38 
39 //===----------------------------------------------------------------------===//
40 // Parser
41 //===----------------------------------------------------------------------===//
42 
43 namespace {
44 class Parser {
45 public:
46  Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
47  bool enableDocumentation, CodeCompleteContext *codeCompleteContext)
48  : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext),
49  curToken(lexer.lexToken()), enableDocumentation(enableDocumentation),
50  typeTy(ast::TypeType::get(ctx)), valueTy(ast::ValueType::get(ctx)),
51  typeRangeTy(ast::TypeRangeType::get(ctx)),
52  valueRangeTy(ast::ValueRangeType::get(ctx)),
53  attrTy(ast::AttributeType::get(ctx)),
54  codeCompleteContext(codeCompleteContext) {}
55 
56  /// Try to parse a new module. Returns nullptr in the case of failure.
57  FailureOr<ast::Module *> parseModule();
58 
59 private:
60  /// The current context of the parser. It allows for the parser to know a bit
61  /// about the construct it is nested within during parsing. This is used
62  /// specifically to provide additional verification during parsing, e.g. to
63  /// prevent using rewrites within a match context, matcher constraints within
64  /// a rewrite section, etc.
65  enum class ParserContext {
66  /// The parser is in the global context.
67  Global,
68  /// The parser is currently within a Constraint, which disallows all types
69  /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.).
70  Constraint,
71  /// The parser is currently within the matcher portion of a Pattern, which
72  /// is allows a terminal operation rewrite statement but no other rewrite
73  /// transformations.
74  PatternMatch,
75  /// The parser is currently within a Rewrite, which disallows calls to
76  /// constraints, requires operation expressions to have names, etc.
77  Rewrite,
78  };
79 
80  /// The current specification context of an operations result type. This
81  /// indicates how the result types of an operation may be inferred.
82  enum class OpResultTypeContext {
83  /// The result types of the operation are not known to be inferred.
84  Explicit,
85  /// The result types of the operation are inferred from the root input of a
86  /// `replace` statement.
87  Replacement,
88  /// The result types of the operation are inferred by using the
89  /// `InferTypeOpInterface` interface provided by the operation.
90  Interface,
91  };
92 
93  //===--------------------------------------------------------------------===//
94  // Parsing
95  //===--------------------------------------------------------------------===//
96 
97  /// Push a new decl scope onto the lexer.
98  ast::DeclScope *pushDeclScope() {
99  ast::DeclScope *newScope =
100  new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope);
101  return (curDeclScope = newScope);
102  }
103  void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; }
104 
105  /// Pop the last decl scope from the lexer.
106  void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); }
107 
108  /// Parse the body of an AST module.
109  LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls);
110 
111  /// Try to convert the given expression to `type`. Returns failure and emits
112  /// an error if a conversion is not viable. On failure, `noteAttachFn` is
113  /// invoked to attach notes to the emitted error diagnostic. On success,
114  /// `expr` is updated to the expression used to convert to `type`.
115  LogicalResult convertExpressionTo(
116  ast::Expr *&expr, ast::Type type,
117  function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {});
118  LogicalResult
119  convertOpExpressionTo(ast::Expr *&expr, ast::OperationType exprType,
120  ast::Type type,
121  function_ref<ast::InFlightDiagnostic()> emitErrorFn);
122  LogicalResult convertTupleExpressionTo(
123  ast::Expr *&expr, ast::TupleType exprType, ast::Type type,
124  function_ref<ast::InFlightDiagnostic()> emitErrorFn,
125  function_ref<void(ast::Diagnostic &diag)> noteAttachFn);
126 
127  /// Given an operation expression, convert it to a Value or ValueRange
128  /// typed expression.
129  ast::Expr *convertOpToValue(const ast::Expr *opExpr);
130 
131  /// Lookup ODS information for the given operation, returns nullptr if no
132  /// information is found.
133  const ods::Operation *lookupODSOperation(std::optional<StringRef> opName) {
134  return opName ? ctx.getODSContext().lookupOperation(*opName) : nullptr;
135  }
136 
137  /// Process the given documentation string, or return an empty string if
138  /// documentation isn't enabled.
139  StringRef processDoc(StringRef doc) {
140  return enableDocumentation ? doc : StringRef();
141  }
142 
143  /// Process the given documentation string and format it, or return an empty
144  /// string if documentation isn't enabled.
145  std::string processAndFormatDoc(const Twine &doc) {
146  if (!enableDocumentation)
147  return "";
148  std::string docStr;
149  {
150  llvm::raw_string_ostream docOS(docStr);
151  std::string tmpDocStr = doc.str();
153  StringRef(tmpDocStr).rtrim(" \t"));
154  }
155  return docStr;
156  }
157 
158  //===--------------------------------------------------------------------===//
159  // Directives
160 
161  LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls);
162  LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls);
163  LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc,
165 
166  /// Process the records of a parsed tablegen include file.
167  void processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
169 
170  /// Create a user defined native constraint for a constraint imported from
171  /// ODS.
172  template <typename ConstraintT>
173  ast::Decl *
174  createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock,
175  SMRange loc, ast::Type type,
176  StringRef nativeType, StringRef docString);
177  template <typename ConstraintT>
178  ast::Decl *
179  createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
180  SMRange loc, ast::Type type,
181  StringRef nativeType);
182 
183  //===--------------------------------------------------------------------===//
184  // Decls
185 
186  /// This structure contains the set of pattern metadata that may be parsed.
187  struct ParsedPatternMetadata {
188  std::optional<uint16_t> benefit;
189  bool hasBoundedRecursion = false;
190  };
191 
192  FailureOr<ast::Decl *> parseTopLevelDecl();
193  FailureOr<ast::NamedAttributeDecl *>
194  parseNamedAttributeDecl(std::optional<StringRef> parentOpName);
195 
196  /// Parse an argument variable as part of the signature of a
197  /// UserConstraintDecl or UserRewriteDecl.
198  FailureOr<ast::VariableDecl *> parseArgumentDecl();
199 
200  /// Parse a result variable as part of the signature of a UserConstraintDecl
201  /// or UserRewriteDecl.
202  FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum);
203 
204  /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being
205  /// defined in a non-global context.
206  FailureOr<ast::UserConstraintDecl *>
207  parseUserConstraintDecl(bool isInline = false);
208 
209  /// Parse an inline UserConstraintDecl. An inline decl is one defined in a
210  /// non-global context, such as within a Pattern/Constraint/etc.
211  FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl();
212 
213  /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
214  /// PDLL constructs.
215  FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl(
216  const ast::Name &name, bool isInline,
217  ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
218  ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
219 
220  /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being
221  /// defined in a non-global context.
222  FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false);
223 
224  /// Parse an inline UserRewriteDecl. An inline decl is one defined in a
225  /// non-global context, such as within a Pattern/Rewrite/etc.
226  FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl();
227 
228  /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
229  /// PDLL constructs.
230  FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl(
231  const ast::Name &name, bool isInline,
232  ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
233  ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
234 
235  /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have
236  /// effectively the same syntax, and only differ on slight semantics (given
237  /// the different parsing contexts).
238  template <typename T, typename ParseUserPDLLDeclFnT>
239  FailureOr<T *> parseUserConstraintOrRewriteDecl(
240  ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
241  StringRef anonymousNamePrefix, bool isInline);
242 
243  /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl.
244  /// These decls have effectively the same syntax.
245  template <typename T>
246  FailureOr<T *> parseUserNativeConstraintOrRewriteDecl(
247  const ast::Name &name, bool isInline,
249  ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
250 
251  /// Parse the functional signature (i.e. the arguments and results) of a
252  /// UserConstraintDecl or UserRewriteDecl.
253  LogicalResult parseUserConstraintOrRewriteSignature(
256  ast::DeclScope *&argumentScope, ast::Type &resultType);
257 
258  /// Validate the return (which if present is specified by bodyIt) of a
259  /// UserConstraintDecl or UserRewriteDecl.
260  LogicalResult validateUserConstraintOrRewriteReturn(
261  StringRef declType, ast::CompoundStmt *body,
264  ArrayRef<ast::VariableDecl *> results, ast::Type &resultType);
265 
266  FailureOr<ast::CompoundStmt *>
267  parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
268  bool expectTerminalSemicolon = true);
269  FailureOr<ast::CompoundStmt *> parsePatternLambdaBody();
270  FailureOr<ast::Decl *> parsePatternDecl();
271  LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
272 
273  /// Check to see if a decl has already been defined with the given name, if
274  /// one has emit and error and return failure. Returns success otherwise.
275  LogicalResult checkDefineNamedDecl(const ast::Name &name);
276 
277  /// Try to define a variable decl with the given components, returns the
278  /// variable on success.
279  FailureOr<ast::VariableDecl *>
280  defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
281  ast::Expr *initExpr,
282  ArrayRef<ast::ConstraintRef> constraints);
283  FailureOr<ast::VariableDecl *>
284  defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
285  ArrayRef<ast::ConstraintRef> constraints);
286 
287  /// Parse the constraint reference list for a variable decl.
288  LogicalResult parseVariableDeclConstraintList(
290 
291  /// Parse the expression used within a type constraint, e.g. Attr<type-expr>.
292  FailureOr<ast::Expr *> parseTypeConstraintExpr();
293 
294  /// Try to parse a single reference to a constraint. `typeConstraint` is the
295  /// location of a previously parsed type constraint for the entity that will
296  /// be constrained by the parsed constraint. `existingConstraints` are any
297  /// existing constraints that have already been parsed for the same entity
298  /// that will be constrained by this constraint. `allowInlineTypeConstraints`
299  /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`.
300  FailureOr<ast::ConstraintRef>
301  parseConstraint(std::optional<SMRange> &typeConstraint,
302  ArrayRef<ast::ConstraintRef> existingConstraints,
303  bool allowInlineTypeConstraints);
304 
305  /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl
306  /// argument or result variable. The constraints for these variables do not
307  /// allow inline type constraints, and only permit a single constraint.
308  FailureOr<ast::ConstraintRef> parseArgOrResultConstraint();
309 
310  //===--------------------------------------------------------------------===//
311  // Exprs
312 
313  FailureOr<ast::Expr *> parseExpr();
314 
315  /// Identifier expressions.
316  FailureOr<ast::Expr *> parseAttributeExpr();
317  FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr,
318  bool isNegated = false);
319  FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
320  FailureOr<ast::Expr *> parseIdentifierExpr();
321  FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
322  FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
323  FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
324  FailureOr<ast::Expr *> parseNegatedExpr();
325  FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
326  FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
327  FailureOr<ast::Expr *>
328  parseOperationExpr(OpResultTypeContext inputResultTypeContext =
329  OpResultTypeContext::Explicit);
330  FailureOr<ast::Expr *> parseTupleExpr();
331  FailureOr<ast::Expr *> parseTypeExpr();
332  FailureOr<ast::Expr *> parseUnderscoreExpr();
333 
334  //===--------------------------------------------------------------------===//
335  // Stmts
336 
337  FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true);
338  FailureOr<ast::CompoundStmt *> parseCompoundStmt();
339  FailureOr<ast::EraseStmt *> parseEraseStmt();
340  FailureOr<ast::LetStmt *> parseLetStmt();
341  FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
342  FailureOr<ast::ReturnStmt *> parseReturnStmt();
343  FailureOr<ast::RewriteStmt *> parseRewriteStmt();
344 
345  //===--------------------------------------------------------------------===//
346  // Creation+Analysis
347  //===--------------------------------------------------------------------===//
348 
349  //===--------------------------------------------------------------------===//
350  // Decls
351 
352  /// Try to extract a callable from the given AST node. Returns nullptr on
353  /// failure.
354  ast::CallableDecl *tryExtractCallableDecl(ast::Node *node);
355 
356  /// Try to create a pattern decl with the given components, returning the
357  /// Pattern on success.
358  FailureOr<ast::PatternDecl *>
359  createPatternDecl(SMRange loc, const ast::Name *name,
360  const ParsedPatternMetadata &metadata,
361  ast::CompoundStmt *body);
362 
363  /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set
364  /// of results, defined as part of the signature.
365  ast::Type
366  createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results);
367 
368  /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl.
369  template <typename T>
370  FailureOr<T *> createUserPDLLConstraintOrRewriteDecl(
371  const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
372  ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
373  ast::CompoundStmt *body);
374 
375  /// Try to create a variable decl with the given components, returning the
376  /// Variable on success.
377  FailureOr<ast::VariableDecl *>
378  createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
379  ArrayRef<ast::ConstraintRef> constraints);
380 
381  /// Create a variable for an argument or result defined as part of the
382  /// signature of a UserConstraintDecl/UserRewriteDecl.
383  FailureOr<ast::VariableDecl *>
384  createArgOrResultVariableDecl(StringRef name, SMRange loc,
385  const ast::ConstraintRef &constraint);
386 
387  /// Validate the constraints used to constraint a variable decl.
388  /// `inferredType` is the type of the variable inferred by the constraints
389  /// within the list, and is updated to the most refined type as determined by
390  /// the constraints. Returns success if the constraint list is valid, failure
391  /// otherwise.
392  LogicalResult
393  validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
394  ast::Type &inferredType);
395  /// Validate a single reference to a constraint. `inferredType` contains the
396  /// currently inferred variabled type and is refined within the type defined
397  /// by the constraint. Returns success if the constraint is valid, failure
398  /// otherwise.
399  LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref,
400  ast::Type &inferredType);
401  LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr);
402  LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr);
403 
404  //===--------------------------------------------------------------------===//
405  // Exprs
406 
407  FailureOr<ast::CallExpr *>
408  createCallExpr(SMRange loc, ast::Expr *parentExpr,
410  bool isNegated = false);
411  FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl);
412  FailureOr<ast::DeclRefExpr *>
413  createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
414  ArrayRef<ast::ConstraintRef> constraints);
415  FailureOr<ast::MemberAccessExpr *>
416  createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc);
417 
418  /// Validate the member access `name` into the given parent expression. On
419  /// success, this also returns the type of the member accessed.
420  FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr,
421  StringRef name, SMRange loc);
422  FailureOr<ast::OperationExpr *>
423  createOperationExpr(SMRange loc, const ast::OpNameDecl *name,
424  OpResultTypeContext resultTypeContext,
428  LogicalResult
429  validateOperationOperands(SMRange loc, std::optional<StringRef> name,
430  const ods::Operation *odsOp,
431  SmallVectorImpl<ast::Expr *> &operands);
432  LogicalResult validateOperationResults(SMRange loc,
433  std::optional<StringRef> name,
434  const ods::Operation *odsOp,
436  void checkOperationResultTypeInferrence(SMRange loc, StringRef name,
437  const ods::Operation *odsOp);
438  LogicalResult validateOperationOperandsOrResults(
439  StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc,
440  std::optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values,
441  ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
442  ast::RangeType rangeTy);
443  FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
444  ArrayRef<ast::Expr *> elements,
445  ArrayRef<StringRef> elementNames);
446 
447  //===--------------------------------------------------------------------===//
448  // Stmts
449 
450  FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp);
451  FailureOr<ast::ReplaceStmt *>
452  createReplaceStmt(SMRange loc, ast::Expr *rootOp,
453  MutableArrayRef<ast::Expr *> replValues);
454  FailureOr<ast::RewriteStmt *>
455  createRewriteStmt(SMRange loc, ast::Expr *rootOp,
456  ast::CompoundStmt *rewriteBody);
457 
458  //===--------------------------------------------------------------------===//
459  // Code Completion
460  //===--------------------------------------------------------------------===//
461 
462  /// The set of various code completion methods. Every completion method
463  /// returns `failure` to stop the parsing process after providing completion
464  /// results.
465 
466  LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr);
467  LogicalResult codeCompleteAttributeName(std::optional<StringRef> opName);
468  LogicalResult codeCompleteConstraintName(ast::Type inferredType,
469  bool allowInlineTypeConstraints);
470  LogicalResult codeCompleteDialectName();
471  LogicalResult codeCompleteOperationName(StringRef dialectName);
472  LogicalResult codeCompletePatternMetadata();
473  LogicalResult codeCompleteIncludeFilename(StringRef curPath);
474 
475  void codeCompleteCallSignature(ast::Node *parent, unsigned currentNumArgs);
476  void codeCompleteOperationOperandsSignature(std::optional<StringRef> opName,
477  unsigned currentNumOperands);
478  void codeCompleteOperationResultsSignature(std::optional<StringRef> opName,
479  unsigned currentNumResults);
480 
481  //===--------------------------------------------------------------------===//
482  // Lexer Utilities
483  //===--------------------------------------------------------------------===//
484 
485  /// If the current token has the specified kind, consume it and return true.
486  /// If not, return false.
487  bool consumeIf(Token::Kind kind) {
488  if (curToken.isNot(kind))
489  return false;
490  consumeToken(kind);
491  return true;
492  }
493 
494  /// Advance the current lexer onto the next token.
495  void consumeToken() {
496  assert(curToken.isNot(Token::eof, Token::error) &&
497  "shouldn't advance past EOF or errors");
498  curToken = lexer.lexToken();
499  }
500 
501  /// Advance the current lexer onto the next token, asserting what the expected
502  /// current token is. This is preferred to the above method because it leads
503  /// to more self-documenting code with better checking.
504  void consumeToken(Token::Kind kind) {
505  assert(curToken.is(kind) && "consumed an unexpected token");
506  consumeToken();
507  }
508 
509  /// Reset the lexer to the location at the given position.
510  void resetToken(SMRange tokLoc) {
511  lexer.resetPointer(tokLoc.Start.getPointer());
512  curToken = lexer.lexToken();
513  }
514 
515  /// Consume the specified token if present and return success. On failure,
516  /// output a diagnostic and return failure.
517  LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
518  if (curToken.getKind() != kind)
519  return emitError(curToken.getLoc(), msg);
520  consumeToken();
521  return success();
522  }
523  LogicalResult emitError(SMRange loc, const Twine &msg) {
524  lexer.emitError(loc, msg);
525  return failure();
526  }
527  LogicalResult emitError(const Twine &msg) {
528  return emitError(curToken.getLoc(), msg);
529  }
530  LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc,
531  const Twine &note) {
532  lexer.emitErrorAndNote(loc, msg, noteLoc, note);
533  return failure();
534  }
535 
536  //===--------------------------------------------------------------------===//
537  // Fields
538  //===--------------------------------------------------------------------===//
539 
540  /// The owning AST context.
541  ast::Context &ctx;
542 
543  /// The lexer of this parser.
544  Lexer lexer;
545 
546  /// The current token within the lexer.
547  Token curToken;
548 
549  /// A flag indicating if the parser should add documentation to AST nodes when
550  /// viable.
551  bool enableDocumentation;
552 
553  /// The most recently defined decl scope.
554  ast::DeclScope *curDeclScope = nullptr;
555  llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;
556 
557  /// The current context of the parser.
558  ParserContext parserContext = ParserContext::Global;
559 
560  /// Cached types to simplify verification and expression creation.
561  ast::Type typeTy, valueTy;
562  ast::RangeType typeRangeTy, valueRangeTy;
563  ast::Type attrTy;
564 
565  /// A counter used when naming anonymous constraints and rewrites.
566  unsigned anonymousDeclNameCounter = 0;
567 
568  /// The optional code completion context.
569  CodeCompleteContext *codeCompleteContext;
570 };
571 } // namespace
572 
573 FailureOr<ast::Module *> Parser::parseModule() {
574  SMLoc moduleLoc = curToken.getStartLoc();
575  pushDeclScope();
576 
577  // Parse the top-level decls of the module.
579  if (failed(parseModuleBody(decls)))
580  return popDeclScope(), failure();
581 
582  popDeclScope();
583  return ast::Module::create(ctx, moduleLoc, decls);
584 }
585 
586 LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) {
587  while (curToken.isNot(Token::eof)) {
588  if (curToken.is(Token::directive)) {
589  if (failed(parseDirective(decls)))
590  return failure();
591  continue;
592  }
593 
594  FailureOr<ast::Decl *> decl = parseTopLevelDecl();
595  if (failed(decl))
596  return failure();
597  decls.push_back(*decl);
598  }
599  return success();
600 }
601 
602 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) {
603  return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr,
604  valueRangeTy);
605 }
606 
607 LogicalResult Parser::convertExpressionTo(
608  ast::Expr *&expr, ast::Type type,
609  function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
610  ast::Type exprType = expr->getType();
611  if (exprType == type)
612  return success();
613 
614  auto emitConvertError = [&]() -> ast::InFlightDiagnostic {
616  expr->getLoc(), llvm::formatv("unable to convert expression of type "
617  "`{0}` to the expected type of "
618  "`{1}`",
619  exprType, type));
620  if (noteAttachFn)
621  noteAttachFn(*diag);
622  return diag;
623  };
624 
625  if (auto exprOpType = dyn_cast<ast::OperationType>(exprType))
626  return convertOpExpressionTo(expr, exprOpType, type, emitConvertError);
627 
628  // FIXME: Decide how to allow/support converting a single result to multiple,
629  // and multiple to a single result. For now, we just allow Single->Range,
630  // but this isn't something really supported in the PDL dialect. We should
631  // figure out some way to support both.
632  if ((exprType == valueTy || exprType == valueRangeTy) &&
633  (type == valueTy || type == valueRangeTy))
634  return success();
635  if ((exprType == typeTy || exprType == typeRangeTy) &&
636  (type == typeTy || type == typeRangeTy))
637  return success();
638 
639  // Handle tuple types.
640  if (auto exprTupleType = dyn_cast<ast::TupleType>(exprType))
641  return convertTupleExpressionTo(expr, exprTupleType, type, emitConvertError,
642  noteAttachFn);
643 
644  return emitConvertError();
645 }
646 
647 LogicalResult Parser::convertOpExpressionTo(
648  ast::Expr *&expr, ast::OperationType exprType, ast::Type type,
649  function_ref<ast::InFlightDiagnostic()> emitErrorFn) {
650  // Two operation types are compatible if they have the same name, or if the
651  // expected type is more general.
652  if (auto opType = dyn_cast<ast::OperationType>(type)) {
653  if (opType.getName())
654  return emitErrorFn();
655  return success();
656  }
657 
658  // An operation can always convert to a ValueRange.
659  if (type == valueRangeTy) {
660  expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
661  valueRangeTy);
662  return success();
663  }
664 
665  // Allow conversion to a single value by constraining the result range.
666  if (type == valueTy) {
667  // If the operation is registered, we can verify if it can ever have a
668  // single result.
669  if (const ods::Operation *odsOp = exprType.getODSOperation()) {
670  if (odsOp->getResults().empty()) {
671  return emitErrorFn()->attachNote(
672  llvm::formatv("see the definition of `{0}`, which was defined "
673  "with zero results",
674  odsOp->getName()),
675  odsOp->getLoc());
676  }
677 
678  unsigned numSingleResults = llvm::count_if(
679  odsOp->getResults(), [](const ods::OperandOrResult &result) {
680  return result.getVariableLengthKind() ==
681  ods::VariableLengthKind::Single;
682  });
683  if (numSingleResults > 1) {
684  return emitErrorFn()->attachNote(
685  llvm::formatv("see the definition of `{0}`, which was defined "
686  "with at least {1} results",
687  odsOp->getName(), numSingleResults),
688  odsOp->getLoc());
689  }
690  }
691 
692  expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
693  valueTy);
694  return success();
695  }
696  return emitErrorFn();
697 }
698 
699 LogicalResult Parser::convertTupleExpressionTo(
700  ast::Expr *&expr, ast::TupleType exprType, ast::Type type,
701  function_ref<ast::InFlightDiagnostic()> emitErrorFn,
702  function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
703  // Handle conversions between tuples.
704  if (auto tupleType = dyn_cast<ast::TupleType>(type)) {
705  if (tupleType.size() != exprType.size())
706  return emitErrorFn();
707 
708  // Build a new tuple expression using each of the elements of the current
709  // tuple.
710  SmallVector<ast::Expr *> newExprs;
711  for (unsigned i = 0, e = exprType.size(); i < e; ++i) {
712  newExprs.push_back(ast::MemberAccessExpr::create(
713  ctx, expr->getLoc(), expr, llvm::to_string(i),
714  exprType.getElementTypes()[i]));
715 
716  auto diagFn = [&](ast::Diagnostic &diag) {
717  diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`",
718  i, exprType));
719  if (noteAttachFn)
720  noteAttachFn(diag);
721  };
722  if (failed(convertExpressionTo(newExprs.back(),
723  tupleType.getElementTypes()[i], diagFn)))
724  return failure();
725  }
726  expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs,
727  tupleType.getElementNames());
728  return success();
729  }
730 
731  // Handle conversion to a range.
732  auto convertToRange = [&](ArrayRef<ast::Type> allowedElementTypes,
733  ast::RangeType resultTy) -> LogicalResult {
734  // TODO: We currently only allow range conversion within a rewrite context.
735  if (parserContext != ParserContext::Rewrite) {
736  return emitErrorFn()->attachNote("Tuple to Range conversion is currently "
737  "only allowed within a rewrite context");
738  }
739 
740  // All of the tuple elements must be allowed types.
741  for (ast::Type elementType : exprType.getElementTypes())
742  if (!llvm::is_contained(allowedElementTypes, elementType))
743  return emitErrorFn();
744 
745  // Build a new tuple expression using each of the elements of the current
746  // tuple.
747  SmallVector<ast::Expr *> newExprs;
748  for (unsigned i = 0, e = exprType.size(); i < e; ++i) {
749  newExprs.push_back(ast::MemberAccessExpr::create(
750  ctx, expr->getLoc(), expr, llvm::to_string(i),
751  exprType.getElementTypes()[i]));
752  }
753  expr = ast::RangeExpr::create(ctx, expr->getLoc(), newExprs, resultTy);
754  return success();
755  };
756  if (type == valueRangeTy)
757  return convertToRange({valueTy, valueRangeTy}, valueRangeTy);
758  if (type == typeRangeTy)
759  return convertToRange({typeTy, typeRangeTy}, typeRangeTy);
760 
761  return emitErrorFn();
762 }
763 
764 //===----------------------------------------------------------------------===//
765 // Directives
766 
767 LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) {
768  StringRef directive = curToken.getSpelling();
769  if (directive == "#include")
770  return parseInclude(decls);
771 
772  return emitError("unknown directive `" + directive + "`");
773 }
774 
775 LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) {
776  SMRange loc = curToken.getLoc();
777  consumeToken(Token::directive);
778 
779  // Handle code completion of the include file path.
780  if (curToken.is(Token::code_complete_string))
781  return codeCompleteIncludeFilename(curToken.getStringValue());
782 
783  // Parse the file being included.
784  if (!curToken.isString())
785  return emitError(loc,
786  "expected string file name after `include` directive");
787  SMRange fileLoc = curToken.getLoc();
788  std::string filenameStr = curToken.getStringValue();
789  StringRef filename = filenameStr;
790  consumeToken();
791 
792  // Check the type of include. If ending with `.pdll`, this is another pdl file
793  // to be parsed along with the current module.
794  if (filename.ends_with(".pdll")) {
795  if (failed(lexer.pushInclude(filename, fileLoc)))
796  return emitError(fileLoc,
797  "unable to open include file `" + filename + "`");
798 
799  // If we added the include successfully, parse it into the current module.
800  // Make sure to update to the next token after we finish parsing the nested
801  // file.
802  curToken = lexer.lexToken();
803  LogicalResult result = parseModuleBody(decls);
804  curToken = lexer.lexToken();
805  return result;
806  }
807 
808  // Otherwise, this must be a `.td` include.
809  if (filename.ends_with(".td"))
810  return parseTdInclude(filename, fileLoc, decls);
811 
812  return emitError(fileLoc,
813  "expected include filename to end with `.pdll` or `.td`");
814 }
815 
816 LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc,
818  llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr();
819 
820  // Use the source manager to open the file, but don't yet add it.
821  std::string includedFile;
822  llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer =
823  parserSrcMgr.OpenIncludeFile(filename.str(), includedFile);
824  if (!includeBuffer)
825  return emitError(fileLoc, "unable to open include file `" + filename + "`");
826 
827  // Setup the source manager for parsing the tablegen file.
828  llvm::SourceMgr tdSrcMgr;
829  tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc());
830  tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs());
831 
832  // This class provides a context argument for the llvm::SourceMgr diagnostic
833  // handler.
834  struct DiagHandlerContext {
835  Parser &parser;
836  StringRef filename;
837  llvm::SMRange loc;
838  } handlerContext{*this, filename, fileLoc};
839 
840  // Set the diagnostic handler for the tablegen source manager.
841  tdSrcMgr.setDiagHandler(
842  [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) {
843  auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext);
844  (void)ctx->parser.emitError(
845  ctx->loc,
846  llvm::formatv("error while processing include file `{0}`: {1}",
847  ctx->filename, diag.getMessage()));
848  },
849  &handlerContext);
850 
851  // Parse the tablegen file.
852  llvm::RecordKeeper tdRecords;
853  if (llvm::TableGenParseFile(tdSrcMgr, tdRecords))
854  return failure();
855 
856  // Process the parsed records.
857  processTdIncludeRecords(tdRecords, decls);
858 
859  // After we are done processing, move all of the tablegen source buffers to
860  // the main parser source mgr. This allows for directly using source locations
861  // from the .td files without needing to remap them.
862  parserSrcMgr.takeSourceBuffersFrom(tdSrcMgr, fileLoc.End);
863  return success();
864 }
865 
866 void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
868  // Return the length kind of the given value.
869  auto getLengthKind = [](const auto &value) {
870  if (value.isOptional())
872  return value.isVariadic() ? ods::VariableLengthKind::Variadic
874  };
875 
876  // Insert a type constraint into the ODS context.
877  ods::Context &odsContext = ctx.getODSContext();
878  auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst)
879  -> const ods::TypeConstraint & {
880  return odsContext.insertTypeConstraint(
881  cst.constraint.getUniqueDefName(),
882  processDoc(cst.constraint.getSummary()),
883  cst.constraint.getCPPClassName());
884  };
885  auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange {
886  return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)};
887  };
888 
889  // Process the parsed tablegen records to build ODS information.
890  /// Operations.
891  for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) {
892  tblgen::Operator op(def);
893 
894  // Check to see if this operation is known to support type inferrence.
895  bool supportsResultTypeInferrence =
896  op.getTrait("::mlir::InferTypeOpInterface::Trait");
897 
898  auto [odsOp, inserted] = odsContext.insertOperation(
899  op.getOperationName(), processDoc(op.getSummary()),
900  processAndFormatDoc(op.getDescription()), op.getQualCppClassName(),
901  supportsResultTypeInferrence, op.getLoc().front());
902 
903  // Ignore operations that have already been added.
904  if (!inserted)
905  continue;
906 
907  for (const tblgen::NamedAttribute &attr : op.getAttributes()) {
908  odsOp->appendAttribute(attr.name, attr.attr.isOptional(),
909  odsContext.insertAttributeConstraint(
910  attr.attr.getUniqueDefName(),
911  processDoc(attr.attr.getSummary()),
912  attr.attr.getStorageType()));
913  }
914  for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) {
915  odsOp->appendOperand(operand.name, getLengthKind(operand),
916  addTypeConstraint(operand));
917  }
918  for (const tblgen::NamedTypeConstraint &result : op.getResults()) {
919  odsOp->appendResult(result.name, getLengthKind(result),
920  addTypeConstraint(result));
921  }
922  }
923 
924  auto shouldBeSkipped = [this](llvm::Record *def) {
925  return def->isAnonymous() || curDeclScope->lookup(def->getName()) ||
926  def->isSubClassOf("DeclareInterfaceMethods");
927  };
928 
929  /// Attr constraints.
930  for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) {
931  if (shouldBeSkipped(def))
932  continue;
933 
934  tblgen::Attribute constraint(def);
935  decls.push_back(createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
936  constraint, convertLocToRange(def->getLoc().front()), attrTy,
937  constraint.getStorageType()));
938  }
939  /// Type constraints.
940  for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) {
941  if (shouldBeSkipped(def))
942  continue;
943 
944  tblgen::TypeConstraint constraint(def);
945  decls.push_back(createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
946  constraint, convertLocToRange(def->getLoc().front()), typeTy,
947  constraint.getCPPClassName()));
948  }
949  /// OpInterfaces.
951  for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("OpInterface")) {
952  if (shouldBeSkipped(def))
953  continue;
954 
955  SMRange loc = convertLocToRange(def->getLoc().front());
956 
957  std::string cppClassName =
958  llvm::formatv("{0}::{1}", def->getValueAsString("cppNamespace"),
959  def->getValueAsString("cppInterfaceName"))
960  .str();
961  std::string codeBlock =
962  llvm::formatv("return ::mlir::success(llvm::isa<{0}>(self));",
963  cppClassName)
964  .str();
965 
966  std::string desc =
967  processAndFormatDoc(def->getValueAsString("description"));
968  decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>(
969  def->getName(), codeBlock, loc, opTy, cppClassName, desc));
970  }
971 }
972 
973 template <typename ConstraintT>
974 ast::Decl *Parser::createODSNativePDLLConstraintDecl(
975  StringRef name, StringRef codeBlock, SMRange loc, ast::Type type,
976  StringRef nativeType, StringRef docString) {
977  // Build the single input parameter.
978  ast::DeclScope *argScope = pushDeclScope();
979  auto *paramVar = ast::VariableDecl::create(
980  ctx, ast::Name::create(ctx, "self", loc), type,
981  /*initExpr=*/nullptr, ast::ConstraintRef(ConstraintT::create(ctx, loc)));
982  argScope->add(paramVar);
983  popDeclScope();
984 
985  // Build the native constraint.
986  auto *constraintDecl = ast::UserConstraintDecl::createNative(
987  ctx, ast::Name::create(ctx, name, loc), paramVar,
988  /*results=*/std::nullopt, codeBlock, ast::TupleType::get(ctx),
989  nativeType);
990  constraintDecl->setDocComment(ctx, docString);
991  curDeclScope->add(constraintDecl);
992  return constraintDecl;
993 }
994 
995 template <typename ConstraintT>
996 ast::Decl *
997 Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
998  SMRange loc, ast::Type type,
999  StringRef nativeType) {
1000  // Format the condition template.
1001  tblgen::FmtContext fmtContext;
1002  fmtContext.withSelf("self");
1003  std::string codeBlock = tblgen::tgfmt(
1004  "return ::mlir::success(" + constraint.getConditionTemplate() + ");",
1005  &fmtContext);
1006 
1007  // If documentation was enabled, build the doc string for the generated
1008  // constraint. It would be nice to do this lazily, but TableGen information is
1009  // destroyed after we finish parsing the file.
1010  std::string docString;
1011  if (enableDocumentation) {
1012  StringRef desc = constraint.getDescription();
1013  docString = processAndFormatDoc(
1014  constraint.getSummary() +
1015  (desc.empty() ? "" : ("\n\n" + constraint.getDescription())));
1016  }
1017 
1018  return createODSNativePDLLConstraintDecl<ConstraintT>(
1019  constraint.getUniqueDefName(), codeBlock, loc, type, nativeType,
1020  docString);
1021 }
1022 
1023 //===----------------------------------------------------------------------===//
1024 // Decls
1025 
1026 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
1027  FailureOr<ast::Decl *> decl;
1028  switch (curToken.getKind()) {
1029  case Token::kw_Constraint:
1030  decl = parseUserConstraintDecl();
1031  break;
1032  case Token::kw_Pattern:
1033  decl = parsePatternDecl();
1034  break;
1035  case Token::kw_Rewrite:
1036  decl = parseUserRewriteDecl();
1037  break;
1038  default:
1039  return emitError("expected top-level declaration, such as a `Pattern`");
1040  }
1041  if (failed(decl))
1042  return failure();
1043 
1044  // If the decl has a name, add it to the current scope.
1045  if (const ast::Name *name = (*decl)->getName()) {
1046  if (failed(checkDefineNamedDecl(*name)))
1047  return failure();
1048  curDeclScope->add(*decl);
1049  }
1050  return decl;
1051 }
1052 
1053 FailureOr<ast::NamedAttributeDecl *>
1054 Parser::parseNamedAttributeDecl(std::optional<StringRef> parentOpName) {
1055  // Check for name code completion.
1056  if (curToken.is(Token::code_complete))
1057  return codeCompleteAttributeName(parentOpName);
1058 
1059  std::string attrNameStr;
1060  if (curToken.isString())
1061  attrNameStr = curToken.getStringValue();
1062  else if (curToken.is(Token::identifier) || curToken.isKeyword())
1063  attrNameStr = curToken.getSpelling().str();
1064  else
1065  return emitError("expected identifier or string attribute name");
1066  const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc());
1067  consumeToken();
1068 
1069  // Check for a value of the attribute.
1070  ast::Expr *attrValue = nullptr;
1071  if (consumeIf(Token::equal)) {
1072  FailureOr<ast::Expr *> attrExpr = parseExpr();
1073  if (failed(attrExpr))
1074  return failure();
1075  attrValue = *attrExpr;
1076  } else {
1077  // If there isn't a concrete value, create an expression representing a
1078  // UnitAttr.
1079  attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit");
1080  }
1081 
1082  return ast::NamedAttributeDecl::create(ctx, name, attrValue);
1083 }
1084 
1085 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody(
1086  function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
1087  bool expectTerminalSemicolon) {
1088  consumeToken(Token::equal_arrow);
1089 
1090  // Parse the single statement of the lambda body.
1091  SMLoc bodyStartLoc = curToken.getStartLoc();
1092  pushDeclScope();
1093  FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon);
1094  bool failedToParse =
1095  failed(singleStatement) || failed(processStatementFn(*singleStatement));
1096  popDeclScope();
1097  if (failedToParse)
1098  return failure();
1099 
1100  SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc());
1101  return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement);
1102 }
1103 
1104 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() {
1105  // Ensure that the argument is named.
1106  if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword())
1107  return emitError("expected identifier argument name");
1108 
1109  // Parse the argument similarly to a normal variable.
1110  StringRef name = curToken.getSpelling();
1111  SMRange nameLoc = curToken.getLoc();
1112  consumeToken();
1113 
1114  if (failed(
1115  parseToken(Token::colon, "expected `:` before argument constraint")))
1116  return failure();
1117 
1118  FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1119  if (failed(cst))
1120  return failure();
1121 
1122  return createArgOrResultVariableDecl(name, nameLoc, *cst);
1123 }
1124 
1125 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) {
1126  // Check to see if this result is named.
1127  if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
1128  // Check to see if this name actually refers to a Constraint.
1129  if (!curDeclScope->lookup<ast::ConstraintDecl>(curToken.getSpelling())) {
1130  // If it wasn't a constraint, parse the result similarly to a variable. If
1131  // there is already an existing decl, we will emit an error when defining
1132  // this variable later.
1133  StringRef name = curToken.getSpelling();
1134  SMRange nameLoc = curToken.getLoc();
1135  consumeToken();
1136 
1137  if (failed(parseToken(Token::colon,
1138  "expected `:` before result constraint")))
1139  return failure();
1140 
1141  FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1142  if (failed(cst))
1143  return failure();
1144 
1145  return createArgOrResultVariableDecl(name, nameLoc, *cst);
1146  }
1147  }
1148 
1149  // If it isn't named, we parse the constraint directly and create an unnamed
1150  // result variable.
1151  FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1152  if (failed(cst))
1153  return failure();
1154 
1155  return createArgOrResultVariableDecl("", cst->referenceLoc, *cst);
1156 }
1157 
1158 FailureOr<ast::UserConstraintDecl *>
1159 Parser::parseUserConstraintDecl(bool isInline) {
1160  // Constraints and rewrites have very similar formats, dispatch to a shared
1161  // interface for parsing.
1162  return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1163  [&](auto &&...args) {
1164  return this->parseUserPDLLConstraintDecl(args...);
1165  },
1166  ParserContext::Constraint, "constraint", isInline);
1167 }
1168 
1169 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() {
1170  FailureOr<ast::UserConstraintDecl *> decl =
1171  parseUserConstraintDecl(/*isInline=*/true);
1172  if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1173  return failure();
1174 
1175  curDeclScope->add(*decl);
1176  return decl;
1177 }
1178 
1179 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl(
1180  const ast::Name &name, bool isInline,
1181  ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1182  ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1183  // Push the argument scope back onto the list, so that the body can
1184  // reference arguments.
1185  pushDeclScope(argumentScope);
1186 
1187  // Parse the body of the constraint. The body is either defined as a compound
1188  // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`.
1189  ast::CompoundStmt *body;
1190  if (curToken.is(Token::equal_arrow)) {
1191  FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1192  [&](ast::Stmt *&stmt) -> LogicalResult {
1193  ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt);
1194  if (!stmtExpr) {
1195  return emitError(stmt->getLoc(),
1196  "expected `Constraint` lambda body to contain a "
1197  "single expression");
1198  }
1199  stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr);
1200  return success();
1201  },
1202  /*expectTerminalSemicolon=*/!isInline);
1203  if (failed(bodyResult))
1204  return failure();
1205  body = *bodyResult;
1206  } else {
1207  FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1208  if (failed(bodyResult))
1209  return failure();
1210  body = *bodyResult;
1211 
1212  // Verify the structure of the body.
1213  auto bodyIt = body->begin(), bodyE = body->end();
1214  for (; bodyIt != bodyE; ++bodyIt)
1215  if (isa<ast::ReturnStmt>(*bodyIt))
1216  break;
1217  if (failed(validateUserConstraintOrRewriteReturn(
1218  "Constraint", body, bodyIt, bodyE, results, resultType)))
1219  return failure();
1220  }
1221  popDeclScope();
1222 
1223  return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1224  name, arguments, results, resultType, body);
1225 }
1226 
1227 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) {
1228  // Constraints and rewrites have very similar formats, dispatch to a shared
1229  // interface for parsing.
1230  return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1231  [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(args...); },
1232  ParserContext::Rewrite, "rewrite", isInline);
1233 }
1234 
1235 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() {
1236  FailureOr<ast::UserRewriteDecl *> decl =
1237  parseUserRewriteDecl(/*isInline=*/true);
1238  if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1239  return failure();
1240 
1241  curDeclScope->add(*decl);
1242  return decl;
1243 }
1244 
1245 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl(
1246  const ast::Name &name, bool isInline,
1247  ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1248  ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1249  // Push the argument scope back onto the list, so that the body can
1250  // reference arguments.
1251  curDeclScope = argumentScope;
1252  ast::CompoundStmt *body;
1253  if (curToken.is(Token::equal_arrow)) {
1254  FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1255  [&](ast::Stmt *&statement) -> LogicalResult {
1256  if (isa<ast::OpRewriteStmt>(statement))
1257  return success();
1258 
1259  ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement);
1260  if (!statementExpr) {
1261  return emitError(
1262  statement->getLoc(),
1263  "expected `Rewrite` lambda body to contain a single expression "
1264  "or an operation rewrite statement; such as `erase`, "
1265  "`replace`, or `rewrite`");
1266  }
1267  statement =
1268  ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr);
1269  return success();
1270  },
1271  /*expectTerminalSemicolon=*/!isInline);
1272  if (failed(bodyResult))
1273  return failure();
1274  body = *bodyResult;
1275  } else {
1276  FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1277  if (failed(bodyResult))
1278  return failure();
1279  body = *bodyResult;
1280  }
1281  popDeclScope();
1282 
1283  // Verify the structure of the body.
1284  auto bodyIt = body->begin(), bodyE = body->end();
1285  for (; bodyIt != bodyE; ++bodyIt)
1286  if (isa<ast::ReturnStmt>(*bodyIt))
1287  break;
1288  if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt,
1289  bodyE, results, resultType)))
1290  return failure();
1291  return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1292  name, arguments, results, resultType, body);
1293 }
1294 
1295 template <typename T, typename ParseUserPDLLDeclFnT>
1296 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl(
1297  ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
1298  StringRef anonymousNamePrefix, bool isInline) {
1299  SMRange loc = curToken.getLoc();
1300  consumeToken();
1301  llvm::SaveAndRestore saveCtx(parserContext, declContext);
1302 
1303  // Parse the name of the decl.
1304  const ast::Name *name = nullptr;
1305  if (curToken.isNot(Token::identifier)) {
1306  // Only inline decls can be un-named. Inline decls are similar to "lambdas"
1307  // in C++, so being unnamed is fine.
1308  if (!isInline)
1309  return emitError("expected identifier name");
1310 
1311  // Create a unique anonymous name to use, as the name for this decl is not
1312  // important.
1313  std::string anonName =
1314  llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix,
1315  anonymousDeclNameCounter++)
1316  .str();
1317  name = &ast::Name::create(ctx, anonName, loc);
1318  } else {
1319  // If a name was provided, we can use it directly.
1320  name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1321  consumeToken(Token::identifier);
1322  }
1323 
1324  // Parse the functional signature of the decl.
1325  SmallVector<ast::VariableDecl *> arguments, results;
1326  ast::DeclScope *argumentScope;
1327  ast::Type resultType;
1328  if (failed(parseUserConstraintOrRewriteSignature(arguments, results,
1329  argumentScope, resultType)))
1330  return failure();
1331 
1332  // Check to see which type of constraint this is. If the constraint contains a
1333  // compound body, this is a PDLL decl.
1334  if (curToken.isAny(Token::l_brace, Token::equal_arrow))
1335  return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results,
1336  resultType);
1337 
1338  // Otherwise, this is a native decl.
1339  return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments,
1340  results, resultType);
1341 }
1342 
1343 template <typename T>
1344 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
1345  const ast::Name &name, bool isInline,
1347  ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1348  // If followed by a string, the native code body has also been specified.
1349  std::string codeStrStorage;
1350  std::optional<StringRef> optCodeStr;
1351  if (curToken.isString()) {
1352  codeStrStorage = curToken.getStringValue();
1353  optCodeStr = codeStrStorage;
1354  consumeToken();
1355  } else if (isInline) {
1356  return emitError(name.getLoc(),
1357  "external declarations must be declared in global scope");
1358  } else if (curToken.is(Token::error)) {
1359  return failure();
1360  }
1361  if (failed(parseToken(Token::semicolon,
1362  "expected `;` after native declaration")))
1363  return failure();
1364  return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
1365 }
1366 
1367 LogicalResult Parser::parseUserConstraintOrRewriteSignature(
1370  ast::DeclScope *&argumentScope, ast::Type &resultType) {
1371  // Parse the argument list of the decl.
1372  if (failed(parseToken(Token::l_paren, "expected `(` to start argument list")))
1373  return failure();
1374 
1375  argumentScope = pushDeclScope();
1376  if (curToken.isNot(Token::r_paren)) {
1377  do {
1378  FailureOr<ast::VariableDecl *> argument = parseArgumentDecl();
1379  if (failed(argument))
1380  return failure();
1381  arguments.emplace_back(*argument);
1382  } while (consumeIf(Token::comma));
1383  }
1384  popDeclScope();
1385  if (failed(parseToken(Token::r_paren, "expected `)` to end argument list")))
1386  return failure();
1387 
1388  // Parse the results of the decl.
1389  pushDeclScope();
1390  if (consumeIf(Token::arrow)) {
1391  auto parseResultFn = [&]() -> LogicalResult {
1392  FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size());
1393  if (failed(result))
1394  return failure();
1395  results.emplace_back(*result);
1396  return success();
1397  };
1398 
1399  // Check for a list of results.
1400  if (consumeIf(Token::l_paren)) {
1401  do {
1402  if (failed(parseResultFn()))
1403  return failure();
1404  } while (consumeIf(Token::comma));
1405  if (failed(parseToken(Token::r_paren, "expected `)` to end result list")))
1406  return failure();
1407 
1408  // Otherwise, there is only one result.
1409  } else if (failed(parseResultFn())) {
1410  return failure();
1411  }
1412  }
1413  popDeclScope();
1414 
1415  // Compute the result type of the decl.
1416  resultType = createUserConstraintRewriteResultType(results);
1417 
1418  // Verify that results are only named if there are more than one.
1419  if (results.size() == 1 && !results.front()->getName().getName().empty()) {
1420  return emitError(
1421  results.front()->getLoc(),
1422  "cannot create a single-element tuple with an element label");
1423  }
1424  return success();
1425 }
1426 
1427 LogicalResult Parser::validateUserConstraintOrRewriteReturn(
1428  StringRef declType, ast::CompoundStmt *body,
1431  ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) {
1432  // Handle if a `return` was provided.
1433  if (bodyIt != bodyE) {
1434  // Emit an error if we have trailing statements after the return.
1435  if (std::next(bodyIt) != bodyE) {
1436  return emitError(
1437  (*std::next(bodyIt))->getLoc(),
1438  llvm::formatv("`return` terminated the `{0}` body, but found "
1439  "trailing statements afterwards",
1440  declType));
1441  }
1442 
1443  // Otherwise if a return wasn't provided, check that no results are
1444  // expected.
1445  } else if (!results.empty()) {
1446  return emitError(
1447  {body->getLoc().End, body->getLoc().End},
1448  llvm::formatv("missing return in a `{0}` expected to return `{1}`",
1449  declType, resultType));
1450  }
1451  return success();
1452 }
1453 
1454 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
1455  return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult {
1456  if (isa<ast::OpRewriteStmt>(statement))
1457  return success();
1458  return emitError(
1459  statement->getLoc(),
1460  "expected Pattern lambda body to contain a single operation "
1461  "rewrite statement, such as `erase`, `replace`, or `rewrite`");
1462  });
1463 }
1464 
1465 FailureOr<ast::Decl *> Parser::parsePatternDecl() {
1466  SMRange loc = curToken.getLoc();
1467  consumeToken(Token::kw_Pattern);
1468  llvm::SaveAndRestore saveCtx(parserContext, ParserContext::PatternMatch);
1469 
1470  // Check for an optional identifier for the pattern name.
1471  const ast::Name *name = nullptr;
1472  if (curToken.is(Token::identifier)) {
1473  name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1474  consumeToken(Token::identifier);
1475  }
1476 
1477  // Parse any pattern metadata.
1478  ParsedPatternMetadata metadata;
1479  if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata)))
1480  return failure();
1481 
1482  // Parse the pattern body.
1483  ast::CompoundStmt *body;
1484 
1485  // Handle a lambda body.
1486  if (curToken.is(Token::equal_arrow)) {
1487  FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody();
1488  if (failed(bodyResult))
1489  return failure();
1490  body = *bodyResult;
1491  } else {
1492  if (curToken.isNot(Token::l_brace))
1493  return emitError("expected `{` or `=>` to start pattern body");
1494  FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1495  if (failed(bodyResult))
1496  return failure();
1497  body = *bodyResult;
1498 
1499  // Verify the body of the pattern.
1500  auto bodyIt = body->begin(), bodyE = body->end();
1501  for (; bodyIt != bodyE; ++bodyIt) {
1502  if (isa<ast::ReturnStmt>(*bodyIt)) {
1503  return emitError((*bodyIt)->getLoc(),
1504  "`return` statements are only permitted within a "
1505  "`Constraint` or `Rewrite` body");
1506  }
1507  // Break when we've found the rewrite statement.
1508  if (isa<ast::OpRewriteStmt>(*bodyIt))
1509  break;
1510  }
1511  if (bodyIt == bodyE) {
1512  return emitError(loc,
1513  "expected Pattern body to terminate with an operation "
1514  "rewrite statement, such as `erase`");
1515  }
1516  if (std::next(bodyIt) != bodyE) {
1517  return emitError((*std::next(bodyIt))->getLoc(),
1518  "Pattern body was terminated by an operation "
1519  "rewrite statement, but found trailing statements");
1520  }
1521  }
1522 
1523  return createPatternDecl(loc, name, metadata, body);
1524 }
1525 
1526 LogicalResult
1527 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
1528  std::optional<SMRange> benefitLoc;
1529  std::optional<SMRange> hasBoundedRecursionLoc;
1530 
1531  do {
1532  // Handle metadata code completion.
1533  if (curToken.is(Token::code_complete))
1534  return codeCompletePatternMetadata();
1535 
1536  if (curToken.isNot(Token::identifier))
1537  return emitError("expected pattern metadata identifier");
1538  StringRef metadataStr = curToken.getSpelling();
1539  SMRange metadataLoc = curToken.getLoc();
1540  consumeToken(Token::identifier);
1541 
1542  // Parse the benefit metadata: benefit(<integer-value>)
1543  if (metadataStr == "benefit") {
1544  if (benefitLoc) {
1545  return emitErrorAndNote(metadataLoc,
1546  "pattern benefit has already been specified",
1547  *benefitLoc, "see previous definition here");
1548  }
1549  if (failed(parseToken(Token::l_paren,
1550  "expected `(` before pattern benefit")))
1551  return failure();
1552 
1553  uint16_t benefitValue = 0;
1554  if (curToken.isNot(Token::integer))
1555  return emitError("expected integral pattern benefit");
1556  if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue))
1557  return emitError(
1558  "expected pattern benefit to fit within a 16-bit integer");
1559  consumeToken(Token::integer);
1560 
1561  metadata.benefit = benefitValue;
1562  benefitLoc = metadataLoc;
1563 
1564  if (failed(
1565  parseToken(Token::r_paren, "expected `)` after pattern benefit")))
1566  return failure();
1567  continue;
1568  }
1569 
1570  // Parse the bounded recursion metadata: recursion
1571  if (metadataStr == "recursion") {
1572  if (hasBoundedRecursionLoc) {
1573  return emitErrorAndNote(
1574  metadataLoc,
1575  "pattern recursion metadata has already been specified",
1576  *hasBoundedRecursionLoc, "see previous definition here");
1577  }
1578  metadata.hasBoundedRecursion = true;
1579  hasBoundedRecursionLoc = metadataLoc;
1580  continue;
1581  }
1582 
1583  return emitError(metadataLoc, "unknown pattern metadata");
1584  } while (consumeIf(Token::comma));
1585 
1586  return success();
1587 }
1588 
1589 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() {
1590  consumeToken(Token::less);
1591 
1592  FailureOr<ast::Expr *> typeExpr = parseExpr();
1593  if (failed(typeExpr) ||
1594  failed(parseToken(Token::greater,
1595  "expected `>` after variable type constraint")))
1596  return failure();
1597  return typeExpr;
1598 }
1599 
1600 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) {
1601  assert(curDeclScope && "defining decl outside of a decl scope");
1602  if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) {
1603  return emitErrorAndNote(
1604  name.getLoc(), "`" + name.getName() + "` has already been defined",
1605  lastDecl->getName()->getLoc(), "see previous definition here");
1606  }
1607  return success();
1608 }
1609 
1610 FailureOr<ast::VariableDecl *>
1611 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1612  ast::Expr *initExpr,
1613  ArrayRef<ast::ConstraintRef> constraints) {
1614  assert(curDeclScope && "defining variable outside of decl scope");
1615  const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc);
1616 
1617  // If the name of the variable indicates a special variable, we don't add it
1618  // to the scope. This variable is local to the definition point.
1619  if (name.empty() || name == "_") {
1620  return ast::VariableDecl::create(ctx, nameDecl, type, initExpr,
1621  constraints);
1622  }
1623  if (failed(checkDefineNamedDecl(nameDecl)))
1624  return failure();
1625 
1626  auto *varDecl =
1627  ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints);
1628  curDeclScope->add(varDecl);
1629  return varDecl;
1630 }
1631 
1632 FailureOr<ast::VariableDecl *>
1633 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1634  ArrayRef<ast::ConstraintRef> constraints) {
1635  return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr,
1636  constraints);
1637 }
1638 
1639 LogicalResult Parser::parseVariableDeclConstraintList(
1640  SmallVectorImpl<ast::ConstraintRef> &constraints) {
1641  std::optional<SMRange> typeConstraint;
1642  auto parseSingleConstraint = [&] {
1643  FailureOr<ast::ConstraintRef> constraint = parseConstraint(
1644  typeConstraint, constraints, /*allowInlineTypeConstraints=*/true);
1645  if (failed(constraint))
1646  return failure();
1647  constraints.push_back(*constraint);
1648  return success();
1649  };
1650 
1651  // Check to see if this is a single constraint, or a list.
1652  if (!consumeIf(Token::l_square))
1653  return parseSingleConstraint();
1654 
1655  do {
1656  if (failed(parseSingleConstraint()))
1657  return failure();
1658  } while (consumeIf(Token::comma));
1659  return parseToken(Token::r_square, "expected `]` after constraint list");
1660 }
1661 
1662 FailureOr<ast::ConstraintRef>
1663 Parser::parseConstraint(std::optional<SMRange> &typeConstraint,
1664  ArrayRef<ast::ConstraintRef> existingConstraints,
1665  bool allowInlineTypeConstraints) {
1666  auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
1667  if (!allowInlineTypeConstraints) {
1668  return emitError(
1669  curToken.getLoc(),
1670  "inline `Attr`, `Value`, and `ValueRange` type constraints are not "
1671  "permitted on arguments or results");
1672  }
1673  if (typeConstraint)
1674  return emitErrorAndNote(
1675  curToken.getLoc(),
1676  "the type of this variable has already been constrained",
1677  *typeConstraint, "see previous constraint location here");
1678  FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr();
1679  if (failed(constraintExpr))
1680  return failure();
1681  typeExpr = *constraintExpr;
1682  typeConstraint = typeExpr->getLoc();
1683  return success();
1684  };
1685 
1686  SMRange loc = curToken.getLoc();
1687  switch (curToken.getKind()) {
1688  case Token::kw_Attr: {
1689  consumeToken(Token::kw_Attr);
1690 
1691  // Check for a type constraint.
1692  ast::Expr *typeExpr = nullptr;
1693  if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1694  return failure();
1695  return ast::ConstraintRef(
1696  ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc);
1697  }
1698  case Token::kw_Op: {
1699  consumeToken(Token::kw_Op);
1700 
1701  // Parse an optional operation name. If the name isn't provided, this refers
1702  // to "any" operation.
1703  FailureOr<ast::OpNameDecl *> opName =
1704  parseWrappedOperationName(/*allowEmptyName=*/true);
1705  if (failed(opName))
1706  return failure();
1707 
1708  return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName),
1709  loc);
1710  }
1711  case Token::kw_Type:
1712  consumeToken(Token::kw_Type);
1713  return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc);
1714  case Token::kw_TypeRange:
1715  consumeToken(Token::kw_TypeRange);
1717  loc);
1718  case Token::kw_Value: {
1719  consumeToken(Token::kw_Value);
1720 
1721  // Check for a type constraint.
1722  ast::Expr *typeExpr = nullptr;
1723  if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1724  return failure();
1725 
1726  return ast::ConstraintRef(
1727  ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc);
1728  }
1729  case Token::kw_ValueRange: {
1730  consumeToken(Token::kw_ValueRange);
1731 
1732  // Check for a type constraint.
1733  ast::Expr *typeExpr = nullptr;
1734  if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1735  return failure();
1736 
1737  return ast::ConstraintRef(
1738  ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc);
1739  }
1740 
1741  case Token::kw_Constraint: {
1742  // Handle an inline constraint.
1743  FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1744  if (failed(decl))
1745  return failure();
1746  return ast::ConstraintRef(*decl, loc);
1747  }
1748  case Token::identifier: {
1749  StringRef constraintName = curToken.getSpelling();
1750  consumeToken(Token::identifier);
1751 
1752  // Lookup the referenced constraint.
1753  ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName);
1754  if (!cstDecl) {
1755  return emitError(loc, "unknown reference to constraint `" +
1756  constraintName + "`");
1757  }
1758 
1759  // Handle a reference to a proper constraint.
1760  if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl))
1761  return ast::ConstraintRef(cst, loc);
1762 
1763  return emitErrorAndNote(
1764  loc, "invalid reference to non-constraint", cstDecl->getLoc(),
1765  "see the definition of `" + constraintName + "` here");
1766  }
1767  // Handle single entity constraint code completion.
1768  case Token::code_complete: {
1769  // Try to infer the current type for use by code completion.
1770  ast::Type inferredType;
1771  if (failed(validateVariableConstraints(existingConstraints, inferredType)))
1772  return failure();
1773 
1774  return codeCompleteConstraintName(inferredType, allowInlineTypeConstraints);
1775  }
1776  default:
1777  break;
1778  }
1779  return emitError(loc, "expected identifier constraint");
1780 }
1781 
1782 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() {
1783  std::optional<SMRange> typeConstraint;
1784  return parseConstraint(typeConstraint, /*existingConstraints=*/std::nullopt,
1785  /*allowInlineTypeConstraints=*/false);
1786 }
1787 
1788 //===----------------------------------------------------------------------===//
1789 // Exprs
1790 
1791 FailureOr<ast::Expr *> Parser::parseExpr() {
1792  if (curToken.is(Token::underscore))
1793  return parseUnderscoreExpr();
1794 
1795  // Parse the LHS expression.
1796  FailureOr<ast::Expr *> lhsExpr;
1797  switch (curToken.getKind()) {
1798  case Token::kw_attr:
1799  lhsExpr = parseAttributeExpr();
1800  break;
1801  case Token::kw_Constraint:
1802  lhsExpr = parseInlineConstraintLambdaExpr();
1803  break;
1804  case Token::kw_not:
1805  lhsExpr = parseNegatedExpr();
1806  break;
1807  case Token::identifier:
1808  lhsExpr = parseIdentifierExpr();
1809  break;
1810  case Token::kw_op:
1811  lhsExpr = parseOperationExpr();
1812  break;
1813  case Token::kw_Rewrite:
1814  lhsExpr = parseInlineRewriteLambdaExpr();
1815  break;
1816  case Token::kw_type:
1817  lhsExpr = parseTypeExpr();
1818  break;
1819  case Token::l_paren:
1820  lhsExpr = parseTupleExpr();
1821  break;
1822  default:
1823  return emitError("expected expression");
1824  }
1825  if (failed(lhsExpr))
1826  return failure();
1827 
1828  // Check for an operator expression.
1829  while (true) {
1830  switch (curToken.getKind()) {
1831  case Token::dot:
1832  lhsExpr = parseMemberAccessExpr(*lhsExpr);
1833  break;
1834  case Token::l_paren:
1835  lhsExpr = parseCallExpr(*lhsExpr);
1836  break;
1837  default:
1838  return lhsExpr;
1839  }
1840  if (failed(lhsExpr))
1841  return failure();
1842  }
1843 }
1844 
1845 FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
1846  SMRange loc = curToken.getLoc();
1847  consumeToken(Token::kw_attr);
1848 
1849  // If we aren't followed by a `<`, the `attr` keyword is treated as a normal
1850  // identifier.
1851  if (!consumeIf(Token::less)) {
1852  resetToken(loc);
1853  return parseIdentifierExpr();
1854  }
1855 
1856  if (!curToken.isString())
1857  return emitError("expected string literal containing MLIR attribute");
1858  std::string attrExpr = curToken.getStringValue();
1859  consumeToken();
1860 
1861  loc.End = curToken.getEndLoc();
1862  if (failed(
1863  parseToken(Token::greater, "expected `>` after attribute literal")))
1864  return failure();
1865  return ast::AttributeExpr::create(ctx, loc, attrExpr);
1866 }
1867 
1868 FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr,
1869  bool isNegated) {
1870  consumeToken(Token::l_paren);
1871 
1872  // Parse the arguments of the call.
1873  SmallVector<ast::Expr *> arguments;
1874  if (curToken.isNot(Token::r_paren)) {
1875  do {
1876  // Handle code completion for the call arguments.
1877  if (curToken.is(Token::code_complete)) {
1878  codeCompleteCallSignature(parentExpr, arguments.size());
1879  return failure();
1880  }
1881 
1882  FailureOr<ast::Expr *> argument = parseExpr();
1883  if (failed(argument))
1884  return failure();
1885  arguments.push_back(*argument);
1886  } while (consumeIf(Token::comma));
1887  }
1888 
1889  SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1890  if (failed(parseToken(Token::r_paren, "expected `)` after argument list")))
1891  return failure();
1892 
1893  return createCallExpr(loc, parentExpr, arguments, isNegated);
1894 }
1895 
1896 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
1897  ast::Decl *decl = curDeclScope->lookup(name);
1898  if (!decl)
1899  return emitError(loc, "undefined reference to `" + name + "`");
1900 
1901  return createDeclRefExpr(loc, decl);
1902 }
1903 
1904 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
1905  StringRef name = curToken.getSpelling();
1906  SMRange nameLoc = curToken.getLoc();
1907  consumeToken();
1908 
1909  // Check to see if this is a decl ref expression that defines a variable
1910  // inline.
1911  if (consumeIf(Token::colon)) {
1912  SmallVector<ast::ConstraintRef> constraints;
1913  if (failed(parseVariableDeclConstraintList(constraints)))
1914  return failure();
1915  ast::Type type;
1916  if (failed(validateVariableConstraints(constraints, type)))
1917  return failure();
1918  return createInlineVariableExpr(type, name, nameLoc, constraints);
1919  }
1920 
1921  return parseDeclRefExpr(name, nameLoc);
1922 }
1923 
1924 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() {
1925  FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1926  if (failed(decl))
1927  return failure();
1928 
1929  return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1931 }
1932 
1933 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() {
1934  FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl();
1935  if (failed(decl))
1936  return failure();
1937 
1938  return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1939  ast::RewriteType::get(ctx));
1940 }
1941 
1942 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
1943  SMRange dotLoc = curToken.getLoc();
1944  consumeToken(Token::dot);
1945 
1946  // Check for code completion of the member name.
1947  if (curToken.is(Token::code_complete))
1948  return codeCompleteMemberAccess(parentExpr);
1949 
1950  // Parse the member name.
1951  Token memberNameTok = curToken;
1952  if (memberNameTok.isNot(Token::identifier, Token::integer) &&
1953  !memberNameTok.isKeyword())
1954  return emitError(dotLoc, "expected identifier or numeric member name");
1955  StringRef memberName = memberNameTok.getSpelling();
1956  SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1957  consumeToken();
1958 
1959  return createMemberAccessExpr(parentExpr, memberName, loc);
1960 }
1961 
1962 FailureOr<ast::Expr *> Parser::parseNegatedExpr() {
1963  consumeToken(Token::kw_not);
1964  // Only native constraints are supported after negation
1965  if (!curToken.is(Token::identifier))
1966  return emitError("expected native constraint");
1967  FailureOr<ast::Expr *> identifierExpr = parseIdentifierExpr();
1968  if (failed(identifierExpr))
1969  return failure();
1970  if (!curToken.is(Token::l_paren))
1971  return emitError("expected `(` after function name");
1972  return parseCallExpr(*identifierExpr, /*isNegated = */ true);
1973 }
1974 
1975 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
1976  SMRange loc = curToken.getLoc();
1977 
1978  // Check for code completion for the dialect name.
1979  if (curToken.is(Token::code_complete))
1980  return codeCompleteDialectName();
1981 
1982  // Handle the case of an no operation name.
1983  if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) {
1984  if (allowEmptyName)
1985  return ast::OpNameDecl::create(ctx, SMRange());
1986  return emitError("expected dialect namespace");
1987  }
1988  StringRef name = curToken.getSpelling();
1989  consumeToken();
1990 
1991  // Otherwise, this is a literal operation name.
1992  if (failed(parseToken(Token::dot, "expected `.` after dialect namespace")))
1993  return failure();
1994 
1995  // Check for code completion for the operation name.
1996  if (curToken.is(Token::code_complete))
1997  return codeCompleteOperationName(name);
1998 
1999  if (curToken.isNot(Token::identifier) && !curToken.isKeyword())
2000  return emitError("expected operation name after dialect namespace");
2001 
2002  name = StringRef(name.data(), name.size() + 1);
2003  do {
2004  name = StringRef(name.data(), name.size() + curToken.getSpelling().size());
2005  loc.End = curToken.getEndLoc();
2006  consumeToken();
2007  } while (curToken.isAny(Token::identifier, Token::dot) ||
2008  curToken.isKeyword());
2009  return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc));
2010 }
2011 
2012 FailureOr<ast::OpNameDecl *>
2013 Parser::parseWrappedOperationName(bool allowEmptyName) {
2014  if (!consumeIf(Token::less))
2015  return ast::OpNameDecl::create(ctx, SMRange());
2016 
2017  FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName);
2018  if (failed(opNameDecl))
2019  return failure();
2020 
2021  if (failed(parseToken(Token::greater, "expected `>` after operation name")))
2022  return failure();
2023  return opNameDecl;
2024 }
2025 
2026 FailureOr<ast::Expr *>
2027 Parser::parseOperationExpr(OpResultTypeContext inputResultTypeContext) {
2028  SMRange loc = curToken.getLoc();
2029  consumeToken(Token::kw_op);
2030 
2031  // If it isn't followed by a `<`, the `op` keyword is treated as a normal
2032  // identifier.
2033  if (curToken.isNot(Token::less)) {
2034  resetToken(loc);
2035  return parseIdentifierExpr();
2036  }
2037 
2038  // Parse the operation name. The name may be elided, in which case the
2039  // operation refers to "any" operation(i.e. a difference between `MyOp` and
2040  // `Operation*`). Operation names within a rewrite context must be named.
2041  bool allowEmptyName = parserContext != ParserContext::Rewrite;
2042  FailureOr<ast::OpNameDecl *> opNameDecl =
2043  parseWrappedOperationName(allowEmptyName);
2044  if (failed(opNameDecl))
2045  return failure();
2046  std::optional<StringRef> opName = (*opNameDecl)->getName();
2047 
2048  // Functor used to create an implicit range variable, used for implicit "all"
2049  // operand or results variables.
2050  auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) {
2051  FailureOr<ast::VariableDecl *> rangeVar =
2052  defineVariableDecl("_", loc, type, ast::ConstraintRef(cst, loc));
2053  assert(succeeded(rangeVar) && "expected range variable to be valid");
2054  return ast::DeclRefExpr::create(ctx, loc, *rangeVar, type);
2055  };
2056 
2057  // Check for the optional list of operands.
2058  SmallVector<ast::Expr *> operands;
2059  if (!consumeIf(Token::l_paren)) {
2060  // If the operand list isn't specified and we are in a match context, define
2061  // an inplace unconstrained operand range corresponding to all of the
2062  // operands of the operation. This avoids treating zero operands the same
2063  // way as "unconstrained operands".
2064  if (parserContext != ParserContext::Rewrite) {
2065  operands.push_back(createImplicitRangeVar(
2066  ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy));
2067  }
2068  } else if (!consumeIf(Token::r_paren)) {
2069  // If the operand list was specified and non-empty, parse the operands.
2070  do {
2071  // Check for operand signature code completion.
2072  if (curToken.is(Token::code_complete)) {
2073  codeCompleteOperationOperandsSignature(opName, operands.size());
2074  return failure();
2075  }
2076 
2077  FailureOr<ast::Expr *> operand = parseExpr();
2078  if (failed(operand))
2079  return failure();
2080  operands.push_back(*operand);
2081  } while (consumeIf(Token::comma));
2082 
2083  if (failed(parseToken(Token::r_paren,
2084  "expected `)` after operation operand list")))
2085  return failure();
2086  }
2087 
2088  // Check for the optional list of attributes.
2090  if (consumeIf(Token::l_brace)) {
2091  do {
2092  FailureOr<ast::NamedAttributeDecl *> decl =
2093  parseNamedAttributeDecl(opName);
2094  if (failed(decl))
2095  return failure();
2096  attributes.emplace_back(*decl);
2097  } while (consumeIf(Token::comma));
2098 
2099  if (failed(parseToken(Token::r_brace,
2100  "expected `}` after operation attribute list")))
2101  return failure();
2102  }
2103 
2104  // Handle the result types of the operation.
2105  SmallVector<ast::Expr *> resultTypes;
2106  OpResultTypeContext resultTypeContext = inputResultTypeContext;
2107 
2108  // Check for an explicit list of result types.
2109  if (consumeIf(Token::arrow)) {
2110  if (failed(parseToken(Token::l_paren,
2111  "expected `(` before operation result type list")))
2112  return failure();
2113 
2114  // If result types are provided, initially assume that the operation does
2115  // not rely on type inferrence. We don't assert that it isn't, because we
2116  // may be inferring the value of some type/type range variables, but given
2117  // that these variables may be defined in calls we can't always discern when
2118  // this is the case.
2119  resultTypeContext = OpResultTypeContext::Explicit;
2120 
2121  // Handle the case of an empty result list.
2122  if (!consumeIf(Token::r_paren)) {
2123  do {
2124  // Check for result signature code completion.
2125  if (curToken.is(Token::code_complete)) {
2126  codeCompleteOperationResultsSignature(opName, resultTypes.size());
2127  return failure();
2128  }
2129 
2130  FailureOr<ast::Expr *> resultTypeExpr = parseExpr();
2131  if (failed(resultTypeExpr))
2132  return failure();
2133  resultTypes.push_back(*resultTypeExpr);
2134  } while (consumeIf(Token::comma));
2135 
2136  if (failed(parseToken(Token::r_paren,
2137  "expected `)` after operation result type list")))
2138  return failure();
2139  }
2140  } else if (parserContext != ParserContext::Rewrite) {
2141  // If the result list isn't specified and we are in a match context, define
2142  // an inplace unconstrained result range corresponding to all of the results
2143  // of the operation. This avoids treating zero results the same way as
2144  // "unconstrained results".
2145  resultTypes.push_back(createImplicitRangeVar(
2146  ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy));
2147  } else if (resultTypeContext == OpResultTypeContext::Explicit) {
2148  // If the result list isn't specified and we are in a rewrite, try to infer
2149  // them at runtime instead.
2150  resultTypeContext = OpResultTypeContext::Interface;
2151  }
2152 
2153  return createOperationExpr(loc, *opNameDecl, resultTypeContext, operands,
2154  attributes, resultTypes);
2155 }
2156 
2157 FailureOr<ast::Expr *> Parser::parseTupleExpr() {
2158  SMRange loc = curToken.getLoc();
2159  consumeToken(Token::l_paren);
2160 
2161  DenseMap<StringRef, SMRange> usedNames;
2162  SmallVector<StringRef> elementNames;
2163  SmallVector<ast::Expr *> elements;
2164  if (curToken.isNot(Token::r_paren)) {
2165  do {
2166  // Check for the optional element name assignment before the value.
2167  StringRef elementName;
2168  if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
2169  Token elementNameTok = curToken;
2170  consumeToken();
2171 
2172  // The element name is only present if followed by an `=`.
2173  if (consumeIf(Token::equal)) {
2174  elementName = elementNameTok.getSpelling();
2175 
2176  // Check to see if this name is already used.
2177  auto elementNameIt =
2178  usedNames.try_emplace(elementName, elementNameTok.getLoc());
2179  if (!elementNameIt.second) {
2180  return emitErrorAndNote(
2181  elementNameTok.getLoc(),
2182  llvm::formatv("duplicate tuple element label `{0}`",
2183  elementName),
2184  elementNameIt.first->getSecond(),
2185  "see previous label use here");
2186  }
2187  } else {
2188  // Otherwise, we treat this as part of an expression so reset the
2189  // lexer.
2190  resetToken(elementNameTok.getLoc());
2191  }
2192  }
2193  elementNames.push_back(elementName);
2194 
2195  // Parse the tuple element value.
2196  FailureOr<ast::Expr *> element = parseExpr();
2197  if (failed(element))
2198  return failure();
2199  elements.push_back(*element);
2200  } while (consumeIf(Token::comma));
2201  }
2202  loc.End = curToken.getEndLoc();
2203  if (failed(
2204  parseToken(Token::r_paren, "expected `)` after tuple element list")))
2205  return failure();
2206  return createTupleExpr(loc, elements, elementNames);
2207 }
2208 
2209 FailureOr<ast::Expr *> Parser::parseTypeExpr() {
2210  SMRange loc = curToken.getLoc();
2211  consumeToken(Token::kw_type);
2212 
2213  // If we aren't followed by a `<`, the `type` keyword is treated as a normal
2214  // identifier.
2215  if (!consumeIf(Token::less)) {
2216  resetToken(loc);
2217  return parseIdentifierExpr();
2218  }
2219 
2220  if (!curToken.isString())
2221  return emitError("expected string literal containing MLIR type");
2222  std::string attrExpr = curToken.getStringValue();
2223  consumeToken();
2224 
2225  loc.End = curToken.getEndLoc();
2226  if (failed(parseToken(Token::greater, "expected `>` after type literal")))
2227  return failure();
2228  return ast::TypeExpr::create(ctx, loc, attrExpr);
2229 }
2230 
2231 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
2232  StringRef name = curToken.getSpelling();
2233  SMRange nameLoc = curToken.getLoc();
2234  consumeToken(Token::underscore);
2235 
2236  // Underscore expressions require a constraint list.
2237  if (failed(parseToken(Token::colon, "expected `:` after `_` variable")))
2238  return failure();
2239 
2240  // Parse the constraints for the expression.
2241  SmallVector<ast::ConstraintRef> constraints;
2242  if (failed(parseVariableDeclConstraintList(constraints)))
2243  return failure();
2244 
2245  ast::Type type;
2246  if (failed(validateVariableConstraints(constraints, type)))
2247  return failure();
2248  return createInlineVariableExpr(type, name, nameLoc, constraints);
2249 }
2250 
2251 //===----------------------------------------------------------------------===//
2252 // Stmts
2253 
2254 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) {
2255  FailureOr<ast::Stmt *> stmt;
2256  switch (curToken.getKind()) {
2257  case Token::kw_erase:
2258  stmt = parseEraseStmt();
2259  break;
2260  case Token::kw_let:
2261  stmt = parseLetStmt();
2262  break;
2263  case Token::kw_replace:
2264  stmt = parseReplaceStmt();
2265  break;
2266  case Token::kw_return:
2267  stmt = parseReturnStmt();
2268  break;
2269  case Token::kw_rewrite:
2270  stmt = parseRewriteStmt();
2271  break;
2272  default:
2273  stmt = parseExpr();
2274  break;
2275  }
2276  if (failed(stmt) ||
2277  (expectTerminalSemicolon &&
2278  failed(parseToken(Token::semicolon, "expected `;` after statement"))))
2279  return failure();
2280  return stmt;
2281 }
2282 
2283 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() {
2284  SMLoc startLoc = curToken.getStartLoc();
2285  consumeToken(Token::l_brace);
2286 
2287  // Push a new block scope and parse any nested statements.
2288  pushDeclScope();
2289  SmallVector<ast::Stmt *> statements;
2290  while (curToken.isNot(Token::r_brace)) {
2291  FailureOr<ast::Stmt *> statement = parseStmt();
2292  if (failed(statement))
2293  return popDeclScope(), failure();
2294  statements.push_back(*statement);
2295  }
2296  popDeclScope();
2297 
2298  // Consume the end brace.
2299  SMRange location(startLoc, curToken.getEndLoc());
2300  consumeToken(Token::r_brace);
2301 
2302  return ast::CompoundStmt::create(ctx, location, statements);
2303 }
2304 
2305 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
2306  if (parserContext == ParserContext::Constraint)
2307  return emitError("`erase` cannot be used within a Constraint");
2308  SMRange loc = curToken.getLoc();
2309  consumeToken(Token::kw_erase);
2310 
2311  // Parse the root operation expression.
2312  FailureOr<ast::Expr *> rootOp = parseExpr();
2313  if (failed(rootOp))
2314  return failure();
2315 
2316  return createEraseStmt(loc, *rootOp);
2317 }
2318 
2319 FailureOr<ast::LetStmt *> Parser::parseLetStmt() {
2320  SMRange loc = curToken.getLoc();
2321  consumeToken(Token::kw_let);
2322 
2323  // Parse the name of the new variable.
2324  SMRange varLoc = curToken.getLoc();
2325  if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) {
2326  // `_` is a reserved variable name.
2327  if (curToken.is(Token::underscore)) {
2328  return emitError(varLoc,
2329  "`_` may only be used to define \"inline\" variables");
2330  }
2331  return emitError(varLoc,
2332  "expected identifier after `let` to name a new variable");
2333  }
2334  StringRef varName = curToken.getSpelling();
2335  consumeToken();
2336 
2337  // Parse the optional set of constraints.
2338  SmallVector<ast::ConstraintRef> constraints;
2339  if (consumeIf(Token::colon) &&
2340  failed(parseVariableDeclConstraintList(constraints)))
2341  return failure();
2342 
2343  // Parse the optional initializer expression.
2344  ast::Expr *initializer = nullptr;
2345  if (consumeIf(Token::equal)) {
2346  FailureOr<ast::Expr *> initOrFailure = parseExpr();
2347  if (failed(initOrFailure))
2348  return failure();
2349  initializer = *initOrFailure;
2350 
2351  // Check that the constraints are compatible with having an initializer,
2352  // e.g. type constraints cannot be used with initializers.
2353  for (ast::ConstraintRef constraint : constraints) {
2354  LogicalResult result =
2355  TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint)
2357  ast::ValueRangeConstraintDecl>([&](const auto *cst) {
2358  if (cst->getTypeExpr()) {
2359  return this->emitError(
2360  constraint.referenceLoc,
2361  "type constraints are not permitted on variables with "
2362  "initializers");
2363  }
2364  return success();
2365  })
2366  .Default(success());
2367  if (failed(result))
2368  return failure();
2369  }
2370  }
2371 
2372  FailureOr<ast::VariableDecl *> varDecl =
2373  createVariableDecl(varName, varLoc, initializer, constraints);
2374  if (failed(varDecl))
2375  return failure();
2376  return ast::LetStmt::create(ctx, loc, *varDecl);
2377 }
2378 
2379 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
2380  if (parserContext == ParserContext::Constraint)
2381  return emitError("`replace` cannot be used within a Constraint");
2382  SMRange loc = curToken.getLoc();
2383  consumeToken(Token::kw_replace);
2384 
2385  // Parse the root operation expression.
2386  FailureOr<ast::Expr *> rootOp = parseExpr();
2387  if (failed(rootOp))
2388  return failure();
2389 
2390  if (failed(
2391  parseToken(Token::kw_with, "expected `with` after root operation")))
2392  return failure();
2393 
2394  // The replacement portion of this statement is within a rewrite context.
2395  llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
2396 
2397  // Parse the replacement values.
2398  SmallVector<ast::Expr *> replValues;
2399  if (consumeIf(Token::l_paren)) {
2400  if (consumeIf(Token::r_paren)) {
2401  return emitError(
2402  loc, "expected at least one replacement value, consider using "
2403  "`erase` if no replacement values are desired");
2404  }
2405 
2406  do {
2407  FailureOr<ast::Expr *> replExpr = parseExpr();
2408  if (failed(replExpr))
2409  return failure();
2410  replValues.emplace_back(*replExpr);
2411  } while (consumeIf(Token::comma));
2412 
2413  if (failed(parseToken(Token::r_paren,
2414  "expected `)` after replacement values")))
2415  return failure();
2416  } else {
2417  // Handle replacement with an operation uniquely, as the replacement
2418  // operation supports type inferrence from the root operation.
2419  FailureOr<ast::Expr *> replExpr;
2420  if (curToken.is(Token::kw_op))
2421  replExpr = parseOperationExpr(OpResultTypeContext::Replacement);
2422  else
2423  replExpr = parseExpr();
2424  if (failed(replExpr))
2425  return failure();
2426  replValues.emplace_back(*replExpr);
2427  }
2428 
2429  return createReplaceStmt(loc, *rootOp, replValues);
2430 }
2431 
2432 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() {
2433  SMRange loc = curToken.getLoc();
2434  consumeToken(Token::kw_return);
2435 
2436  // Parse the result value.
2437  FailureOr<ast::Expr *> resultExpr = parseExpr();
2438  if (failed(resultExpr))
2439  return failure();
2440 
2441  return ast::ReturnStmt::create(ctx, loc, *resultExpr);
2442 }
2443 
2444 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
2445  if (parserContext == ParserContext::Constraint)
2446  return emitError("`rewrite` cannot be used within a Constraint");
2447  SMRange loc = curToken.getLoc();
2448  consumeToken(Token::kw_rewrite);
2449 
2450  // Parse the root operation.
2451  FailureOr<ast::Expr *> rootOp = parseExpr();
2452  if (failed(rootOp))
2453  return failure();
2454 
2455  if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body")))
2456  return failure();
2457 
2458  if (curToken.isNot(Token::l_brace))
2459  return emitError("expected `{` to start rewrite body");
2460 
2461  // The rewrite body of this statement is within a rewrite context.
2462  llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
2463 
2464  FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt();
2465  if (failed(rewriteBody))
2466  return failure();
2467 
2468  // Verify the rewrite body.
2469  for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) {
2470  if (isa<ast::ReturnStmt>(stmt)) {
2471  return emitError(stmt->getLoc(),
2472  "`return` statements are only permitted within a "
2473  "`Constraint` or `Rewrite` body");
2474  }
2475  }
2476 
2477  return createRewriteStmt(loc, *rootOp, *rewriteBody);
2478 }
2479 
2480 //===----------------------------------------------------------------------===//
2481 // Creation+Analysis
2482 //===----------------------------------------------------------------------===//
2483 
2484 //===----------------------------------------------------------------------===//
2485 // Decls
2486 
2487 ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) {
2488  // Unwrap reference expressions.
2489  if (auto *init = dyn_cast<ast::DeclRefExpr>(node))
2490  node = init->getDecl();
2491  return dyn_cast<ast::CallableDecl>(node);
2492 }
2493 
2494 FailureOr<ast::PatternDecl *>
2495 Parser::createPatternDecl(SMRange loc, const ast::Name *name,
2496  const ParsedPatternMetadata &metadata,
2497  ast::CompoundStmt *body) {
2498  return ast::PatternDecl::create(ctx, loc, name, metadata.benefit,
2499  metadata.hasBoundedRecursion, body);
2500 }
2501 
2502 ast::Type Parser::createUserConstraintRewriteResultType(
2504  // Single result decls use the type of the single result.
2505  if (results.size() == 1)
2506  return results[0]->getType();
2507 
2508  // Multiple results use a tuple type, with the types and names grabbed from
2509  // the result variable decls.
2510  auto resultTypes = llvm::map_range(
2511  results, [&](const auto *result) { return result->getType(); });
2512  auto resultNames = llvm::map_range(
2513  results, [&](const auto *result) { return result->getName().getName(); });
2514  return ast::TupleType::get(ctx, llvm::to_vector(resultTypes),
2515  llvm::to_vector(resultNames));
2516 }
2517 
2518 template <typename T>
2519 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl(
2520  const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
2521  ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
2522  ast::CompoundStmt *body) {
2523  if (!body->getChildren().empty()) {
2524  if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) {
2525  ast::Expr *resultExpr = retStmt->getResultExpr();
2526 
2527  // Process the result of the decl. If no explicit signature results
2528  // were provided, check for return type inference. Otherwise, check that
2529  // the return expression can be converted to the expected type.
2530  if (results.empty())
2531  resultType = resultExpr->getType();
2532  else if (failed(convertExpressionTo(resultExpr, resultType)))
2533  return failure();
2534  else
2535  retStmt->setResultExpr(resultExpr);
2536  }
2537  }
2538  return T::createPDLL(ctx, name, arguments, results, body, resultType);
2539 }
2540 
2541 FailureOr<ast::VariableDecl *>
2542 Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
2543  ArrayRef<ast::ConstraintRef> constraints) {
2544  // The type of the variable, which is expected to be inferred by either a
2545  // constraint or an initializer expression.
2546  ast::Type type;
2547  if (failed(validateVariableConstraints(constraints, type)))
2548  return failure();
2549 
2550  if (initializer) {
2551  // Update the variable type based on the initializer, or try to convert the
2552  // initializer to the existing type.
2553  if (!type)
2554  type = initializer->getType();
2555  else if (ast::Type mergedType = type.refineWith(initializer->getType()))
2556  type = mergedType;
2557  else if (failed(convertExpressionTo(initializer, type)))
2558  return failure();
2559 
2560  // Otherwise, if there is no initializer check that the type has already
2561  // been resolved from the constraint list.
2562  } else if (!type) {
2563  return emitErrorAndNote(
2564  loc, "unable to infer type for variable `" + name + "`", loc,
2565  "the type of a variable must be inferable from the constraint "
2566  "list or the initializer");
2567  }
2568 
2569  // Constraint types cannot be used when defining variables.
2570  if (isa<ast::ConstraintType, ast::RewriteType>(type)) {
2571  return emitError(
2572  loc, llvm::formatv("unable to define variable of `{0}` type", type));
2573  }
2574 
2575  // Try to define a variable with the given name.
2576  FailureOr<ast::VariableDecl *> varDecl =
2577  defineVariableDecl(name, loc, type, initializer, constraints);
2578  if (failed(varDecl))
2579  return failure();
2580 
2581  return *varDecl;
2582 }
2583 
2584 FailureOr<ast::VariableDecl *>
2585 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
2586  const ast::ConstraintRef &constraint) {
2587  ast::Type argType;
2588  if (failed(validateVariableConstraint(constraint, argType)))
2589  return failure();
2590  return defineVariableDecl(name, loc, argType, constraint);
2591 }
2592 
2593 LogicalResult
2594 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
2595  ast::Type &inferredType) {
2596  for (const ast::ConstraintRef &ref : constraints)
2597  if (failed(validateVariableConstraint(ref, inferredType)))
2598  return failure();
2599  return success();
2600 }
2601 
2602 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
2603  ast::Type &inferredType) {
2604  ast::Type constraintType;
2605  if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) {
2606  if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2607  if (failed(validateTypeConstraintExpr(typeExpr)))
2608  return failure();
2609  }
2610  constraintType = ast::AttributeType::get(ctx);
2611  } else if (const auto *cst =
2612  dyn_cast<ast::OpConstraintDecl>(ref.constraint)) {
2613  constraintType = ast::OperationType::get(
2614  ctx, cst->getName(), lookupODSOperation(cst->getName()));
2615  } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) {
2616  constraintType = typeTy;
2617  } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) {
2618  constraintType = typeRangeTy;
2619  } else if (const auto *cst =
2620  dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) {
2621  if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2622  if (failed(validateTypeConstraintExpr(typeExpr)))
2623  return failure();
2624  }
2625  constraintType = valueTy;
2626  } else if (const auto *cst =
2627  dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) {
2628  if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2629  if (failed(validateTypeRangeConstraintExpr(typeExpr)))
2630  return failure();
2631  }
2632  constraintType = valueRangeTy;
2633  } else if (const auto *cst =
2634  dyn_cast<ast::UserConstraintDecl>(ref.constraint)) {
2635  ArrayRef<ast::VariableDecl *> inputs = cst->getInputs();
2636  if (inputs.size() != 1) {
2637  return emitErrorAndNote(ref.referenceLoc,
2638  "`Constraint`s applied via a variable constraint "
2639  "list must take a single input, but got " +
2640  Twine(inputs.size()),
2641  cst->getLoc(),
2642  "see definition of constraint here");
2643  }
2644  constraintType = inputs.front()->getType();
2645  } else {
2646  llvm_unreachable("unknown constraint type");
2647  }
2648 
2649  // Check that the constraint type is compatible with the current inferred
2650  // type.
2651  if (!inferredType) {
2652  inferredType = constraintType;
2653  } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) {
2654  inferredType = mergedTy;
2655  } else {
2656  return emitError(ref.referenceLoc,
2657  llvm::formatv("constraint type `{0}` is incompatible "
2658  "with the previously inferred type `{1}`",
2659  constraintType, inferredType));
2660  }
2661  return success();
2662 }
2663 
2664 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) {
2665  ast::Type typeExprType = typeExpr->getType();
2666  if (typeExprType != typeTy) {
2667  return emitError(typeExpr->getLoc(),
2668  "expected expression of `Type` in type constraint");
2669  }
2670  return success();
2671 }
2672 
2673 LogicalResult
2674 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) {
2675  ast::Type typeExprType = typeExpr->getType();
2676  if (typeExprType != typeRangeTy) {
2677  return emitError(typeExpr->getLoc(),
2678  "expected expression of `TypeRange` in type constraint");
2679  }
2680  return success();
2681 }
2682 
2683 //===----------------------------------------------------------------------===//
2684 // Exprs
2685 
2686 FailureOr<ast::CallExpr *>
2687 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
2688  MutableArrayRef<ast::Expr *> arguments, bool isNegated) {
2689  ast::Type parentType = parentExpr->getType();
2690 
2691  ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr);
2692  if (!callableDecl) {
2693  return emitError(loc,
2694  llvm::formatv("expected a reference to a callable "
2695  "`Constraint` or `Rewrite`, but got: `{0}`",
2696  parentType));
2697  }
2698  if (parserContext == ParserContext::Rewrite) {
2699  if (isa<ast::UserConstraintDecl>(callableDecl))
2700  return emitError(
2701  loc, "unable to invoke `Constraint` within a rewrite section");
2702  if (isNegated)
2703  return emitError(loc, "unable to negate a Rewrite");
2704  } else {
2705  if (isa<ast::UserRewriteDecl>(callableDecl))
2706  return emitError(loc,
2707  "unable to invoke `Rewrite` within a match section");
2708  if (isNegated && cast<ast::UserConstraintDecl>(callableDecl)->getBody())
2709  return emitError(loc, "unable to negate non native constraints");
2710  }
2711 
2712  // Verify the arguments of the call.
2713  /// Handle size mismatch.
2714  ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs();
2715  if (callArgs.size() != arguments.size()) {
2716  return emitErrorAndNote(
2717  loc,
2718  llvm::formatv("invalid number of arguments for {0} call; expected "
2719  "{1}, but got {2}",
2720  callableDecl->getCallableType(), callArgs.size(),
2721  arguments.size()),
2722  callableDecl->getLoc(),
2723  llvm::formatv("see the definition of {0} here",
2724  callableDecl->getName()->getName()));
2725  }
2726 
2727  /// Handle argument type mismatch.
2728  auto attachDiagFn = [&](ast::Diagnostic &diag) {
2729  diag.attachNote(llvm::formatv("see the definition of `{0}` here",
2730  callableDecl->getName()->getName()),
2731  callableDecl->getLoc());
2732  };
2733  for (auto it : llvm::zip(callArgs, arguments)) {
2734  if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(),
2735  attachDiagFn)))
2736  return failure();
2737  }
2738 
2739  return ast::CallExpr::create(ctx, loc, parentExpr, arguments,
2740  callableDecl->getResultType(), isNegated);
2741 }
2742 
2743 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,
2744  ast::Decl *decl) {
2745  // Check the type of decl being referenced.
2746  ast::Type declType;
2747  if (isa<ast::ConstraintDecl>(decl))
2748  declType = ast::ConstraintType::get(ctx);
2749  else if (isa<ast::UserRewriteDecl>(decl))
2750  declType = ast::RewriteType::get(ctx);
2751  else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
2752  declType = varDecl->getType();
2753  else
2754  return emitError(loc, "invalid reference to `" +
2755  decl->getName()->getName() + "`");
2756 
2757  return ast::DeclRefExpr::create(ctx, loc, decl, declType);
2758 }
2759 
2760 FailureOr<ast::DeclRefExpr *>
2761 Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
2762  ArrayRef<ast::ConstraintRef> constraints) {
2763  FailureOr<ast::VariableDecl *> decl =
2764  defineVariableDecl(name, loc, type, constraints);
2765  if (failed(decl))
2766  return failure();
2767  return ast::DeclRefExpr::create(ctx, loc, *decl, type);
2768 }
2769 
2770 FailureOr<ast::MemberAccessExpr *>
2771 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
2772  SMRange loc) {
2773  // Validate the member name for the given parent expression.
2774  FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc);
2775  if (failed(memberType))
2776  return failure();
2777 
2778  return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType);
2779 }
2780 
2781 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
2782  StringRef name, SMRange loc) {
2783  ast::Type parentType = parentExpr->getType();
2784  if (ast::OperationType opType = dyn_cast<ast::OperationType>(parentType)) {
2786  return valueRangeTy;
2787 
2788  // Verify member access based on the operation type.
2789  if (const ods::Operation *odsOp = opType.getODSOperation()) {
2790  auto results = odsOp->getResults();
2791 
2792  // Handle indexed results.
2793  unsigned index = 0;
2794  if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2795  index < results.size()) {
2796  return results[index].isVariadic() ? valueRangeTy : valueTy;
2797  }
2798 
2799  // Handle named results.
2800  const auto *it = llvm::find_if(results, [&](const auto &result) {
2801  return result.getName() == name;
2802  });
2803  if (it != results.end())
2804  return it->isVariadic() ? valueRangeTy : valueTy;
2805  } else if (llvm::isDigit(name[0])) {
2806  // Allow unchecked numeric indexing of the results of unregistered
2807  // operations. It returns a single value.
2808  return valueTy;
2809  }
2810  } else if (auto tupleType = dyn_cast<ast::TupleType>(parentType)) {
2811  // Handle indexed results.
2812  unsigned index = 0;
2813  if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2814  index < tupleType.size()) {
2815  return tupleType.getElementTypes()[index];
2816  }
2817 
2818  // Handle named results.
2819  auto elementNames = tupleType.getElementNames();
2820  const auto *it = llvm::find(elementNames, name);
2821  if (it != elementNames.end())
2822  return tupleType.getElementTypes()[it - elementNames.begin()];
2823  }
2824  return emitError(
2825  loc,
2826  llvm::formatv("invalid member access `{0}` on expression of type `{1}`",
2827  name, parentType));
2828 }
2829 
2830 FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
2831  SMRange loc, const ast::OpNameDecl *name,
2832  OpResultTypeContext resultTypeContext,
2833  SmallVectorImpl<ast::Expr *> &operands,
2835  SmallVectorImpl<ast::Expr *> &results) {
2836  std::optional<StringRef> opNameRef = name->getName();
2837  const ods::Operation *odsOp = lookupODSOperation(opNameRef);
2838 
2839  // Verify the inputs operands.
2840  if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands)))
2841  return failure();
2842 
2843  // Verify the attribute list.
2844  for (ast::NamedAttributeDecl *attr : attributes) {
2845  // Check for an attribute type, or a type awaiting resolution.
2846  ast::Type attrType = attr->getValue()->getType();
2847  if (!isa<ast::AttributeType>(attrType)) {
2848  return emitError(
2849  attr->getValue()->getLoc(),
2850  llvm::formatv("expected `Attr` expression, but got `{0}`", attrType));
2851  }
2852  }
2853 
2854  assert(
2855  (resultTypeContext == OpResultTypeContext::Explicit || results.empty()) &&
2856  "unexpected inferrence when results were explicitly specified");
2857 
2858  // If we aren't relying on type inferrence, or explicit results were provided,
2859  // validate them.
2860  if (resultTypeContext == OpResultTypeContext::Explicit) {
2861  if (failed(validateOperationResults(loc, opNameRef, odsOp, results)))
2862  return failure();
2863 
2864  // Validate the use of interface based type inferrence for this operation.
2865  } else if (resultTypeContext == OpResultTypeContext::Interface) {
2866  assert(opNameRef &&
2867  "expected valid operation name when inferring operation results");
2868  checkOperationResultTypeInferrence(loc, *opNameRef, odsOp);
2869  }
2870 
2871  return ast::OperationExpr::create(ctx, loc, odsOp, name, operands, results,
2872  attributes);
2873 }
2874 
2875 LogicalResult
2876 Parser::validateOperationOperands(SMRange loc, std::optional<StringRef> name,
2877  const ods::Operation *odsOp,
2878  SmallVectorImpl<ast::Expr *> &operands) {
2879  return validateOperationOperandsOrResults(
2880  "operand", loc, odsOp ? odsOp->getLoc() : std::optional<SMRange>(), name,
2881  operands, odsOp ? odsOp->getOperands() : std::nullopt, valueTy,
2882  valueRangeTy);
2883 }
2884 
2885 LogicalResult
2886 Parser::validateOperationResults(SMRange loc, std::optional<StringRef> name,
2887  const ods::Operation *odsOp,
2888  SmallVectorImpl<ast::Expr *> &results) {
2889  return validateOperationOperandsOrResults(
2890  "result", loc, odsOp ? odsOp->getLoc() : std::optional<SMRange>(), name,
2891  results, odsOp ? odsOp->getResults() : std::nullopt, typeTy, typeRangeTy);
2892 }
2893 
2894 void Parser::checkOperationResultTypeInferrence(SMRange loc, StringRef opName,
2895  const ods::Operation *odsOp) {
2896  // If the operation might not have inferrence support, emit a warning to the
2897  // user. We don't emit an error because the interface might be added to the
2898  // operation at runtime. It's rare, but it could still happen. We emit a
2899  // warning here instead.
2900 
2901  // Handle inferrence warnings for unknown operations.
2902  if (!odsOp) {
2903  ctx.getDiagEngine().emitWarning(
2904  loc, llvm::formatv(
2905  "operation result types are marked to be inferred, but "
2906  "`{0}` is unknown. Ensure that `{0}` supports zero "
2907  "results or implements `InferTypeOpInterface`. Include "
2908  "the ODS definition of this operation to remove this warning.",
2909  opName));
2910  return;
2911  }
2912 
2913  // Handle inferrence warnings for known operations that expected at least one
2914  // result, but don't have inference support. An elided results list can mean
2915  // "zero-results", and we don't want to warn when that is the expected
2916  // behavior.
2917  bool requiresInferrence =
2918  llvm::any_of(odsOp->getResults(), [](const ods::OperandOrResult &result) {
2919  return !result.isVariableLength();
2920  });
2921  if (requiresInferrence && !odsOp->hasResultTypeInferrence()) {
2923  loc,
2924  llvm::formatv("operation result types are marked to be inferred, but "
2925  "`{0}` does not provide an implementation of "
2926  "`InferTypeOpInterface`. Ensure that `{0}` attaches "
2927  "`InferTypeOpInterface` at runtime, or add support to "
2928  "the ODS definition to remove this warning.",
2929  opName));
2930  diag->attachNote(llvm::formatv("see the definition of `{0}` here", opName),
2931  odsOp->getLoc());
2932  return;
2933  }
2934 }
2935 
2936 LogicalResult Parser::validateOperationOperandsOrResults(
2937  StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc,
2938  std::optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values,
2939  ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
2940  ast::RangeType rangeTy) {
2941  // All operation types accept a single range parameter.
2942  if (values.size() == 1) {
2943  if (failed(convertExpressionTo(values[0], rangeTy)))
2944  return failure();
2945  return success();
2946  }
2947 
2948  /// If the operation has ODS information, we can more accurately verify the
2949  /// values.
2950  if (odsOpLoc) {
2951  auto emitSizeMismatchError = [&] {
2952  return emitErrorAndNote(
2953  loc,
2954  llvm::formatv("invalid number of {0} groups for `{1}`; expected "
2955  "{2}, but got {3}",
2956  groupName, *name, odsValues.size(), values.size()),
2957  *odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name));
2958  };
2959 
2960  // Handle the case where no values were provided.
2961  if (values.empty()) {
2962  // If we don't expect any on the ODS side, we are done.
2963  if (odsValues.empty())
2964  return success();
2965 
2966  // If we do, check if we actually need to provide values (i.e. if any of
2967  // the values are actually required).
2968  unsigned numVariadic = 0;
2969  for (const auto &odsValue : odsValues) {
2970  if (!odsValue.isVariableLength())
2971  return emitSizeMismatchError();
2972  ++numVariadic;
2973  }
2974 
2975  // If we are in a non-rewrite context, we don't need to do anything more.
2976  // Zero-values is a valid constraint on the operation.
2977  if (parserContext != ParserContext::Rewrite)
2978  return success();
2979 
2980  // Otherwise, when in a rewrite we may need to provide values to match the
2981  // ODS signature of the operation to create.
2982 
2983  // If we only have one variadic value, just use an empty list.
2984  if (numVariadic == 1)
2985  return success();
2986 
2987  // Otherwise, create dummy values for each of the entries so that we
2988  // adhere to the ODS signature.
2989  for (unsigned i = 0, e = odsValues.size(); i < e; ++i) {
2990  values.push_back(ast::RangeExpr::create(
2991  ctx, loc, /*elements=*/std::nullopt, rangeTy));
2992  }
2993  return success();
2994  }
2995 
2996  // Verify that the number of values provided matches the number of value
2997  // groups ODS expects.
2998  if (odsValues.size() != values.size())
2999  return emitSizeMismatchError();
3000 
3001  auto diagFn = [&](ast::Diagnostic &diag) {
3002  diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name),
3003  *odsOpLoc);
3004  };
3005  for (unsigned i = 0, e = values.size(); i < e; ++i) {
3006  ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy;
3007  if (failed(convertExpressionTo(values[i], expectedType, diagFn)))
3008  return failure();
3009  }
3010  return success();
3011  }
3012 
3013  // Otherwise, accept the value groups as they have been defined and just
3014  // ensure they are one of the expected types.
3015  for (ast::Expr *&valueExpr : values) {
3016  ast::Type valueExprType = valueExpr->getType();
3017 
3018  // Check if this is one of the expected types.
3019  if (valueExprType == rangeTy || valueExprType == singleTy)
3020  continue;
3021 
3022  // If the operand is an Operation, allow converting to a Value or
3023  // ValueRange. This situations arises quite often with nested operation
3024  // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)`
3025  if (singleTy == valueTy) {
3026  if (isa<ast::OperationType>(valueExprType)) {
3027  valueExpr = convertOpToValue(valueExpr);
3028  continue;
3029  }
3030  }
3031 
3032  // Otherwise, try to convert the expression to a range.
3033  if (succeeded(convertExpressionTo(valueExpr, rangeTy)))
3034  continue;
3035 
3036  return emitError(
3037  valueExpr->getLoc(),
3038  llvm::formatv(
3039  "expected `{0}` or `{1}` convertible expression, but got `{2}`",
3040  singleTy, rangeTy, valueExprType));
3041  }
3042  return success();
3043 }
3044 
3045 FailureOr<ast::TupleExpr *>
3046 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements,
3047  ArrayRef<StringRef> elementNames) {
3048  for (const ast::Expr *element : elements) {
3049  ast::Type eleTy = element->getType();
3050  if (isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>(eleTy)) {
3051  return emitError(
3052  element->getLoc(),
3053  llvm::formatv("unable to build a tuple with `{0}` element", eleTy));
3054  }
3055  }
3056  return ast::TupleExpr::create(ctx, loc, elements, elementNames);
3057 }
3058 
3059 //===----------------------------------------------------------------------===//
3060 // Stmts
3061 
3062 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc,
3063  ast::Expr *rootOp) {
3064  // Check that root is an Operation.
3065  ast::Type rootType = rootOp->getType();
3066  if (!isa<ast::OperationType>(rootType))
3067  return emitError(rootOp->getLoc(), "expected `Op` expression");
3068 
3069  return ast::EraseStmt::create(ctx, loc, rootOp);
3070 }
3071 
3072 FailureOr<ast::ReplaceStmt *>
3073 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp,
3074  MutableArrayRef<ast::Expr *> replValues) {
3075  // Check that root is an Operation.
3076  ast::Type rootType = rootOp->getType();
3077  if (!isa<ast::OperationType>(rootType)) {
3078  return emitError(
3079  rootOp->getLoc(),
3080  llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
3081  }
3082 
3083  // If there are multiple replacement values, we implicitly convert any Op
3084  // expressions to the value form.
3085  bool shouldConvertOpToValues = replValues.size() > 1;
3086  for (ast::Expr *&replExpr : replValues) {
3087  ast::Type replType = replExpr->getType();
3088 
3089  // Check that replExpr is an Operation, Value, or ValueRange.
3090  if (isa<ast::OperationType>(replType)) {
3091  if (shouldConvertOpToValues)
3092  replExpr = convertOpToValue(replExpr);
3093  continue;
3094  }
3095 
3096  if (replType != valueTy && replType != valueRangeTy) {
3097  return emitError(replExpr->getLoc(),
3098  llvm::formatv("expected `Op`, `Value` or `ValueRange` "
3099  "expression, but got `{0}`",
3100  replType));
3101  }
3102  }
3103 
3104  return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues);
3105 }
3106 
3107 FailureOr<ast::RewriteStmt *>
3108 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp,
3109  ast::CompoundStmt *rewriteBody) {
3110  // Check that root is an Operation.
3111  ast::Type rootType = rootOp->getType();
3112  if (!isa<ast::OperationType>(rootType)) {
3113  return emitError(
3114  rootOp->getLoc(),
3115  llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
3116  }
3117 
3118  return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody);
3119 }
3120 
3121 //===----------------------------------------------------------------------===//
3122 // Code Completion
3123 //===----------------------------------------------------------------------===//
3124 
3125 LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) {
3126  ast::Type parentType = parentExpr->getType();
3127  if (ast::OperationType opType = dyn_cast<ast::OperationType>(parentType))
3128  codeCompleteContext->codeCompleteOperationMemberAccess(opType);
3129  else if (ast::TupleType tupleType = dyn_cast<ast::TupleType>(parentType))
3130  codeCompleteContext->codeCompleteTupleMemberAccess(tupleType);
3131  return failure();
3132 }
3133 
3134 LogicalResult
3135 Parser::codeCompleteAttributeName(std::optional<StringRef> opName) {
3136  if (opName)
3137  codeCompleteContext->codeCompleteOperationAttributeName(*opName);
3138  return failure();
3139 }
3140 
3141 LogicalResult
3142 Parser::codeCompleteConstraintName(ast::Type inferredType,
3143  bool allowInlineTypeConstraints) {
3144  codeCompleteContext->codeCompleteConstraintName(
3145  inferredType, allowInlineTypeConstraints, curDeclScope);
3146  return failure();
3147 }
3148 
3149 LogicalResult Parser::codeCompleteDialectName() {
3150  codeCompleteContext->codeCompleteDialectName();
3151  return failure();
3152 }
3153 
3154 LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) {
3155  codeCompleteContext->codeCompleteOperationName(dialectName);
3156  return failure();
3157 }
3158 
3159 LogicalResult Parser::codeCompletePatternMetadata() {
3160  codeCompleteContext->codeCompletePatternMetadata();
3161  return failure();
3162 }
3163 
3164 LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) {
3165  codeCompleteContext->codeCompleteIncludeFilename(curPath);
3166  return failure();
3167 }
3168 
3169 void Parser::codeCompleteCallSignature(ast::Node *parent,
3170  unsigned currentNumArgs) {
3171  ast::CallableDecl *callableDecl = tryExtractCallableDecl(parent);
3172  if (!callableDecl)
3173  return;
3174 
3175  codeCompleteContext->codeCompleteCallSignature(callableDecl, currentNumArgs);
3176 }
3177 
3178 void Parser::codeCompleteOperationOperandsSignature(
3179  std::optional<StringRef> opName, unsigned currentNumOperands) {
3180  codeCompleteContext->codeCompleteOperationOperandsSignature(
3181  opName, currentNumOperands);
3182 }
3183 
3184 void Parser::codeCompleteOperationResultsSignature(
3185  std::optional<StringRef> opName, unsigned currentNumResults) {
3186  codeCompleteContext->codeCompleteOperationResultsSignature(opName,
3187  currentNumResults);
3188 }
3189 
3190 //===----------------------------------------------------------------------===//
3191 // Parser
3192 //===----------------------------------------------------------------------===//
3193 
3194 FailureOr<ast::Module *>
3195 mlir::pdll::parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
3196  bool enableDocumentation,
3197  CodeCompleteContext *codeCompleteContext) {
3198  Parser parser(ctx, sourceMgr, enableDocumentation, codeCompleteContext);
3199  return parser.parseModule();
3200 }
static std::string diag(const llvm::Value &value)
This class breaks up the current file into a token stream.
Definition: Lexer.h:23
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
This represents a token in the MLIR syntax.
Definition: Token.h:20
bool isKeyword() const
Return true if this is one of the keyword token kinds (e.g. kw_if).
Definition: Token.cpp:192
SMLoc getLoc() const
Definition: Token.cpp:24
bool isNot(Kind k) const
Definition: Token.h:50
StringRef getSpelling() const
Definition: Token.h:34
This class provides an abstract interface into the parser for hooking in code completion events.
Definition: CodeComplete.h:30
virtual void codeCompleteConstraintName(ast::Type currentType, bool allowInlineTypeConstraints, const ast::DeclScope *scope)
Signal code completion for a constraint name with an optional decl scope.
virtual void codeCompleteOperationAttributeName(StringRef opName)
Signal code completion for a member access into the given operation type.
Definition: CodeComplete.h:48
virtual void codeCompleteOperationOperandsSignature(std::optional< StringRef > opName, unsigned currentNumOperands)
Signal code completion for the signature of an operation's operands.
Definition: CodeComplete.h:80
virtual void codeCompleteOperationName(StringRef dialectName)
Signal code completion for an operation name in the given dialect.
Definition: CodeComplete.h:62
virtual void codeCompleteOperationResultsSignature(std::optional< StringRef > opName, unsigned currentNumResults)
Signal code completion for the signature of an operation's results.
Definition: CodeComplete.h:85
virtual void codeCompleteDialectName()
Signal code completion for a dialect name.
Definition: CodeComplete.h:59
virtual void codeCompleteOperationMemberAccess(ast::OperationType opType)
Signal code completion for a member access into the given operation type.
virtual void codeCompleteTupleMemberAccess(ast::TupleType tupleType)
Signal code completion for a member access into the given tuple type.
virtual void codeCompletePatternMetadata()
Signal code completion for Pattern metadata.
Definition: CodeComplete.h:65
virtual void codeCompleteCallSignature(const ast::CallableDecl *callable, unsigned currentNumArgs)
Signal code completion for the signature of a callable.
Definition: CodeComplete.h:75
virtual void codeCompleteIncludeFilename(StringRef curPath)
Signal code completion for an include filename.
Definition: CodeComplete.h:68
@ code_complete_string
Token signifying a code completion location within a string.
Definition: Lexer.h:41
@ directive
Tokens.
Definition: Lexer.h:93
@ eof
Markers.
Definition: Lexer.h:36
@ code_complete
Token signifying a code completion location.
Definition: Lexer.h:39
@ arrow
Punctuation.
Definition: Lexer.h:74
@ less
Paired punctuation.
Definition: Lexer.h:82
@ kw_Attr
General keywords.
Definition: Lexer.h:54
static StringRef getMemberName()
Return the member name used for the "all-results" access.
Definition: Nodes.h:487
static AllResultsMemberAccessExpr * create(Context &ctx, SMRange loc, const Expr *parentExpr, Type type)
Definition: Nodes.h:489
The class represents an Attribute constraint, and constrains a variable to be an Attribute.
Definition: Nodes.h:749
static AttrConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr=nullptr)
Definition: Nodes.cpp:390
static AttributeExpr * create(Context &ctx, SMRange loc, StringRef value)
Definition: Nodes.cpp:258
static AttributeType get(Context &context)
Return an instance of the Attribute type.
Definition: Types.cpp:57
static CallExpr * create(Context &ctx, SMRange loc, Expr *callable, ArrayRef< Expr * > arguments, Type resultType, bool isNegated=false)
Definition: Nodes.cpp:268
This decl represents a shared interface for all callable decls.
Definition: Nodes.h:1188
Type getResultType() const
Return the result type of this decl.
Definition: Nodes.h:1206
ArrayRef< VariableDecl * > getInputs() const
Return the inputs of this decl.
Definition: Nodes.h:1199
StringRef getCallableType() const
Return the callable type of this decl.
Definition: Nodes.h:1191
This statement represents a compound statement, which contains a collection of other statements.
Definition: Nodes.h:179
ArrayRef< Stmt * >::iterator begin() const
Definition: Nodes.h:191
MutableArrayRef< Stmt * > getChildren()
Return the children of this compound statement.
Definition: Nodes.h:185
ArrayRef< Stmt * >::iterator end() const
Definition: Nodes.h:192
static CompoundStmt * create(Context &ctx, SMRange location, ArrayRef< Stmt * > children)
Definition: Nodes.cpp:192
This class represents the base of all AST Constraint decls.
Definition: Nodes.h:704
static ConstraintType get(Context &context)
Return an instance of the Constraint type.
Definition: Types.cpp:65
This class represents the main context of the PDLL AST.
Definition: Context.h:25
DiagnosticEngine & getDiagEngine()
Return the diagnostic engine of this context.
Definition: Context.h:42
ods::Context & getODSContext()
Return the ODS context used by the AST.
Definition: Context.h:38
static DeclRefExpr * create(Context &ctx, SMRange loc, Decl *decl, Type type)
Definition: Nodes.cpp:285
This class represents a scope for named AST decls.
Definition: Nodes.h:64
DeclScope * getParentScope()
Return the parent scope of this scope, or nullptr if there is no parent.
Definition: Nodes.h:70
void add(Decl *decl)
Add a new decl to the scope.
Definition: Nodes.cpp:175
This class represents the base Decl node.
Definition: Nodes.h:669
const Name * getName() const
Return the name of the decl, or nullptr if it doesn't have one.
Definition: Nodes.h:672
InFlightDiagnostic emitWarning(SMRange loc, const Twine &msg)
Definition: Diagnostic.h:149
InFlightDiagnostic emitError(SMRange loc, const Twine &msg)
Emit an error to the diagnostic engine.
Definition: Diagnostic.h:145
This class provides a simple implementation of a PDLL diagnostic.
Definition: Diagnostic.h:29
static EraseStmt * create(Context &ctx, SMRange loc, Expr *rootOp)
Definition: Nodes.cpp:218
This class represents a base AST Expression node.
Definition: Nodes.h:345
Type getType() const
Return the type of this expression.
Definition: Nodes.h:348
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostic.h:83
static LetStmt * create(Context &ctx, SMRange loc, VariableDecl *varDecl)
Definition: Nodes.cpp:207
static MemberAccessExpr * create(Context &ctx, SMRange loc, const Expr *parentExpr, StringRef memberName, Type type)
Definition: Nodes.cpp:295
static Module * create(Context &ctx, SMLoc loc, ArrayRef< Decl * > children)
Definition: Nodes.cpp:576
This Decl represents a NamedAttribute, and contains a string name and attribute value.
Definition: Nodes.h:992
static NamedAttributeDecl * create(Context &ctx, const Name &name, Expr *value)
Definition: Nodes.cpp:499
This class represents a base AST node.
Definition: Nodes.h:108
SMRange getLoc() const
Return the location of this node.
Definition: Nodes.h:131
static OpConstraintDecl * create(Context &ctx, SMRange loc, const OpNameDecl *nameDecl=nullptr)
Definition: Nodes.cpp:400
This Decl represents an OperationName.
Definition: Nodes.h:1016
std::optional< StringRef > getName() const
Return the name of this operation, or std::nullopt if the name is unknown.
Definition: Nodes.h:1022
static OpNameDecl * create(Context &ctx, const Name &name)
Definition: Nodes.cpp:509
static OperationExpr * create(Context &ctx, SMRange loc, const ods::Operation *odsOp, const OpNameDecl *nameDecl, ArrayRef< Expr * > operands, ArrayRef< Expr * > resultTypes, ArrayRef< NamedAttributeDecl * > attributes)
Definition: Nodes.cpp:307
This class represents a PDLL type that corresponds to an mlir::Operation.
Definition: Types.h:163
const ods::Operation * getODSOperation() const
Return the ODS operation that this type refers to, or nullptr if the ODS operation is unknown.
Definition: Types.cpp:87
static OperationType get(Context &context, std::optional< StringRef > name=std::nullopt, const ods::Operation *odsOp=nullptr)
Return an instance of the Operation type with an optional operation name.
Definition: Types.cpp:73
static PatternDecl * create(Context &ctx, SMRange location, const Name *name, std::optional< uint16_t > benefit, bool hasBoundedRecursion, const CompoundStmt *body)
Definition: Nodes.cpp:520
static RangeExpr * create(Context &ctx, SMRange loc, ArrayRef< Expr * > elements, RangeType type)
Definition: Nodes.cpp:338
This class represents a PDLL type that corresponds to a range of elements with a given element type.
Definition: Types.h:188
static ReplaceStmt * create(Context &ctx, SMRange loc, Expr *rootOp, ArrayRef< Expr * > replExprs)
Definition: Nodes.cpp:225
static ReturnStmt * create(Context &ctx, SMRange loc, Expr *resultExpr)
Definition: Nodes.cpp:249
static RewriteStmt * create(Context &ctx, SMRange loc, Expr *rootOp, CompoundStmt *rewriteBody)
Definition: Nodes.cpp:239
static RewriteType get(Context &context)
Return an instance of the Rewrite type.
Definition: Types.cpp:134
This class represents a base AST Statement node.
Definition: Nodes.h:164
static TupleExpr * create(Context &ctx, SMRange loc, ArrayRef< Expr * > elements, ArrayRef< StringRef > elementNames)
Definition: Nodes.cpp:353
This class represents a PDLL tuple type, i.e.
Definition: Types.h:249
size_t size() const
Return the number of elements within this tuple.
Definition: Types.h:266
ArrayRef< Type > getElementTypes() const
Return the element types of this tuple.
Definition: Types.cpp:153
static TupleType get(Context &context, ArrayRef< Type > elementTypes, ArrayRef< StringRef > elementNames)
Return an instance of the Tuple type.
Definition: Types.cpp:142
static TypeConstraintDecl * create(Context &ctx, SMRange loc)
Definition: Nodes.cpp:417
static TypeExpr * create(Context &ctx, SMRange loc, StringRef value)
Definition: Nodes.cpp:373
static TypeRangeConstraintDecl * create(Context &ctx, SMRange loc)
Definition: Nodes.cpp:426
static TypeRangeType get(Context &context)
Return an instance of the TypeRange type.
Definition: Types.cpp:112
static TypeType get(Context &context)
Return an instance of the Type type.
Definition: Types.cpp:165
Type refineWith(Type other) const
Try to refine this type with the one provided.
Definition: Types.cpp:33
static UserConstraintDecl * createNative(Context &ctx, const Name &name, ArrayRef< VariableDecl * > inputs, ArrayRef< VariableDecl * > results, std::optional< StringRef > codeBlock, Type resultType, ArrayRef< StringRef > nativeInputTypes={})
Create a native constraint with the given optional code block.
Definition: Nodes.h:886
The class represents a Value constraint, and constrains a variable to be a Value.
Definition: Nodes.h:825
static ValueConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr)
Definition: Nodes.cpp:436
The class represents a ValueRange constraint, and constrains a variable to be a ValueRange.
Definition: Nodes.h:847
static ValueRangeConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr=nullptr)
Definition: Nodes.cpp:447
static ValueRangeType get(Context &context)
Return an instance of the ValueRange type.
Definition: Types.cpp:125
static ValueType get(Context &context)
Return an instance of the Value type.
Definition: Types.cpp:173
static VariableDecl * create(Context &ctx, const Name &name, Type type, Expr *initExpr, ArrayRef< ConstraintRef > constraints)
Definition: Nodes.cpp:558
This class contains all of the registered ODS operation classes.
Definition: Context.h:32
std::pair< Operation *, bool > insertOperation(StringRef name, StringRef summary, StringRef desc, StringRef nativeClassName, bool supportsResultTypeInferrence, SMLoc loc)
Insert a new operation with the context.
Definition: Context.cpp:64
const TypeConstraint & insertTypeConstraint(StringRef name, StringRef summary, StringRef cppClass)
Insert a new type constraint with the context.
Definition: Context.cpp:42
const AttributeConstraint & insertAttributeConstraint(StringRef name, StringRef summary, StringRef cppClass)
Insert a new attribute constraint with the context.
Definition: Context.cpp:28
const Operation * lookupOperation(StringRef name) const
Lookup an operation registered with the given name, or null if no operation with that name is registe...
Definition: Context.cpp:73
This class provides an ODS representation of a specific operation operand or result.
Definition: Operation.h:74
This class provides an ODS representation of a specific operation.
Definition: Operation.h:125
SMRange getLoc() const
Return the source location of this operation.
Definition: Operation.h:128
ArrayRef< OperandOrResult > getOperands() const
Returns the operands of this operation.
Definition: Operation.h:165
bool hasResultTypeInferrence() const
Return if the operation is known to support result type inferrence.
Definition: Operation.h:171
ArrayRef< OperandOrResult > getResults() const
Returns the results of this operation.
Definition: Operation.h:168
This class represents a generic ODS Type constraint.
Definition: Constraint.h:84
raw_ostream subclass that simplifies indention a sequence of code.
raw_indented_ostream & printReindented(StringRef str, StringRef extraPrefix="")
Prints a string re-indented to the current indent.
StringRef getSummary() const
Definition: Constraint.cpp:55
std::string getUniqueDefName() const
Returns a unique name for the TablGen def of this constraint.
Definition: Constraint.cpp:72
StringRef getDescription() const
Definition: Constraint.cpp:62
std::string getConditionTemplate() const
Definition: Constraint.cpp:51
Format context containing substitutions for special placeholders.
Definition: Format.h:40
FmtContext & withSelf(Twine subst)
Definition: Format.cpp:41
Wrapper class that contains a MLIR op's information (e.g., operands, attributes) defined in TableGen ...
Definition: Operator.h:77
FailureOr< ast::Module * > parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr, bool enableDocumentation=false, CodeCompleteContext *codeCompleteContext=nullptr)
Parse an AST module from the main file of the given source manager.
Definition: Parser.cpp:3195
auto tgfmt(StringRef fmt, const FmtContext *ctx, Ts &&...vals) -> FmtObject< decltype(std::make_tuple(llvm::support::detail::build_format_adapter(std::forward< Ts >(vals))...))>
Formats text by substituting placeholders in format string with replacement parameters.
Definition: Format.h:262
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
This class represents a reference to a constraint, and contains a constraint and the location of the ...
Definition: Nodes.h:716
const ConstraintDecl * constraint
Definition: Nodes.h:722
This class provides a convenient API for interacting with source names.
Definition: Nodes.h:37
StringRef getName() const
Return the raw string name.
Definition: Nodes.h:41
SMRange getLoc() const
Get the location of this name.
Definition: Nodes.h:44
static const Name & create(Context &ctx, StringRef name, SMRange location)
Definition: Nodes.cpp:33