MLIR  22.0.0git
Parser.cpp
Go to the documentation of this file.
1 //===- Parser.cpp ---------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "Lexer.h"
12 #include "mlir/TableGen/Argument.h"
15 #include "mlir/TableGen/Format.h"
16 #include "mlir/TableGen/Operator.h"
25 #include "llvm/ADT/StringExtras.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "llvm/Support/SaveAndRestore.h"
29 #include "llvm/Support/ScopedPrinter.h"
30 #include "llvm/TableGen/Error.h"
31 #include "llvm/TableGen/Parser.h"
32 #include <optional>
33 #include <string>
34 
35 using namespace mlir;
36 using namespace mlir::pdll;
37 
38 //===----------------------------------------------------------------------===//
39 // Parser
40 //===----------------------------------------------------------------------===//
41 
42 namespace {
43 class Parser {
44 public:
45  Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
46  bool enableDocumentation, CodeCompleteContext *codeCompleteContext)
47  : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext),
48  curToken(lexer.lexToken()), enableDocumentation(enableDocumentation),
49  typeTy(ast::TypeType::get(ctx)), valueTy(ast::ValueType::get(ctx)),
50  typeRangeTy(ast::TypeRangeType::get(ctx)),
51  valueRangeTy(ast::ValueRangeType::get(ctx)),
53  codeCompleteContext(codeCompleteContext) {}
54 
55  /// Try to parse a new module. Returns nullptr in the case of failure.
56  FailureOr<ast::Module *> parseModule();
57 
58 private:
59  /// The current context of the parser. It allows for the parser to know a bit
60  /// about the construct it is nested within during parsing. This is used
61  /// specifically to provide additional verification during parsing, e.g. to
62  /// prevent using rewrites within a match context, matcher constraints within
63  /// a rewrite section, etc.
64  enum class ParserContext {
65  /// The parser is in the global context.
66  Global,
67  /// The parser is currently within a Constraint, which disallows all types
68  /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.).
69  Constraint,
70  /// The parser is currently within the matcher portion of a Pattern, which
71  /// is allows a terminal operation rewrite statement but no other rewrite
72  /// transformations.
73  PatternMatch,
74  /// The parser is currently within a Rewrite, which disallows calls to
75  /// constraints, requires operation expressions to have names, etc.
76  Rewrite,
77  };
78 
79  /// The current specification context of an operations result type. This
80  /// indicates how the result types of an operation may be inferred.
81  enum class OpResultTypeContext {
82  /// The result types of the operation are not known to be inferred.
83  Explicit,
84  /// The result types of the operation are inferred from the root input of a
85  /// `replace` statement.
86  Replacement,
87  /// The result types of the operation are inferred by using the
88  /// `InferTypeOpInterface` interface provided by the operation.
89  Interface,
90  };
91 
92  //===--------------------------------------------------------------------===//
93  // Parsing
94  //===--------------------------------------------------------------------===//
95 
96  /// Push a new decl scope onto the lexer.
97  ast::DeclScope *pushDeclScope() {
98  ast::DeclScope *newScope =
99  new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope);
100  return (curDeclScope = newScope);
101  }
102  void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; }
103 
104  /// Pop the last decl scope from the lexer.
105  void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); }
106 
107  /// Parse the body of an AST module.
108  LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls);
109 
110  /// Try to convert the given expression to `type`. Returns failure and emits
111  /// an error if a conversion is not viable. On failure, `noteAttachFn` is
112  /// invoked to attach notes to the emitted error diagnostic. On success,
113  /// `expr` is updated to the expression used to convert to `type`.
114  LogicalResult convertExpressionTo(
115  ast::Expr *&expr, ast::Type type,
116  function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {});
117  LogicalResult
118  convertOpExpressionTo(ast::Expr *&expr, ast::OperationType exprType,
119  ast::Type type,
120  function_ref<ast::InFlightDiagnostic()> emitErrorFn);
121  LogicalResult convertTupleExpressionTo(
122  ast::Expr *&expr, ast::TupleType exprType, ast::Type type,
123  function_ref<ast::InFlightDiagnostic()> emitErrorFn,
124  function_ref<void(ast::Diagnostic &diag)> noteAttachFn);
125 
126  /// Given an operation expression, convert it to a Value or ValueRange
127  /// typed expression.
128  ast::Expr *convertOpToValue(const ast::Expr *opExpr);
129 
130  /// Lookup ODS information for the given operation, returns nullptr if no
131  /// information is found.
132  const ods::Operation *lookupODSOperation(std::optional<StringRef> opName) {
133  return opName ? ctx.getODSContext().lookupOperation(*opName) : nullptr;
134  }
135 
136  /// Process the given documentation string, or return an empty string if
137  /// documentation isn't enabled.
138  StringRef processDoc(StringRef doc) {
139  return enableDocumentation ? doc : StringRef();
140  }
141 
142  /// Process the given documentation string and format it, or return an empty
143  /// string if documentation isn't enabled.
144  std::string processAndFormatDoc(const Twine &doc) {
145  if (!enableDocumentation)
146  return "";
147  std::string docStr;
148  {
149  llvm::raw_string_ostream docOS(docStr);
150  std::string tmpDocStr = doc.str();
152  StringRef(tmpDocStr).rtrim(" \t"));
153  }
154  return docStr;
155  }
156 
157  //===--------------------------------------------------------------------===//
158  // Directives
159 
160  LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls);
161  LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls);
162  LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc,
164 
165  /// Process the records of a parsed tablegen include file.
166  void processTdIncludeRecords(const llvm::RecordKeeper &tdRecords,
168 
169  /// Create a user defined native constraint for a constraint imported from
170  /// ODS.
171  template <typename ConstraintT>
172  ast::Decl *
173  createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock,
174  SMRange loc, ast::Type type,
175  StringRef nativeType, StringRef docString);
176  template <typename ConstraintT>
177  ast::Decl *
178  createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
179  SMRange loc, ast::Type type,
180  StringRef nativeType);
181 
182  //===--------------------------------------------------------------------===//
183  // Decls
184 
185  /// This structure contains the set of pattern metadata that may be parsed.
186  struct ParsedPatternMetadata {
187  std::optional<uint16_t> benefit;
188  bool hasBoundedRecursion = false;
189  };
190 
191  FailureOr<ast::Decl *> parseTopLevelDecl();
192  FailureOr<ast::NamedAttributeDecl *>
193  parseNamedAttributeDecl(std::optional<StringRef> parentOpName);
194 
195  /// Parse an argument variable as part of the signature of a
196  /// UserConstraintDecl or UserRewriteDecl.
197  FailureOr<ast::VariableDecl *> parseArgumentDecl();
198 
199  /// Parse a result variable as part of the signature of a UserConstraintDecl
200  /// or UserRewriteDecl.
201  FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum);
202 
203  /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being
204  /// defined in a non-global context.
205  FailureOr<ast::UserConstraintDecl *>
206  parseUserConstraintDecl(bool isInline = false);
207 
208  /// Parse an inline UserConstraintDecl. An inline decl is one defined in a
209  /// non-global context, such as within a Pattern/Constraint/etc.
210  FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl();
211 
212  /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
213  /// PDLL constructs.
214  FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl(
215  const ast::Name &name, bool isInline,
216  ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
217  ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
218 
219  /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being
220  /// defined in a non-global context.
221  FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false);
222 
223  /// Parse an inline UserRewriteDecl. An inline decl is one defined in a
224  /// non-global context, such as within a Pattern/Rewrite/etc.
225  FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl();
226 
227  /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
228  /// PDLL constructs.
229  FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl(
230  const ast::Name &name, bool isInline,
231  ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
232  ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
233 
234  /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have
235  /// effectively the same syntax, and only differ on slight semantics (given
236  /// the different parsing contexts).
237  template <typename T, typename ParseUserPDLLDeclFnT>
238  FailureOr<T *> parseUserConstraintOrRewriteDecl(
239  ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
240  StringRef anonymousNamePrefix, bool isInline);
241 
242  /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl.
243  /// These decls have effectively the same syntax.
244  template <typename T>
245  FailureOr<T *> parseUserNativeConstraintOrRewriteDecl(
246  const ast::Name &name, bool isInline,
248  ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
249 
250  /// Parse the functional signature (i.e. the arguments and results) of a
251  /// UserConstraintDecl or UserRewriteDecl.
252  LogicalResult parseUserConstraintOrRewriteSignature(
255  ast::DeclScope *&argumentScope, ast::Type &resultType);
256 
257  /// Validate the return (which if present is specified by bodyIt) of a
258  /// UserConstraintDecl or UserRewriteDecl.
259  LogicalResult validateUserConstraintOrRewriteReturn(
260  StringRef declType, ast::CompoundStmt *body,
263  ArrayRef<ast::VariableDecl *> results, ast::Type &resultType);
264 
265  FailureOr<ast::CompoundStmt *>
266  parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
267  bool expectTerminalSemicolon = true);
268  FailureOr<ast::CompoundStmt *> parsePatternLambdaBody();
269  FailureOr<ast::Decl *> parsePatternDecl();
270  LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
271 
272  /// Check to see if a decl has already been defined with the given name, if
273  /// one has emit and error and return failure. Returns success otherwise.
274  LogicalResult checkDefineNamedDecl(const ast::Name &name);
275 
276  /// Try to define a variable decl with the given components, returns the
277  /// variable on success.
278  FailureOr<ast::VariableDecl *>
279  defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
280  ast::Expr *initExpr,
281  ArrayRef<ast::ConstraintRef> constraints);
282  FailureOr<ast::VariableDecl *>
283  defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
284  ArrayRef<ast::ConstraintRef> constraints);
285 
286  /// Parse the constraint reference list for a variable decl.
287  LogicalResult parseVariableDeclConstraintList(
289 
290  /// Parse the expression used within a type constraint, e.g. Attr<type-expr>.
291  FailureOr<ast::Expr *> parseTypeConstraintExpr();
292 
293  /// Try to parse a single reference to a constraint. `typeConstraint` is the
294  /// location of a previously parsed type constraint for the entity that will
295  /// be constrained by the parsed constraint. `existingConstraints` are any
296  /// existing constraints that have already been parsed for the same entity
297  /// that will be constrained by this constraint. `allowInlineTypeConstraints`
298  /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`.
299  FailureOr<ast::ConstraintRef>
300  parseConstraint(std::optional<SMRange> &typeConstraint,
301  ArrayRef<ast::ConstraintRef> existingConstraints,
302  bool allowInlineTypeConstraints);
303 
304  /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl
305  /// argument or result variable. The constraints for these variables do not
306  /// allow inline type constraints, and only permit a single constraint.
307  FailureOr<ast::ConstraintRef> parseArgOrResultConstraint();
308 
309  //===--------------------------------------------------------------------===//
310  // Exprs
311 
312  FailureOr<ast::Expr *> parseExpr();
313 
314  /// Identifier expressions.
315  FailureOr<ast::Expr *> parseAttributeExpr();
316  FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr,
317  bool isNegated = false);
318  FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
319  FailureOr<ast::Expr *> parseIdentifierExpr();
320  FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
321  FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
322  FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
323  FailureOr<ast::Expr *> parseNegatedExpr();
324  FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
325  FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
326  FailureOr<ast::Expr *>
327  parseOperationExpr(OpResultTypeContext inputResultTypeContext =
328  OpResultTypeContext::Explicit);
329  FailureOr<ast::Expr *> parseTupleExpr();
330  FailureOr<ast::Expr *> parseTypeExpr();
331  FailureOr<ast::Expr *> parseUnderscoreExpr();
332 
333  //===--------------------------------------------------------------------===//
334  // Stmts
335 
336  FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true);
337  FailureOr<ast::CompoundStmt *> parseCompoundStmt();
338  FailureOr<ast::EraseStmt *> parseEraseStmt();
339  FailureOr<ast::LetStmt *> parseLetStmt();
340  FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
341  FailureOr<ast::ReturnStmt *> parseReturnStmt();
342  FailureOr<ast::RewriteStmt *> parseRewriteStmt();
343 
344  //===--------------------------------------------------------------------===//
345  // Creation+Analysis
346  //===--------------------------------------------------------------------===//
347 
348  //===--------------------------------------------------------------------===//
349  // Decls
350 
351  /// Try to extract a callable from the given AST node. Returns nullptr on
352  /// failure.
353  ast::CallableDecl *tryExtractCallableDecl(ast::Node *node);
354 
355  /// Try to create a pattern decl with the given components, returning the
356  /// Pattern on success.
357  FailureOr<ast::PatternDecl *>
358  createPatternDecl(SMRange loc, const ast::Name *name,
359  const ParsedPatternMetadata &metadata,
360  ast::CompoundStmt *body);
361 
362  /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set
363  /// of results, defined as part of the signature.
364  ast::Type
365  createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results);
366 
367  /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl.
368  template <typename T>
369  FailureOr<T *> createUserPDLLConstraintOrRewriteDecl(
370  const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
371  ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
372  ast::CompoundStmt *body);
373 
374  /// Try to create a variable decl with the given components, returning the
375  /// Variable on success.
376  FailureOr<ast::VariableDecl *>
377  createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
378  ArrayRef<ast::ConstraintRef> constraints);
379 
380  /// Create a variable for an argument or result defined as part of the
381  /// signature of a UserConstraintDecl/UserRewriteDecl.
382  FailureOr<ast::VariableDecl *>
383  createArgOrResultVariableDecl(StringRef name, SMRange loc,
384  const ast::ConstraintRef &constraint);
385 
386  /// Validate the constraints used to constraint a variable decl.
387  /// `inferredType` is the type of the variable inferred by the constraints
388  /// within the list, and is updated to the most refined type as determined by
389  /// the constraints. Returns success if the constraint list is valid, failure
390  /// otherwise.
391  LogicalResult
392  validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
393  ast::Type &inferredType);
394  /// Validate a single reference to a constraint. `inferredType` contains the
395  /// currently inferred variabled type and is refined within the type defined
396  /// by the constraint. Returns success if the constraint is valid, failure
397  /// otherwise.
398  LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref,
399  ast::Type &inferredType);
400  LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr);
401  LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr);
402 
403  //===--------------------------------------------------------------------===//
404  // Exprs
405 
406  FailureOr<ast::CallExpr *>
407  createCallExpr(SMRange loc, ast::Expr *parentExpr,
409  bool isNegated = false);
410  FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl);
411  FailureOr<ast::DeclRefExpr *>
412  createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
413  ArrayRef<ast::ConstraintRef> constraints);
414  FailureOr<ast::MemberAccessExpr *>
415  createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc);
416 
417  /// Validate the member access `name` into the given parent expression. On
418  /// success, this also returns the type of the member accessed.
419  FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr,
420  StringRef name, SMRange loc);
421  FailureOr<ast::OperationExpr *>
422  createOperationExpr(SMRange loc, const ast::OpNameDecl *name,
423  OpResultTypeContext resultTypeContext,
427  LogicalResult
428  validateOperationOperands(SMRange loc, std::optional<StringRef> name,
429  const ods::Operation *odsOp,
430  SmallVectorImpl<ast::Expr *> &operands);
431  LogicalResult validateOperationResults(SMRange loc,
432  std::optional<StringRef> name,
433  const ods::Operation *odsOp,
435  void checkOperationResultTypeInferrence(SMRange loc, StringRef name,
436  const ods::Operation *odsOp);
437  LogicalResult validateOperationOperandsOrResults(
438  StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc,
439  std::optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values,
440  ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
441  ast::RangeType rangeTy);
442  FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
443  ArrayRef<ast::Expr *> elements,
444  ArrayRef<StringRef> elementNames);
445 
446  //===--------------------------------------------------------------------===//
447  // Stmts
448 
449  FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp);
450  FailureOr<ast::ReplaceStmt *>
451  createReplaceStmt(SMRange loc, ast::Expr *rootOp,
452  MutableArrayRef<ast::Expr *> replValues);
453  FailureOr<ast::RewriteStmt *>
454  createRewriteStmt(SMRange loc, ast::Expr *rootOp,
455  ast::CompoundStmt *rewriteBody);
456 
457  //===--------------------------------------------------------------------===//
458  // Code Completion
459  //===--------------------------------------------------------------------===//
460 
461  /// The set of various code completion methods. Every completion method
462  /// returns `failure` to stop the parsing process after providing completion
463  /// results.
464 
465  LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr);
466  LogicalResult codeCompleteAttributeName(std::optional<StringRef> opName);
467  LogicalResult codeCompleteConstraintName(ast::Type inferredType,
468  bool allowInlineTypeConstraints);
469  LogicalResult codeCompleteDialectName();
470  LogicalResult codeCompleteOperationName(StringRef dialectName);
471  LogicalResult codeCompletePatternMetadata();
472  LogicalResult codeCompleteIncludeFilename(StringRef curPath);
473 
474  void codeCompleteCallSignature(ast::Node *parent, unsigned currentNumArgs);
475  void codeCompleteOperationOperandsSignature(std::optional<StringRef> opName,
476  unsigned currentNumOperands);
477  void codeCompleteOperationResultsSignature(std::optional<StringRef> opName,
478  unsigned currentNumResults);
479 
480  //===--------------------------------------------------------------------===//
481  // Lexer Utilities
482  //===--------------------------------------------------------------------===//
483 
484  /// If the current token has the specified kind, consume it and return true.
485  /// If not, return false.
486  bool consumeIf(Token::Kind kind) {
487  if (curToken.isNot(kind))
488  return false;
489  consumeToken(kind);
490  return true;
491  }
492 
493  /// Advance the current lexer onto the next token.
494  void consumeToken() {
495  assert(curToken.isNot(Token::eof, Token::error) &&
496  "shouldn't advance past EOF or errors");
497  curToken = lexer.lexToken();
498  }
499 
500  /// Advance the current lexer onto the next token, asserting what the expected
501  /// current token is. This is preferred to the above method because it leads
502  /// to more self-documenting code with better checking.
503  void consumeToken(Token::Kind kind) {
504  assert(curToken.is(kind) && "consumed an unexpected token");
505  consumeToken();
506  }
507 
508  /// Reset the lexer to the location at the given position.
509  void resetToken(SMRange tokLoc) {
510  lexer.resetPointer(tokLoc.Start.getPointer());
511  curToken = lexer.lexToken();
512  }
513 
514  /// Consume the specified token if present and return success. On failure,
515  /// output a diagnostic and return failure.
516  LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
517  if (curToken.getKind() != kind)
518  return emitError(curToken.getLoc(), msg);
519  consumeToken();
520  return success();
521  }
522  LogicalResult emitError(SMRange loc, const Twine &msg) {
523  lexer.emitError(loc, msg);
524  return failure();
525  }
526  LogicalResult emitError(const Twine &msg) {
527  return emitError(curToken.getLoc(), msg);
528  }
529  LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc,
530  const Twine &note) {
531  lexer.emitErrorAndNote(loc, msg, noteLoc, note);
532  return failure();
533  }
534 
535  //===--------------------------------------------------------------------===//
536  // Fields
537  //===--------------------------------------------------------------------===//
538 
539  /// The owning AST context.
540  ast::Context &ctx;
541 
542  /// The lexer of this parser.
543  Lexer lexer;
544 
545  /// The current token within the lexer.
546  Token curToken;
547 
548  /// A flag indicating if the parser should add documentation to AST nodes when
549  /// viable.
550  bool enableDocumentation;
551 
552  /// The most recently defined decl scope.
553  ast::DeclScope *curDeclScope = nullptr;
554  llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;
555 
556  /// The current context of the parser.
557  ParserContext parserContext = ParserContext::Global;
558 
559  /// Cached types to simplify verification and expression creation.
560  ast::Type typeTy, valueTy;
561  ast::RangeType typeRangeTy, valueRangeTy;
562  ast::Type attrTy;
563 
564  /// A counter used when naming anonymous constraints and rewrites.
565  unsigned anonymousDeclNameCounter = 0;
566 
567  /// The optional code completion context.
568  CodeCompleteContext *codeCompleteContext;
569 };
570 } // namespace
571 
572 FailureOr<ast::Module *> Parser::parseModule() {
573  SMLoc moduleLoc = curToken.getStartLoc();
574  pushDeclScope();
575 
576  // Parse the top-level decls of the module.
578  if (failed(parseModuleBody(decls)))
579  return popDeclScope(), failure();
580 
581  popDeclScope();
582  return ast::Module::create(ctx, moduleLoc, decls);
583 }
584 
585 LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) {
586  while (curToken.isNot(Token::eof)) {
587  if (curToken.is(Token::directive)) {
588  if (failed(parseDirective(decls)))
589  return failure();
590  continue;
591  }
592 
593  FailureOr<ast::Decl *> decl = parseTopLevelDecl();
594  if (failed(decl))
595  return failure();
596  decls.push_back(*decl);
597  }
598  return success();
599 }
600 
601 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) {
602  return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr,
603  valueRangeTy);
604 }
605 
606 LogicalResult Parser::convertExpressionTo(
607  ast::Expr *&expr, ast::Type type,
608  function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
609  ast::Type exprType = expr->getType();
610  if (exprType == type)
611  return success();
612 
613  auto emitConvertError = [&]() -> ast::InFlightDiagnostic {
615  expr->getLoc(), llvm::formatv("unable to convert expression of type "
616  "`{0}` to the expected type of "
617  "`{1}`",
618  exprType, type));
619  if (noteAttachFn)
620  noteAttachFn(*diag);
621  return diag;
622  };
623 
624  if (auto exprOpType = dyn_cast<ast::OperationType>(exprType))
625  return convertOpExpressionTo(expr, exprOpType, type, emitConvertError);
626 
627  // FIXME: Decide how to allow/support converting a single result to multiple,
628  // and multiple to a single result. For now, we just allow Single->Range,
629  // but this isn't something really supported in the PDL dialect. We should
630  // figure out some way to support both.
631  if ((exprType == valueTy || exprType == valueRangeTy) &&
632  (type == valueTy || type == valueRangeTy))
633  return success();
634  if ((exprType == typeTy || exprType == typeRangeTy) &&
635  (type == typeTy || type == typeRangeTy))
636  return success();
637 
638  // Handle tuple types.
639  if (auto exprTupleType = dyn_cast<ast::TupleType>(exprType))
640  return convertTupleExpressionTo(expr, exprTupleType, type, emitConvertError,
641  noteAttachFn);
642 
643  return emitConvertError();
644 }
645 
646 LogicalResult Parser::convertOpExpressionTo(
647  ast::Expr *&expr, ast::OperationType exprType, ast::Type type,
648  function_ref<ast::InFlightDiagnostic()> emitErrorFn) {
649  // Two operation types are compatible if they have the same name, or if the
650  // expected type is more general.
651  if (auto opType = dyn_cast<ast::OperationType>(type)) {
652  if (opType.getName())
653  return emitErrorFn();
654  return success();
655  }
656 
657  // An operation can always convert to a ValueRange.
658  if (type == valueRangeTy) {
659  expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
660  valueRangeTy);
661  return success();
662  }
663 
664  // Allow conversion to a single value by constraining the result range.
665  if (type == valueTy) {
666  // If the operation is registered, we can verify if it can ever have a
667  // single result.
668  if (const ods::Operation *odsOp = exprType.getODSOperation()) {
669  if (odsOp->getResults().empty()) {
670  return emitErrorFn()->attachNote(
671  llvm::formatv("see the definition of `{0}`, which was defined "
672  "with zero results",
673  odsOp->getName()),
674  odsOp->getLoc());
675  }
676 
677  unsigned numSingleResults = llvm::count_if(
678  odsOp->getResults(), [](const ods::OperandOrResult &result) {
679  return result.getVariableLengthKind() ==
680  ods::VariableLengthKind::Single;
681  });
682  if (numSingleResults > 1) {
683  return emitErrorFn()->attachNote(
684  llvm::formatv("see the definition of `{0}`, which was defined "
685  "with at least {1} results",
686  odsOp->getName(), numSingleResults),
687  odsOp->getLoc());
688  }
689  }
690 
691  expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
692  valueTy);
693  return success();
694  }
695  return emitErrorFn();
696 }
697 
698 LogicalResult Parser::convertTupleExpressionTo(
699  ast::Expr *&expr, ast::TupleType exprType, ast::Type type,
700  function_ref<ast::InFlightDiagnostic()> emitErrorFn,
701  function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
702  // Handle conversions between tuples.
703  if (auto tupleType = dyn_cast<ast::TupleType>(type)) {
704  if (tupleType.size() != exprType.size())
705  return emitErrorFn();
706 
707  // Build a new tuple expression using each of the elements of the current
708  // tuple.
709  SmallVector<ast::Expr *> newExprs;
710  for (unsigned i = 0, e = exprType.size(); i < e; ++i) {
711  newExprs.push_back(ast::MemberAccessExpr::create(
712  ctx, expr->getLoc(), expr, llvm::to_string(i),
713  exprType.getElementTypes()[i]));
714 
715  auto diagFn = [&](ast::Diagnostic &diag) {
716  diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`",
717  i, exprType));
718  if (noteAttachFn)
719  noteAttachFn(diag);
720  };
721  if (failed(convertExpressionTo(newExprs.back(),
722  tupleType.getElementTypes()[i], diagFn)))
723  return failure();
724  }
725  expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs,
726  tupleType.getElementNames());
727  return success();
728  }
729 
730  // Handle conversion to a range.
731  auto convertToRange = [&](ArrayRef<ast::Type> allowedElementTypes,
732  ast::RangeType resultTy) -> LogicalResult {
733  // TODO: We currently only allow range conversion within a rewrite context.
734  if (parserContext != ParserContext::Rewrite) {
735  return emitErrorFn()->attachNote("Tuple to Range conversion is currently "
736  "only allowed within a rewrite context");
737  }
738 
739  // All of the tuple elements must be allowed types.
740  for (ast::Type elementType : exprType.getElementTypes())
741  if (!llvm::is_contained(allowedElementTypes, elementType))
742  return emitErrorFn();
743 
744  // Build a new tuple expression using each of the elements of the current
745  // tuple.
746  SmallVector<ast::Expr *> newExprs;
747  for (unsigned i = 0, e = exprType.size(); i < e; ++i) {
748  newExprs.push_back(ast::MemberAccessExpr::create(
749  ctx, expr->getLoc(), expr, llvm::to_string(i),
750  exprType.getElementTypes()[i]));
751  }
752  expr = ast::RangeExpr::create(ctx, expr->getLoc(), newExprs, resultTy);
753  return success();
754  };
755  if (type == valueRangeTy)
756  return convertToRange({valueTy, valueRangeTy}, valueRangeTy);
757  if (type == typeRangeTy)
758  return convertToRange({typeTy, typeRangeTy}, typeRangeTy);
759 
760  return emitErrorFn();
761 }
762 
763 //===----------------------------------------------------------------------===//
764 // Directives
765 //===----------------------------------------------------------------------===//
766 
767 LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) {
768  StringRef directive = curToken.getSpelling();
769  if (directive == "#include")
770  return parseInclude(decls);
771 
772  return emitError("unknown directive `" + directive + "`");
773 }
774 
775 LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) {
776  SMRange loc = curToken.getLoc();
777  consumeToken(Token::directive);
778 
779  // Handle code completion of the include file path.
780  if (curToken.is(Token::code_complete_string))
781  return codeCompleteIncludeFilename(curToken.getStringValue());
782 
783  // Parse the file being included.
784  if (!curToken.isString())
785  return emitError(loc,
786  "expected string file name after `include` directive");
787  SMRange fileLoc = curToken.getLoc();
788  std::string filenameStr = curToken.getStringValue();
789  StringRef filename = filenameStr;
790  consumeToken();
791 
792  // Check the type of include. If ending with `.pdll`, this is another pdl file
793  // to be parsed along with the current module.
794  if (filename.ends_with(".pdll")) {
795  if (failed(lexer.pushInclude(filename, fileLoc)))
796  return emitError(fileLoc,
797  "unable to open include file `" + filename + "`");
798 
799  // If we added the include successfully, parse it into the current module.
800  // Make sure to update to the next token after we finish parsing the nested
801  // file.
802  curToken = lexer.lexToken();
803  LogicalResult result = parseModuleBody(decls);
804  curToken = lexer.lexToken();
805  return result;
806  }
807 
808  // Otherwise, this must be a `.td` include.
809  if (filename.ends_with(".td"))
810  return parseTdInclude(filename, fileLoc, decls);
811 
812  return emitError(fileLoc,
813  "expected include filename to end with `.pdll` or `.td`");
814 }
815 
816 LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc,
818  llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr();
819 
820  // Use the source manager to open the file, but don't yet add it.
821  std::string includedFile;
822  llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer =
823  parserSrcMgr.OpenIncludeFile(filename.str(), includedFile);
824  if (!includeBuffer)
825  return emitError(fileLoc, "unable to open include file `" + filename + "`");
826 
827  // Setup the source manager for parsing the tablegen file.
828  llvm::SourceMgr tdSrcMgr;
829  tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc());
830  tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs());
831 
832  // This class provides a context argument for the llvm::SourceMgr diagnostic
833  // handler.
834  struct DiagHandlerContext {
835  Parser &parser;
836  StringRef filename;
837  llvm::SMRange loc;
838  } handlerContext{*this, filename, fileLoc};
839 
840  // Set the diagnostic handler for the tablegen source manager.
841  tdSrcMgr.setDiagHandler(
842  [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) {
843  auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext);
844  (void)ctx->parser.emitError(
845  ctx->loc,
846  llvm::formatv("error while processing include file `{0}`: {1}",
847  ctx->filename, diag.getMessage()));
848  },
849  &handlerContext);
850 
851  // Parse the tablegen file.
852  llvm::RecordKeeper tdRecords;
853  if (llvm::TableGenParseFile(tdSrcMgr, tdRecords))
854  return failure();
855 
856  // Process the parsed records.
857  processTdIncludeRecords(tdRecords, decls);
858 
859  // After we are done processing, move all of the tablegen source buffers to
860  // the main parser source mgr. This allows for directly using source locations
861  // from the .td files without needing to remap them.
862  parserSrcMgr.takeSourceBuffersFrom(tdSrcMgr, fileLoc.End);
863  return success();
864 }
865 
866 void Parser::processTdIncludeRecords(const llvm::RecordKeeper &tdRecords,
868  // Return the length kind of the given value.
869  auto getLengthKind = [](const auto &value) {
870  if (value.isOptional())
872  return value.isVariadic() ? ods::VariableLengthKind::Variadic
874  };
875 
876  // Insert a type constraint into the ODS context.
877  ods::Context &odsContext = ctx.getODSContext();
878  auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst)
879  -> const ods::TypeConstraint & {
880  return odsContext.insertTypeConstraint(
881  cst.constraint.getUniqueDefName(),
882  processDoc(cst.constraint.getSummary()), cst.constraint.getCppType());
883  };
884  auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange {
885  return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)};
886  };
887 
888  // Process the parsed tablegen records to build ODS information.
889  /// Operations.
890  for (const llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) {
891  tblgen::Operator op(def);
892 
893  // Check to see if this operation is known to support type inferrence.
894  bool supportsResultTypeInferrence =
895  op.getTrait("::mlir::InferTypeOpInterface::Trait");
896 
897  auto [odsOp, inserted] = odsContext.insertOperation(
898  op.getOperationName(), processDoc(op.getSummary()),
899  processAndFormatDoc(op.getDescription()), op.getQualCppClassName(),
900  supportsResultTypeInferrence, op.getLoc().front());
901 
902  // Ignore operations that have already been added.
903  if (!inserted)
904  continue;
905 
906  for (const tblgen::NamedAttribute &attr : op.getAttributes()) {
907  odsOp->appendAttribute(attr.name, attr.attr.isOptional(),
908  odsContext.insertAttributeConstraint(
909  attr.attr.getUniqueDefName(),
910  processDoc(attr.attr.getSummary()),
911  attr.attr.getStorageType()));
912  }
913  for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) {
914  odsOp->appendOperand(operand.name, getLengthKind(operand),
915  addTypeConstraint(operand));
916  }
917  for (const tblgen::NamedTypeConstraint &result : op.getResults()) {
918  odsOp->appendResult(result.name, getLengthKind(result),
919  addTypeConstraint(result));
920  }
921  }
922 
923  auto shouldBeSkipped = [this](const llvm::Record *def) {
924  return def->isAnonymous() || curDeclScope->lookup(def->getName()) ||
925  def->isSubClassOf("DeclareInterfaceMethods");
926  };
927 
928  /// Attr constraints.
929  for (const llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) {
930  if (shouldBeSkipped(def))
931  continue;
932 
933  tblgen::Attribute constraint(def);
934  decls.push_back(createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
935  constraint, convertLocToRange(def->getLoc().front()), attrTy,
936  constraint.getStorageType()));
937  }
938  /// Type constraints.
939  for (const llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) {
940  if (shouldBeSkipped(def))
941  continue;
942 
943  tblgen::TypeConstraint constraint(def);
944  decls.push_back(createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
945  constraint, convertLocToRange(def->getLoc().front()), typeTy,
946  constraint.getCppType()));
947  }
948  /// OpInterfaces.
950  for (const llvm::Record *def :
951  tdRecords.getAllDerivedDefinitions("OpInterface")) {
952  if (shouldBeSkipped(def))
953  continue;
954 
955  SMRange loc = convertLocToRange(def->getLoc().front());
956 
957  std::string cppClassName =
958  llvm::formatv("{0}::{1}", def->getValueAsString("cppNamespace"),
959  def->getValueAsString("cppInterfaceName"))
960  .str();
961  std::string codeBlock =
962  llvm::formatv("return ::mlir::success(llvm::isa<{0}>(self));",
963  cppClassName)
964  .str();
965 
966  std::string desc =
967  processAndFormatDoc(def->getValueAsString("description"));
968  decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>(
969  def->getName(), codeBlock, loc, opTy, cppClassName, desc));
970  }
971 }
972 
973 template <typename ConstraintT>
974 ast::Decl *Parser::createODSNativePDLLConstraintDecl(
975  StringRef name, StringRef codeBlock, SMRange loc, ast::Type type,
976  StringRef nativeType, StringRef docString) {
977  // Build the single input parameter.
978  ast::DeclScope *argScope = pushDeclScope();
979  auto *paramVar = ast::VariableDecl::create(
980  ctx, ast::Name::create(ctx, "self", loc), type,
981  /*initExpr=*/nullptr, ast::ConstraintRef(ConstraintT::create(ctx, loc)));
982  argScope->add(paramVar);
983  popDeclScope();
984 
985  // Build the native constraint.
986  auto *constraintDecl = ast::UserConstraintDecl::createNative(
987  ctx, ast::Name::create(ctx, name, loc), paramVar,
988  /*results=*/{}, codeBlock, ast::TupleType::get(ctx), nativeType);
989  constraintDecl->setDocComment(ctx, docString);
990  curDeclScope->add(constraintDecl);
991  return constraintDecl;
992 }
993 
994 template <typename ConstraintT>
995 ast::Decl *
996 Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
997  SMRange loc, ast::Type type,
998  StringRef nativeType) {
999  // Format the condition template.
1000  tblgen::FmtContext fmtContext;
1001  fmtContext.withSelf("self");
1002  std::string codeBlock = tblgen::tgfmt(
1003  "return ::mlir::success(" + constraint.getConditionTemplate() + ");",
1004  &fmtContext);
1005 
1006  // If documentation was enabled, build the doc string for the generated
1007  // constraint. It would be nice to do this lazily, but TableGen information is
1008  // destroyed after we finish parsing the file.
1009  std::string docString;
1010  if (enableDocumentation) {
1011  StringRef desc = constraint.getDescription();
1012  docString = processAndFormatDoc(
1013  constraint.getSummary() +
1014  (desc.empty() ? "" : ("\n\n" + constraint.getDescription())));
1015  }
1016 
1017  return createODSNativePDLLConstraintDecl<ConstraintT>(
1018  constraint.getUniqueDefName(), codeBlock, loc, type, nativeType,
1019  docString);
1020 }
1021 
1022 //===----------------------------------------------------------------------===//
1023 // Decls
1024 //===----------------------------------------------------------------------===//
1025 
1026 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
1027  FailureOr<ast::Decl *> decl;
1028  switch (curToken.getKind()) {
1029  case Token::kw_Constraint:
1030  decl = parseUserConstraintDecl();
1031  break;
1032  case Token::kw_Pattern:
1033  decl = parsePatternDecl();
1034  break;
1035  case Token::kw_Rewrite:
1036  decl = parseUserRewriteDecl();
1037  break;
1038  default:
1039  return emitError("expected top-level declaration, such as a `Pattern`");
1040  }
1041  if (failed(decl))
1042  return failure();
1043 
1044  // If the decl has a name, add it to the current scope.
1045  if (const ast::Name *name = (*decl)->getName()) {
1046  if (failed(checkDefineNamedDecl(*name)))
1047  return failure();
1048  curDeclScope->add(*decl);
1049  }
1050  return decl;
1051 }
1052 
1053 FailureOr<ast::NamedAttributeDecl *>
1054 Parser::parseNamedAttributeDecl(std::optional<StringRef> parentOpName) {
1055  // Check for name code completion.
1056  if (curToken.is(Token::code_complete))
1057  return codeCompleteAttributeName(parentOpName);
1058 
1059  std::string attrNameStr;
1060  if (curToken.isString())
1061  attrNameStr = curToken.getStringValue();
1062  else if (curToken.is(Token::identifier) || curToken.isKeyword())
1063  attrNameStr = curToken.getSpelling().str();
1064  else
1065  return emitError("expected identifier or string attribute name");
1066  const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc());
1067  consumeToken();
1068 
1069  // Check for a value of the attribute.
1070  ast::Expr *attrValue = nullptr;
1071  if (consumeIf(Token::equal)) {
1072  FailureOr<ast::Expr *> attrExpr = parseExpr();
1073  if (failed(attrExpr))
1074  return failure();
1075  attrValue = *attrExpr;
1076  } else {
1077  // If there isn't a concrete value, create an expression representing a
1078  // UnitAttr.
1079  attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit");
1080  }
1081 
1082  return ast::NamedAttributeDecl::create(ctx, name, attrValue);
1083 }
1084 
1085 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody(
1086  function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
1087  bool expectTerminalSemicolon) {
1088  consumeToken(Token::equal_arrow);
1089 
1090  // Parse the single statement of the lambda body.
1091  SMLoc bodyStartLoc = curToken.getStartLoc();
1092  pushDeclScope();
1093  FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon);
1094  bool failedToParse =
1095  failed(singleStatement) || failed(processStatementFn(*singleStatement));
1096  popDeclScope();
1097  if (failedToParse)
1098  return failure();
1099 
1100  SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc());
1101  return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement);
1102 }
1103 
1104 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() {
1105  // Ensure that the argument is named.
1106  if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword())
1107  return emitError("expected identifier argument name");
1108 
1109  // Parse the argument similarly to a normal variable.
1110  StringRef name = curToken.getSpelling();
1111  SMRange nameLoc = curToken.getLoc();
1112  consumeToken();
1113 
1114  if (failed(
1115  parseToken(Token::colon, "expected `:` before argument constraint")))
1116  return failure();
1117 
1118  FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1119  if (failed(cst))
1120  return failure();
1121 
1122  return createArgOrResultVariableDecl(name, nameLoc, *cst);
1123 }
1124 
1125 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) {
1126  // Check to see if this result is named.
1127  if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
1128  // Check to see if this name actually refers to a Constraint.
1129  if (!curDeclScope->lookup<ast::ConstraintDecl>(curToken.getSpelling())) {
1130  // If it wasn't a constraint, parse the result similarly to a variable. If
1131  // there is already an existing decl, we will emit an error when defining
1132  // this variable later.
1133  StringRef name = curToken.getSpelling();
1134  SMRange nameLoc = curToken.getLoc();
1135  consumeToken();
1136 
1137  if (failed(parseToken(Token::colon,
1138  "expected `:` before result constraint")))
1139  return failure();
1140 
1141  FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1142  if (failed(cst))
1143  return failure();
1144 
1145  return createArgOrResultVariableDecl(name, nameLoc, *cst);
1146  }
1147  }
1148 
1149  // If it isn't named, we parse the constraint directly and create an unnamed
1150  // result variable.
1151  FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1152  if (failed(cst))
1153  return failure();
1154 
1155  return createArgOrResultVariableDecl("", cst->referenceLoc, *cst);
1156 }
1157 
1158 FailureOr<ast::UserConstraintDecl *>
1159 Parser::parseUserConstraintDecl(bool isInline) {
1160  // Constraints and rewrites have very similar formats, dispatch to a shared
1161  // interface for parsing.
1162  return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1163  [&](auto &&...args) {
1164  return this->parseUserPDLLConstraintDecl(args...);
1165  },
1166  ParserContext::Constraint, "constraint", isInline);
1167 }
1168 
1169 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() {
1170  FailureOr<ast::UserConstraintDecl *> decl =
1171  parseUserConstraintDecl(/*isInline=*/true);
1172  if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1173  return failure();
1174 
1175  curDeclScope->add(*decl);
1176  return decl;
1177 }
1178 
1179 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl(
1180  const ast::Name &name, bool isInline,
1181  ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1182  ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1183  // Push the argument scope back onto the list, so that the body can
1184  // reference arguments.
1185  pushDeclScope(argumentScope);
1186 
1187  // Parse the body of the constraint. The body is either defined as a compound
1188  // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`.
1189  ast::CompoundStmt *body;
1190  if (curToken.is(Token::equal_arrow)) {
1191  FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1192  [&](ast::Stmt *&stmt) -> LogicalResult {
1193  ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt);
1194  if (!stmtExpr) {
1195  return emitError(stmt->getLoc(),
1196  "expected `Constraint` lambda body to contain a "
1197  "single expression");
1198  }
1199  stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr);
1200  return success();
1201  },
1202  /*expectTerminalSemicolon=*/!isInline);
1203  if (failed(bodyResult))
1204  return failure();
1205  body = *bodyResult;
1206  } else {
1207  FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1208  if (failed(bodyResult))
1209  return failure();
1210  body = *bodyResult;
1211 
1212  // Verify the structure of the body.
1213  auto bodyIt = body->begin(), bodyE = body->end();
1214  for (; bodyIt != bodyE; ++bodyIt)
1215  if (isa<ast::ReturnStmt>(*bodyIt))
1216  break;
1217  if (failed(validateUserConstraintOrRewriteReturn(
1218  "Constraint", body, bodyIt, bodyE, results, resultType)))
1219  return failure();
1220  }
1221  popDeclScope();
1222 
1223  return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1224  name, arguments, results, resultType, body);
1225 }
1226 
1227 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) {
1228  // Constraints and rewrites have very similar formats, dispatch to a shared
1229  // interface for parsing.
1230  return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1231  [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(args...); },
1232  ParserContext::Rewrite, "rewrite", isInline);
1233 }
1234 
1235 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() {
1236  FailureOr<ast::UserRewriteDecl *> decl =
1237  parseUserRewriteDecl(/*isInline=*/true);
1238  if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1239  return failure();
1240 
1241  curDeclScope->add(*decl);
1242  return decl;
1243 }
1244 
1245 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl(
1246  const ast::Name &name, bool isInline,
1247  ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1248  ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1249  // Push the argument scope back onto the list, so that the body can
1250  // reference arguments.
1251  curDeclScope = argumentScope;
1252  ast::CompoundStmt *body;
1253  if (curToken.is(Token::equal_arrow)) {
1254  FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1255  [&](ast::Stmt *&statement) -> LogicalResult {
1256  if (isa<ast::OpRewriteStmt>(statement))
1257  return success();
1258 
1259  ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement);
1260  if (!statementExpr) {
1261  return emitError(
1262  statement->getLoc(),
1263  "expected `Rewrite` lambda body to contain a single expression "
1264  "or an operation rewrite statement; such as `erase`, "
1265  "`replace`, or `rewrite`");
1266  }
1267  statement =
1268  ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr);
1269  return success();
1270  },
1271  /*expectTerminalSemicolon=*/!isInline);
1272  if (failed(bodyResult))
1273  return failure();
1274  body = *bodyResult;
1275  } else {
1276  FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1277  if (failed(bodyResult))
1278  return failure();
1279  body = *bodyResult;
1280  }
1281  popDeclScope();
1282 
1283  // Verify the structure of the body.
1284  auto bodyIt = body->begin(), bodyE = body->end();
1285  for (; bodyIt != bodyE; ++bodyIt)
1286  if (isa<ast::ReturnStmt>(*bodyIt))
1287  break;
1288  if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt,
1289  bodyE, results, resultType)))
1290  return failure();
1291  return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1292  name, arguments, results, resultType, body);
1293 }
1294 
1295 template <typename T, typename ParseUserPDLLDeclFnT>
1296 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl(
1297  ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
1298  StringRef anonymousNamePrefix, bool isInline) {
1299  SMRange loc = curToken.getLoc();
1300  consumeToken();
1301  llvm::SaveAndRestore saveCtx(parserContext, declContext);
1302 
1303  // Parse the name of the decl.
1304  const ast::Name *name = nullptr;
1305  if (curToken.isNot(Token::identifier)) {
1306  // Only inline decls can be un-named. Inline decls are similar to "lambdas"
1307  // in C++, so being unnamed is fine.
1308  if (!isInline)
1309  return emitError("expected identifier name");
1310 
1311  // Create a unique anonymous name to use, as the name for this decl is not
1312  // important.
1313  std::string anonName =
1314  llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix,
1315  anonymousDeclNameCounter++)
1316  .str();
1317  name = &ast::Name::create(ctx, anonName, loc);
1318  } else {
1319  // If a name was provided, we can use it directly.
1320  name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1321  consumeToken(Token::identifier);
1322  }
1323 
1324  // Parse the functional signature of the decl.
1325  SmallVector<ast::VariableDecl *> arguments, results;
1326  ast::DeclScope *argumentScope;
1327  ast::Type resultType;
1328  if (failed(parseUserConstraintOrRewriteSignature(arguments, results,
1329  argumentScope, resultType)))
1330  return failure();
1331 
1332  // Check to see which type of constraint this is. If the constraint contains a
1333  // compound body, this is a PDLL decl.
1334  if (curToken.isAny(Token::l_brace, Token::equal_arrow))
1335  return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results,
1336  resultType);
1337 
1338  // Otherwise, this is a native decl.
1339  return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments,
1340  results, resultType);
1341 }
1342 
1343 template <typename T>
1344 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
1345  const ast::Name &name, bool isInline,
1347  ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1348  // If followed by a string, the native code body has also been specified.
1349  std::string codeStrStorage;
1350  std::optional<StringRef> optCodeStr;
1351  if (curToken.isString()) {
1352  codeStrStorage = curToken.getStringValue();
1353  optCodeStr = codeStrStorage;
1354  consumeToken();
1355  } else if (isInline) {
1356  return emitError(name.getLoc(),
1357  "external declarations must be declared in global scope");
1358  } else if (curToken.is(Token::error)) {
1359  return failure();
1360  }
1361  if (failed(parseToken(Token::semicolon,
1362  "expected `;` after native declaration")))
1363  return failure();
1364  return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
1365 }
1366 
1367 LogicalResult Parser::parseUserConstraintOrRewriteSignature(
1370  ast::DeclScope *&argumentScope, ast::Type &resultType) {
1371  // Parse the argument list of the decl.
1372  if (failed(parseToken(Token::l_paren, "expected `(` to start argument list")))
1373  return failure();
1374 
1375  argumentScope = pushDeclScope();
1376  if (curToken.isNot(Token::r_paren)) {
1377  do {
1378  FailureOr<ast::VariableDecl *> argument = parseArgumentDecl();
1379  if (failed(argument))
1380  return failure();
1381  arguments.emplace_back(*argument);
1382  } while (consumeIf(Token::comma));
1383  }
1384  popDeclScope();
1385  if (failed(parseToken(Token::r_paren, "expected `)` to end argument list")))
1386  return failure();
1387 
1388  // Parse the results of the decl.
1389  pushDeclScope();
1390  if (consumeIf(Token::arrow)) {
1391  auto parseResultFn = [&]() -> LogicalResult {
1392  FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size());
1393  if (failed(result))
1394  return failure();
1395  results.emplace_back(*result);
1396  return success();
1397  };
1398 
1399  // Check for a list of results.
1400  if (consumeIf(Token::l_paren)) {
1401  do {
1402  if (failed(parseResultFn()))
1403  return failure();
1404  } while (consumeIf(Token::comma));
1405  if (failed(parseToken(Token::r_paren, "expected `)` to end result list")))
1406  return failure();
1407 
1408  // Otherwise, there is only one result.
1409  } else if (failed(parseResultFn())) {
1410  return failure();
1411  }
1412  }
1413  popDeclScope();
1414 
1415  // Compute the result type of the decl.
1416  resultType = createUserConstraintRewriteResultType(results);
1417 
1418  // Verify that results are only named if there are more than one.
1419  if (results.size() == 1 && !results.front()->getName().getName().empty()) {
1420  return emitError(
1421  results.front()->getLoc(),
1422  "cannot create a single-element tuple with an element label");
1423  }
1424  return success();
1425 }
1426 
1427 LogicalResult Parser::validateUserConstraintOrRewriteReturn(
1428  StringRef declType, ast::CompoundStmt *body,
1431  ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) {
1432  // Handle if a `return` was provided.
1433  if (bodyIt != bodyE) {
1434  // Emit an error if we have trailing statements after the return.
1435  if (std::next(bodyIt) != bodyE) {
1436  return emitError(
1437  (*std::next(bodyIt))->getLoc(),
1438  llvm::formatv("`return` terminated the `{0}` body, but found "
1439  "trailing statements afterwards",
1440  declType));
1441  }
1442 
1443  // Otherwise if a return wasn't provided, check that no results are
1444  // expected.
1445  } else if (!results.empty()) {
1446  return emitError(
1447  {body->getLoc().End, body->getLoc().End},
1448  llvm::formatv("missing return in a `{0}` expected to return `{1}`",
1449  declType, resultType));
1450  }
1451  return success();
1452 }
1453 
1454 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
1455  return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult {
1456  if (isa<ast::OpRewriteStmt>(statement))
1457  return success();
1458  return emitError(
1459  statement->getLoc(),
1460  "expected Pattern lambda body to contain a single operation "
1461  "rewrite statement, such as `erase`, `replace`, or `rewrite`");
1462  });
1463 }
1464 
1465 FailureOr<ast::Decl *> Parser::parsePatternDecl() {
1466  SMRange loc = curToken.getLoc();
1467  consumeToken(Token::kw_Pattern);
1468  llvm::SaveAndRestore saveCtx(parserContext, ParserContext::PatternMatch);
1469 
1470  // Check for an optional identifier for the pattern name.
1471  const ast::Name *name = nullptr;
1472  if (curToken.is(Token::identifier)) {
1473  name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1474  consumeToken(Token::identifier);
1475  }
1476 
1477  // Parse any pattern metadata.
1478  ParsedPatternMetadata metadata;
1479  if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata)))
1480  return failure();
1481 
1482  // Parse the pattern body.
1483  ast::CompoundStmt *body;
1484 
1485  // Handle a lambda body.
1486  if (curToken.is(Token::equal_arrow)) {
1487  FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody();
1488  if (failed(bodyResult))
1489  return failure();
1490  body = *bodyResult;
1491  } else {
1492  if (curToken.isNot(Token::l_brace))
1493  return emitError("expected `{` or `=>` to start pattern body");
1494  FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1495  if (failed(bodyResult))
1496  return failure();
1497  body = *bodyResult;
1498 
1499  // Verify the body of the pattern.
1500  auto bodyIt = body->begin(), bodyE = body->end();
1501  for (; bodyIt != bodyE; ++bodyIt) {
1502  if (isa<ast::ReturnStmt>(*bodyIt)) {
1503  return emitError((*bodyIt)->getLoc(),
1504  "`return` statements are only permitted within a "
1505  "`Constraint` or `Rewrite` body");
1506  }
1507  // Break when we've found the rewrite statement.
1508  if (isa<ast::OpRewriteStmt>(*bodyIt))
1509  break;
1510  }
1511  if (bodyIt == bodyE) {
1512  return emitError(loc,
1513  "expected Pattern body to terminate with an operation "
1514  "rewrite statement, such as `erase`");
1515  }
1516  if (std::next(bodyIt) != bodyE) {
1517  return emitError((*std::next(bodyIt))->getLoc(),
1518  "Pattern body was terminated by an operation "
1519  "rewrite statement, but found trailing statements");
1520  }
1521  }
1522 
1523  return createPatternDecl(loc, name, metadata, body);
1524 }
1525 
1526 LogicalResult
1527 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
1528  std::optional<SMRange> benefitLoc;
1529  std::optional<SMRange> hasBoundedRecursionLoc;
1530 
1531  do {
1532  // Handle metadata code completion.
1533  if (curToken.is(Token::code_complete))
1534  return codeCompletePatternMetadata();
1535 
1536  if (curToken.isNot(Token::identifier))
1537  return emitError("expected pattern metadata identifier");
1538  StringRef metadataStr = curToken.getSpelling();
1539  SMRange metadataLoc = curToken.getLoc();
1540  consumeToken(Token::identifier);
1541 
1542  // Parse the benefit metadata: benefit(<integer-value>)
1543  if (metadataStr == "benefit") {
1544  if (benefitLoc) {
1545  return emitErrorAndNote(metadataLoc,
1546  "pattern benefit has already been specified",
1547  *benefitLoc, "see previous definition here");
1548  }
1549  if (failed(parseToken(Token::l_paren,
1550  "expected `(` before pattern benefit")))
1551  return failure();
1552 
1553  uint16_t benefitValue = 0;
1554  if (curToken.isNot(Token::integer))
1555  return emitError("expected integral pattern benefit");
1556  if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue))
1557  return emitError(
1558  "expected pattern benefit to fit within a 16-bit integer");
1559  consumeToken(Token::integer);
1560 
1561  metadata.benefit = benefitValue;
1562  benefitLoc = metadataLoc;
1563 
1564  if (failed(
1565  parseToken(Token::r_paren, "expected `)` after pattern benefit")))
1566  return failure();
1567  continue;
1568  }
1569 
1570  // Parse the bounded recursion metadata: recursion
1571  if (metadataStr == "recursion") {
1572  if (hasBoundedRecursionLoc) {
1573  return emitErrorAndNote(
1574  metadataLoc,
1575  "pattern recursion metadata has already been specified",
1576  *hasBoundedRecursionLoc, "see previous definition here");
1577  }
1578  metadata.hasBoundedRecursion = true;
1579  hasBoundedRecursionLoc = metadataLoc;
1580  continue;
1581  }
1582 
1583  return emitError(metadataLoc, "unknown pattern metadata");
1584  } while (consumeIf(Token::comma));
1585 
1586  return success();
1587 }
1588 
1589 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() {
1590  consumeToken(Token::less);
1591 
1592  FailureOr<ast::Expr *> typeExpr = parseExpr();
1593  if (failed(typeExpr) ||
1594  failed(parseToken(Token::greater,
1595  "expected `>` after variable type constraint")))
1596  return failure();
1597  return typeExpr;
1598 }
1599 
1600 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) {
1601  assert(curDeclScope && "defining decl outside of a decl scope");
1602  if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) {
1603  return emitErrorAndNote(
1604  name.getLoc(), "`" + name.getName() + "` has already been defined",
1605  lastDecl->getName()->getLoc(), "see previous definition here");
1606  }
1607  return success();
1608 }
1609 
1610 FailureOr<ast::VariableDecl *>
1611 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1612  ast::Expr *initExpr,
1613  ArrayRef<ast::ConstraintRef> constraints) {
1614  assert(curDeclScope && "defining variable outside of decl scope");
1615  const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc);
1616 
1617  // If the name of the variable indicates a special variable, we don't add it
1618  // to the scope. This variable is local to the definition point.
1619  if (name.empty() || name == "_") {
1620  return ast::VariableDecl::create(ctx, nameDecl, type, initExpr,
1621  constraints);
1622  }
1623  if (failed(checkDefineNamedDecl(nameDecl)))
1624  return failure();
1625 
1626  auto *varDecl =
1627  ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints);
1628  curDeclScope->add(varDecl);
1629  return varDecl;
1630 }
1631 
1632 FailureOr<ast::VariableDecl *>
1633 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1634  ArrayRef<ast::ConstraintRef> constraints) {
1635  return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr,
1636  constraints);
1637 }
1638 
1639 LogicalResult Parser::parseVariableDeclConstraintList(
1640  SmallVectorImpl<ast::ConstraintRef> &constraints) {
1641  std::optional<SMRange> typeConstraint;
1642  auto parseSingleConstraint = [&] {
1643  FailureOr<ast::ConstraintRef> constraint = parseConstraint(
1644  typeConstraint, constraints, /*allowInlineTypeConstraints=*/true);
1645  if (failed(constraint))
1646  return failure();
1647  constraints.push_back(*constraint);
1648  return success();
1649  };
1650 
1651  // Check to see if this is a single constraint, or a list.
1652  if (!consumeIf(Token::l_square))
1653  return parseSingleConstraint();
1654 
1655  do {
1656  if (failed(parseSingleConstraint()))
1657  return failure();
1658  } while (consumeIf(Token::comma));
1659  return parseToken(Token::r_square, "expected `]` after constraint list");
1660 }
1661 
1662 FailureOr<ast::ConstraintRef>
1663 Parser::parseConstraint(std::optional<SMRange> &typeConstraint,
1664  ArrayRef<ast::ConstraintRef> existingConstraints,
1665  bool allowInlineTypeConstraints) {
1666  auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
1667  if (!allowInlineTypeConstraints) {
1668  return emitError(
1669  curToken.getLoc(),
1670  "inline `Attr`, `Value`, and `ValueRange` type constraints are not "
1671  "permitted on arguments or results");
1672  }
1673  if (typeConstraint)
1674  return emitErrorAndNote(
1675  curToken.getLoc(),
1676  "the type of this variable has already been constrained",
1677  *typeConstraint, "see previous constraint location here");
1678  FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr();
1679  if (failed(constraintExpr))
1680  return failure();
1681  typeExpr = *constraintExpr;
1682  typeConstraint = typeExpr->getLoc();
1683  return success();
1684  };
1685 
1686  SMRange loc = curToken.getLoc();
1687  switch (curToken.getKind()) {
1688  case Token::kw_Attr: {
1689  consumeToken(Token::kw_Attr);
1690 
1691  // Check for a type constraint.
1692  ast::Expr *typeExpr = nullptr;
1693  if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1694  return failure();
1695  return ast::ConstraintRef(
1696  ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc);
1697  }
1698  case Token::kw_Op: {
1699  consumeToken(Token::kw_Op);
1700 
1701  // Parse an optional operation name. If the name isn't provided, this refers
1702  // to "any" operation.
1703  FailureOr<ast::OpNameDecl *> opName =
1704  parseWrappedOperationName(/*allowEmptyName=*/true);
1705  if (failed(opName))
1706  return failure();
1707 
1708  return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName),
1709  loc);
1710  }
1711  case Token::kw_Type:
1712  consumeToken(Token::kw_Type);
1713  return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc);
1714  case Token::kw_TypeRange:
1715  consumeToken(Token::kw_TypeRange);
1717  loc);
1718  case Token::kw_Value: {
1719  consumeToken(Token::kw_Value);
1720 
1721  // Check for a type constraint.
1722  ast::Expr *typeExpr = nullptr;
1723  if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1724  return failure();
1725 
1726  return ast::ConstraintRef(
1727  ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc);
1728  }
1729  case Token::kw_ValueRange: {
1730  consumeToken(Token::kw_ValueRange);
1731 
1732  // Check for a type constraint.
1733  ast::Expr *typeExpr = nullptr;
1734  if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1735  return failure();
1736 
1737  return ast::ConstraintRef(
1738  ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc);
1739  }
1740 
1741  case Token::kw_Constraint: {
1742  // Handle an inline constraint.
1743  FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1744  if (failed(decl))
1745  return failure();
1746  return ast::ConstraintRef(*decl, loc);
1747  }
1748  case Token::identifier: {
1749  StringRef constraintName = curToken.getSpelling();
1750  consumeToken(Token::identifier);
1751 
1752  // Lookup the referenced constraint.
1753  ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName);
1754  if (!cstDecl) {
1755  return emitError(loc, "unknown reference to constraint `" +
1756  constraintName + "`");
1757  }
1758 
1759  // Handle a reference to a proper constraint.
1760  if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl))
1761  return ast::ConstraintRef(cst, loc);
1762 
1763  return emitErrorAndNote(
1764  loc, "invalid reference to non-constraint", cstDecl->getLoc(),
1765  "see the definition of `" + constraintName + "` here");
1766  }
1767  // Handle single entity constraint code completion.
1768  case Token::code_complete: {
1769  // Try to infer the current type for use by code completion.
1770  ast::Type inferredType;
1771  if (failed(validateVariableConstraints(existingConstraints, inferredType)))
1772  return failure();
1773 
1774  return codeCompleteConstraintName(inferredType, allowInlineTypeConstraints);
1775  }
1776  default:
1777  break;
1778  }
1779  return emitError(loc, "expected identifier constraint");
1780 }
1781 
1782 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() {
1783  std::optional<SMRange> typeConstraint;
1784  return parseConstraint(typeConstraint, /*existingConstraints=*/{},
1785  /*allowInlineTypeConstraints=*/false);
1786 }
1787 
1788 //===----------------------------------------------------------------------===//
1789 // Exprs
1790 //===----------------------------------------------------------------------===//
1791 
1792 FailureOr<ast::Expr *> Parser::parseExpr() {
1793  if (curToken.is(Token::underscore))
1794  return parseUnderscoreExpr();
1795 
1796  // Parse the LHS expression.
1797  FailureOr<ast::Expr *> lhsExpr;
1798  switch (curToken.getKind()) {
1799  case Token::kw_attr:
1800  lhsExpr = parseAttributeExpr();
1801  break;
1802  case Token::kw_Constraint:
1803  lhsExpr = parseInlineConstraintLambdaExpr();
1804  break;
1805  case Token::kw_not:
1806  lhsExpr = parseNegatedExpr();
1807  break;
1808  case Token::identifier:
1809  lhsExpr = parseIdentifierExpr();
1810  break;
1811  case Token::kw_op:
1812  lhsExpr = parseOperationExpr();
1813  break;
1814  case Token::kw_Rewrite:
1815  lhsExpr = parseInlineRewriteLambdaExpr();
1816  break;
1817  case Token::kw_type:
1818  lhsExpr = parseTypeExpr();
1819  break;
1820  case Token::l_paren:
1821  lhsExpr = parseTupleExpr();
1822  break;
1823  default:
1824  return emitError("expected expression");
1825  }
1826  if (failed(lhsExpr))
1827  return failure();
1828 
1829  // Check for an operator expression.
1830  while (true) {
1831  switch (curToken.getKind()) {
1832  case Token::dot:
1833  lhsExpr = parseMemberAccessExpr(*lhsExpr);
1834  break;
1835  case Token::l_paren:
1836  lhsExpr = parseCallExpr(*lhsExpr);
1837  break;
1838  default:
1839  return lhsExpr;
1840  }
1841  if (failed(lhsExpr))
1842  return failure();
1843  }
1844 }
1845 
1846 FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
1847  SMRange loc = curToken.getLoc();
1848  consumeToken(Token::kw_attr);
1849 
1850  // If we aren't followed by a `<`, the `attr` keyword is treated as a normal
1851  // identifier.
1852  if (!consumeIf(Token::less)) {
1853  resetToken(loc);
1854  return parseIdentifierExpr();
1855  }
1856 
1857  if (!curToken.isString())
1858  return emitError("expected string literal containing MLIR attribute");
1859  std::string attrExpr = curToken.getStringValue();
1860  consumeToken();
1861 
1862  loc.End = curToken.getEndLoc();
1863  if (failed(
1864  parseToken(Token::greater, "expected `>` after attribute literal")))
1865  return failure();
1866  return ast::AttributeExpr::create(ctx, loc, attrExpr);
1867 }
1868 
1869 FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr,
1870  bool isNegated) {
1871  consumeToken(Token::l_paren);
1872 
1873  // Parse the arguments of the call.
1874  SmallVector<ast::Expr *> arguments;
1875  if (curToken.isNot(Token::r_paren)) {
1876  do {
1877  // Handle code completion for the call arguments.
1878  if (curToken.is(Token::code_complete)) {
1879  codeCompleteCallSignature(parentExpr, arguments.size());
1880  return failure();
1881  }
1882 
1883  FailureOr<ast::Expr *> argument = parseExpr();
1884  if (failed(argument))
1885  return failure();
1886  arguments.push_back(*argument);
1887  } while (consumeIf(Token::comma));
1888  }
1889 
1890  SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1891  if (failed(parseToken(Token::r_paren, "expected `)` after argument list")))
1892  return failure();
1893 
1894  return createCallExpr(loc, parentExpr, arguments, isNegated);
1895 }
1896 
1897 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
1898  ast::Decl *decl = curDeclScope->lookup(name);
1899  if (!decl)
1900  return emitError(loc, "undefined reference to `" + name + "`");
1901 
1902  return createDeclRefExpr(loc, decl);
1903 }
1904 
1905 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
1906  StringRef name = curToken.getSpelling();
1907  SMRange nameLoc = curToken.getLoc();
1908  consumeToken();
1909 
1910  // Check to see if this is a decl ref expression that defines a variable
1911  // inline.
1912  if (consumeIf(Token::colon)) {
1913  SmallVector<ast::ConstraintRef> constraints;
1914  if (failed(parseVariableDeclConstraintList(constraints)))
1915  return failure();
1916  ast::Type type;
1917  if (failed(validateVariableConstraints(constraints, type)))
1918  return failure();
1919  return createInlineVariableExpr(type, name, nameLoc, constraints);
1920  }
1921 
1922  return parseDeclRefExpr(name, nameLoc);
1923 }
1924 
1925 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() {
1926  FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1927  if (failed(decl))
1928  return failure();
1929 
1930  return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1932 }
1933 
1934 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() {
1935  FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl();
1936  if (failed(decl))
1937  return failure();
1938 
1939  return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1940  ast::RewriteType::get(ctx));
1941 }
1942 
1943 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
1944  SMRange dotLoc = curToken.getLoc();
1945  consumeToken(Token::dot);
1946 
1947  // Check for code completion of the member name.
1948  if (curToken.is(Token::code_complete))
1949  return codeCompleteMemberAccess(parentExpr);
1950 
1951  // Parse the member name.
1952  Token memberNameTok = curToken;
1953  if (memberNameTok.isNot(Token::identifier, Token::integer) &&
1954  !memberNameTok.isKeyword())
1955  return emitError(dotLoc, "expected identifier or numeric member name");
1956  StringRef memberName = memberNameTok.getSpelling();
1957  SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1958  consumeToken();
1959 
1960  return createMemberAccessExpr(parentExpr, memberName, loc);
1961 }
1962 
1963 FailureOr<ast::Expr *> Parser::parseNegatedExpr() {
1964  consumeToken(Token::kw_not);
1965  // Only native constraints are supported after negation
1966  if (!curToken.is(Token::identifier))
1967  return emitError("expected native constraint");
1968  FailureOr<ast::Expr *> identifierExpr = parseIdentifierExpr();
1969  if (failed(identifierExpr))
1970  return failure();
1971  if (!curToken.is(Token::l_paren))
1972  return emitError("expected `(` after function name");
1973  return parseCallExpr(*identifierExpr, /*isNegated = */ true);
1974 }
1975 
1976 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
1977  SMRange loc = curToken.getLoc();
1978 
1979  // Check for code completion for the dialect name.
1980  if (curToken.is(Token::code_complete))
1981  return codeCompleteDialectName();
1982 
1983  // Handle the case of an no operation name.
1984  if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) {
1985  if (allowEmptyName)
1986  return ast::OpNameDecl::create(ctx, SMRange());
1987  return emitError("expected dialect namespace");
1988  }
1989  StringRef name = curToken.getSpelling();
1990  consumeToken();
1991 
1992  // Otherwise, this is a literal operation name.
1993  if (failed(parseToken(Token::dot, "expected `.` after dialect namespace")))
1994  return failure();
1995 
1996  // Check for code completion for the operation name.
1997  if (curToken.is(Token::code_complete))
1998  return codeCompleteOperationName(name);
1999 
2000  if (curToken.isNot(Token::identifier) && !curToken.isKeyword())
2001  return emitError("expected operation name after dialect namespace");
2002 
2003  name = StringRef(name.data(), name.size() + 1);
2004  do {
2005  name = StringRef(name.data(), name.size() + curToken.getSpelling().size());
2006  loc.End = curToken.getEndLoc();
2007  consumeToken();
2008  } while (curToken.isAny(Token::identifier, Token::dot) ||
2009  curToken.isKeyword());
2010  return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc));
2011 }
2012 
2013 FailureOr<ast::OpNameDecl *>
2014 Parser::parseWrappedOperationName(bool allowEmptyName) {
2015  if (!consumeIf(Token::less))
2016  return ast::OpNameDecl::create(ctx, SMRange());
2017 
2018  FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName);
2019  if (failed(opNameDecl))
2020  return failure();
2021 
2022  if (failed(parseToken(Token::greater, "expected `>` after operation name")))
2023  return failure();
2024  return opNameDecl;
2025 }
2026 
2027 FailureOr<ast::Expr *>
2028 Parser::parseOperationExpr(OpResultTypeContext inputResultTypeContext) {
2029  SMRange loc = curToken.getLoc();
2030  consumeToken(Token::kw_op);
2031 
2032  // If it isn't followed by a `<`, the `op` keyword is treated as a normal
2033  // identifier.
2034  if (curToken.isNot(Token::less)) {
2035  resetToken(loc);
2036  return parseIdentifierExpr();
2037  }
2038 
2039  // Parse the operation name. The name may be elided, in which case the
2040  // operation refers to "any" operation(i.e. a difference between `MyOp` and
2041  // `Operation*`). Operation names within a rewrite context must be named.
2042  bool allowEmptyName = parserContext != ParserContext::Rewrite;
2043  FailureOr<ast::OpNameDecl *> opNameDecl =
2044  parseWrappedOperationName(allowEmptyName);
2045  if (failed(opNameDecl))
2046  return failure();
2047  std::optional<StringRef> opName = (*opNameDecl)->getName();
2048 
2049  // Functor used to create an implicit range variable, used for implicit "all"
2050  // operand or results variables.
2051  auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) {
2052  FailureOr<ast::VariableDecl *> rangeVar =
2053  defineVariableDecl("_", loc, type, ast::ConstraintRef(cst, loc));
2054  assert(succeeded(rangeVar) && "expected range variable to be valid");
2055  return ast::DeclRefExpr::create(ctx, loc, *rangeVar, type);
2056  };
2057 
2058  // Check for the optional list of operands.
2059  SmallVector<ast::Expr *> operands;
2060  if (!consumeIf(Token::l_paren)) {
2061  // If the operand list isn't specified and we are in a match context, define
2062  // an inplace unconstrained operand range corresponding to all of the
2063  // operands of the operation. This avoids treating zero operands the same
2064  // way as "unconstrained operands".
2065  if (parserContext != ParserContext::Rewrite) {
2066  operands.push_back(createImplicitRangeVar(
2067  ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy));
2068  }
2069  } else if (!consumeIf(Token::r_paren)) {
2070  // If the operand list was specified and non-empty, parse the operands.
2071  do {
2072  // Check for operand signature code completion.
2073  if (curToken.is(Token::code_complete)) {
2074  codeCompleteOperationOperandsSignature(opName, operands.size());
2075  return failure();
2076  }
2077 
2078  FailureOr<ast::Expr *> operand = parseExpr();
2079  if (failed(operand))
2080  return failure();
2081  operands.push_back(*operand);
2082  } while (consumeIf(Token::comma));
2083 
2084  if (failed(parseToken(Token::r_paren,
2085  "expected `)` after operation operand list")))
2086  return failure();
2087  }
2088 
2089  // Check for the optional list of attributes.
2091  if (consumeIf(Token::l_brace)) {
2092  do {
2093  FailureOr<ast::NamedAttributeDecl *> decl =
2094  parseNamedAttributeDecl(opName);
2095  if (failed(decl))
2096  return failure();
2097  attributes.emplace_back(*decl);
2098  } while (consumeIf(Token::comma));
2099 
2100  if (failed(parseToken(Token::r_brace,
2101  "expected `}` after operation attribute list")))
2102  return failure();
2103  }
2104 
2105  // Handle the result types of the operation.
2106  SmallVector<ast::Expr *> resultTypes;
2107  OpResultTypeContext resultTypeContext = inputResultTypeContext;
2108 
2109  // Check for an explicit list of result types.
2110  if (consumeIf(Token::arrow)) {
2111  if (failed(parseToken(Token::l_paren,
2112  "expected `(` before operation result type list")))
2113  return failure();
2114 
2115  // If result types are provided, initially assume that the operation does
2116  // not rely on type inferrence. We don't assert that it isn't, because we
2117  // may be inferring the value of some type/type range variables, but given
2118  // that these variables may be defined in calls we can't always discern when
2119  // this is the case.
2120  resultTypeContext = OpResultTypeContext::Explicit;
2121 
2122  // Handle the case of an empty result list.
2123  if (!consumeIf(Token::r_paren)) {
2124  do {
2125  // Check for result signature code completion.
2126  if (curToken.is(Token::code_complete)) {
2127  codeCompleteOperationResultsSignature(opName, resultTypes.size());
2128  return failure();
2129  }
2130 
2131  FailureOr<ast::Expr *> resultTypeExpr = parseExpr();
2132  if (failed(resultTypeExpr))
2133  return failure();
2134  resultTypes.push_back(*resultTypeExpr);
2135  } while (consumeIf(Token::comma));
2136 
2137  if (failed(parseToken(Token::r_paren,
2138  "expected `)` after operation result type list")))
2139  return failure();
2140  }
2141  } else if (parserContext != ParserContext::Rewrite) {
2142  // If the result list isn't specified and we are in a match context, define
2143  // an inplace unconstrained result range corresponding to all of the results
2144  // of the operation. This avoids treating zero results the same way as
2145  // "unconstrained results".
2146  resultTypes.push_back(createImplicitRangeVar(
2147  ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy));
2148  } else if (resultTypeContext == OpResultTypeContext::Explicit) {
2149  // If the result list isn't specified and we are in a rewrite, try to infer
2150  // them at runtime instead.
2151  resultTypeContext = OpResultTypeContext::Interface;
2152  }
2153 
2154  return createOperationExpr(loc, *opNameDecl, resultTypeContext, operands,
2155  attributes, resultTypes);
2156 }
2157 
2158 FailureOr<ast::Expr *> Parser::parseTupleExpr() {
2159  SMRange loc = curToken.getLoc();
2160  consumeToken(Token::l_paren);
2161 
2162  DenseMap<StringRef, SMRange> usedNames;
2163  SmallVector<StringRef> elementNames;
2164  SmallVector<ast::Expr *> elements;
2165  if (curToken.isNot(Token::r_paren)) {
2166  do {
2167  // Check for the optional element name assignment before the value.
2168  StringRef elementName;
2169  if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
2170  Token elementNameTok = curToken;
2171  consumeToken();
2172 
2173  // The element name is only present if followed by an `=`.
2174  if (consumeIf(Token::equal)) {
2175  elementName = elementNameTok.getSpelling();
2176 
2177  // Check to see if this name is already used.
2178  auto elementNameIt =
2179  usedNames.try_emplace(elementName, elementNameTok.getLoc());
2180  if (!elementNameIt.second) {
2181  return emitErrorAndNote(
2182  elementNameTok.getLoc(),
2183  llvm::formatv("duplicate tuple element label `{0}`",
2184  elementName),
2185  elementNameIt.first->getSecond(),
2186  "see previous label use here");
2187  }
2188  } else {
2189  // Otherwise, we treat this as part of an expression so reset the
2190  // lexer.
2191  resetToken(elementNameTok.getLoc());
2192  }
2193  }
2194  elementNames.push_back(elementName);
2195 
2196  // Parse the tuple element value.
2197  FailureOr<ast::Expr *> element = parseExpr();
2198  if (failed(element))
2199  return failure();
2200  elements.push_back(*element);
2201  } while (consumeIf(Token::comma));
2202  }
2203  loc.End = curToken.getEndLoc();
2204  if (failed(
2205  parseToken(Token::r_paren, "expected `)` after tuple element list")))
2206  return failure();
2207  return createTupleExpr(loc, elements, elementNames);
2208 }
2209 
2210 FailureOr<ast::Expr *> Parser::parseTypeExpr() {
2211  SMRange loc = curToken.getLoc();
2212  consumeToken(Token::kw_type);
2213 
2214  // If we aren't followed by a `<`, the `type` keyword is treated as a normal
2215  // identifier.
2216  if (!consumeIf(Token::less)) {
2217  resetToken(loc);
2218  return parseIdentifierExpr();
2219  }
2220 
2221  if (!curToken.isString())
2222  return emitError("expected string literal containing MLIR type");
2223  std::string attrExpr = curToken.getStringValue();
2224  consumeToken();
2225 
2226  loc.End = curToken.getEndLoc();
2227  if (failed(parseToken(Token::greater, "expected `>` after type literal")))
2228  return failure();
2229  return ast::TypeExpr::create(ctx, loc, attrExpr);
2230 }
2231 
2232 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
2233  StringRef name = curToken.getSpelling();
2234  SMRange nameLoc = curToken.getLoc();
2235  consumeToken(Token::underscore);
2236 
2237  // Underscore expressions require a constraint list.
2238  if (failed(parseToken(Token::colon, "expected `:` after `_` variable")))
2239  return failure();
2240 
2241  // Parse the constraints for the expression.
2242  SmallVector<ast::ConstraintRef> constraints;
2243  if (failed(parseVariableDeclConstraintList(constraints)))
2244  return failure();
2245 
2246  ast::Type type;
2247  if (failed(validateVariableConstraints(constraints, type)))
2248  return failure();
2249  return createInlineVariableExpr(type, name, nameLoc, constraints);
2250 }
2251 
2252 //===----------------------------------------------------------------------===//
2253 // Stmts
2254 //===----------------------------------------------------------------------===//
2255 
2256 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) {
2257  FailureOr<ast::Stmt *> stmt;
2258  switch (curToken.getKind()) {
2259  case Token::kw_erase:
2260  stmt = parseEraseStmt();
2261  break;
2262  case Token::kw_let:
2263  stmt = parseLetStmt();
2264  break;
2265  case Token::kw_replace:
2266  stmt = parseReplaceStmt();
2267  break;
2268  case Token::kw_return:
2269  stmt = parseReturnStmt();
2270  break;
2271  case Token::kw_rewrite:
2272  stmt = parseRewriteStmt();
2273  break;
2274  default:
2275  stmt = parseExpr();
2276  break;
2277  }
2278  if (failed(stmt) ||
2279  (expectTerminalSemicolon &&
2280  failed(parseToken(Token::semicolon, "expected `;` after statement"))))
2281  return failure();
2282  return stmt;
2283 }
2284 
2285 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() {
2286  SMLoc startLoc = curToken.getStartLoc();
2287  consumeToken(Token::l_brace);
2288 
2289  // Push a new block scope and parse any nested statements.
2290  pushDeclScope();
2291  SmallVector<ast::Stmt *> statements;
2292  while (curToken.isNot(Token::r_brace)) {
2293  FailureOr<ast::Stmt *> statement = parseStmt();
2294  if (failed(statement))
2295  return popDeclScope(), failure();
2296  statements.push_back(*statement);
2297  }
2298  popDeclScope();
2299 
2300  // Consume the end brace.
2301  SMRange location(startLoc, curToken.getEndLoc());
2302  consumeToken(Token::r_brace);
2303 
2304  return ast::CompoundStmt::create(ctx, location, statements);
2305 }
2306 
2307 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
2308  if (parserContext == ParserContext::Constraint)
2309  return emitError("`erase` cannot be used within a Constraint");
2310  SMRange loc = curToken.getLoc();
2311  consumeToken(Token::kw_erase);
2312 
2313  // Parse the root operation expression.
2314  FailureOr<ast::Expr *> rootOp = parseExpr();
2315  if (failed(rootOp))
2316  return failure();
2317 
2318  return createEraseStmt(loc, *rootOp);
2319 }
2320 
2321 FailureOr<ast::LetStmt *> Parser::parseLetStmt() {
2322  SMRange loc = curToken.getLoc();
2323  consumeToken(Token::kw_let);
2324 
2325  // Parse the name of the new variable.
2326  SMRange varLoc = curToken.getLoc();
2327  if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) {
2328  // `_` is a reserved variable name.
2329  if (curToken.is(Token::underscore)) {
2330  return emitError(varLoc,
2331  "`_` may only be used to define \"inline\" variables");
2332  }
2333  return emitError(varLoc,
2334  "expected identifier after `let` to name a new variable");
2335  }
2336  StringRef varName = curToken.getSpelling();
2337  consumeToken();
2338 
2339  // Parse the optional set of constraints.
2340  SmallVector<ast::ConstraintRef> constraints;
2341  if (consumeIf(Token::colon) &&
2342  failed(parseVariableDeclConstraintList(constraints)))
2343  return failure();
2344 
2345  // Parse the optional initializer expression.
2346  ast::Expr *initializer = nullptr;
2347  if (consumeIf(Token::equal)) {
2348  FailureOr<ast::Expr *> initOrFailure = parseExpr();
2349  if (failed(initOrFailure))
2350  return failure();
2351  initializer = *initOrFailure;
2352 
2353  // Check that the constraints are compatible with having an initializer,
2354  // e.g. type constraints cannot be used with initializers.
2355  for (ast::ConstraintRef constraint : constraints) {
2356  LogicalResult result =
2357  TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint)
2359  ast::ValueRangeConstraintDecl>([&](const auto *cst) {
2360  if (cst->getTypeExpr()) {
2361  return this->emitError(
2362  constraint.referenceLoc,
2363  "type constraints are not permitted on variables with "
2364  "initializers");
2365  }
2366  return success();
2367  })
2368  .Default(success());
2369  if (failed(result))
2370  return failure();
2371  }
2372  }
2373 
2374  FailureOr<ast::VariableDecl *> varDecl =
2375  createVariableDecl(varName, varLoc, initializer, constraints);
2376  if (failed(varDecl))
2377  return failure();
2378  return ast::LetStmt::create(ctx, loc, *varDecl);
2379 }
2380 
2381 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
2382  if (parserContext == ParserContext::Constraint)
2383  return emitError("`replace` cannot be used within a Constraint");
2384  SMRange loc = curToken.getLoc();
2385  consumeToken(Token::kw_replace);
2386 
2387  // Parse the root operation expression.
2388  FailureOr<ast::Expr *> rootOp = parseExpr();
2389  if (failed(rootOp))
2390  return failure();
2391 
2392  if (failed(
2393  parseToken(Token::kw_with, "expected `with` after root operation")))
2394  return failure();
2395 
2396  // The replacement portion of this statement is within a rewrite context.
2397  llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
2398 
2399  // Parse the replacement values.
2400  SmallVector<ast::Expr *> replValues;
2401  if (consumeIf(Token::l_paren)) {
2402  if (consumeIf(Token::r_paren)) {
2403  return emitError(
2404  loc, "expected at least one replacement value, consider using "
2405  "`erase` if no replacement values are desired");
2406  }
2407 
2408  do {
2409  FailureOr<ast::Expr *> replExpr = parseExpr();
2410  if (failed(replExpr))
2411  return failure();
2412  replValues.emplace_back(*replExpr);
2413  } while (consumeIf(Token::comma));
2414 
2415  if (failed(parseToken(Token::r_paren,
2416  "expected `)` after replacement values")))
2417  return failure();
2418  } else {
2419  // Handle replacement with an operation uniquely, as the replacement
2420  // operation supports type inferrence from the root operation.
2421  FailureOr<ast::Expr *> replExpr;
2422  if (curToken.is(Token::kw_op))
2423  replExpr = parseOperationExpr(OpResultTypeContext::Replacement);
2424  else
2425  replExpr = parseExpr();
2426  if (failed(replExpr))
2427  return failure();
2428  replValues.emplace_back(*replExpr);
2429  }
2430 
2431  return createReplaceStmt(loc, *rootOp, replValues);
2432 }
2433 
2434 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() {
2435  SMRange loc = curToken.getLoc();
2436  consumeToken(Token::kw_return);
2437 
2438  // Parse the result value.
2439  FailureOr<ast::Expr *> resultExpr = parseExpr();
2440  if (failed(resultExpr))
2441  return failure();
2442 
2443  return ast::ReturnStmt::create(ctx, loc, *resultExpr);
2444 }
2445 
2446 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
2447  if (parserContext == ParserContext::Constraint)
2448  return emitError("`rewrite` cannot be used within a Constraint");
2449  SMRange loc = curToken.getLoc();
2450  consumeToken(Token::kw_rewrite);
2451 
2452  // Parse the root operation.
2453  FailureOr<ast::Expr *> rootOp = parseExpr();
2454  if (failed(rootOp))
2455  return failure();
2456 
2457  if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body")))
2458  return failure();
2459 
2460  if (curToken.isNot(Token::l_brace))
2461  return emitError("expected `{` to start rewrite body");
2462 
2463  // The rewrite body of this statement is within a rewrite context.
2464  llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
2465 
2466  FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt();
2467  if (failed(rewriteBody))
2468  return failure();
2469 
2470  // Verify the rewrite body.
2471  for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) {
2472  if (isa<ast::ReturnStmt>(stmt)) {
2473  return emitError(stmt->getLoc(),
2474  "`return` statements are only permitted within a "
2475  "`Constraint` or `Rewrite` body");
2476  }
2477  }
2478 
2479  return createRewriteStmt(loc, *rootOp, *rewriteBody);
2480 }
2481 
2482 //===----------------------------------------------------------------------===//
2483 // Creation+Analysis
2484 //===----------------------------------------------------------------------===//
2485 
2486 //===----------------------------------------------------------------------===//
2487 // Decls
2488 //===----------------------------------------------------------------------===//
2489 
2490 ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) {
2491  // Unwrap reference expressions.
2492  if (auto *init = dyn_cast<ast::DeclRefExpr>(node))
2493  node = init->getDecl();
2494  return dyn_cast<ast::CallableDecl>(node);
2495 }
2496 
2497 FailureOr<ast::PatternDecl *>
2498 Parser::createPatternDecl(SMRange loc, const ast::Name *name,
2499  const ParsedPatternMetadata &metadata,
2500  ast::CompoundStmt *body) {
2501  return ast::PatternDecl::create(ctx, loc, name, metadata.benefit,
2502  metadata.hasBoundedRecursion, body);
2503 }
2504 
2505 ast::Type Parser::createUserConstraintRewriteResultType(
2507  // Single result decls use the type of the single result.
2508  if (results.size() == 1)
2509  return results[0]->getType();
2510 
2511  // Multiple results use a tuple type, with the types and names grabbed from
2512  // the result variable decls.
2513  auto resultTypes = llvm::map_range(
2514  results, [&](const auto *result) { return result->getType(); });
2515  auto resultNames = llvm::map_range(
2516  results, [&](const auto *result) { return result->getName().getName(); });
2517  return ast::TupleType::get(ctx, llvm::to_vector(resultTypes),
2518  llvm::to_vector(resultNames));
2519 }
2520 
2521 template <typename T>
2522 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl(
2523  const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
2524  ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
2525  ast::CompoundStmt *body) {
2526  if (!body->getChildren().empty()) {
2527  if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) {
2528  ast::Expr *resultExpr = retStmt->getResultExpr();
2529 
2530  // Process the result of the decl. If no explicit signature results
2531  // were provided, check for return type inference. Otherwise, check that
2532  // the return expression can be converted to the expected type.
2533  if (results.empty())
2534  resultType = resultExpr->getType();
2535  else if (failed(convertExpressionTo(resultExpr, resultType)))
2536  return failure();
2537  else
2538  retStmt->setResultExpr(resultExpr);
2539  }
2540  }
2541  return T::createPDLL(ctx, name, arguments, results, body, resultType);
2542 }
2543 
2544 FailureOr<ast::VariableDecl *>
2545 Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
2546  ArrayRef<ast::ConstraintRef> constraints) {
2547  // The type of the variable, which is expected to be inferred by either a
2548  // constraint or an initializer expression.
2549  ast::Type type;
2550  if (failed(validateVariableConstraints(constraints, type)))
2551  return failure();
2552 
2553  if (initializer) {
2554  // Update the variable type based on the initializer, or try to convert the
2555  // initializer to the existing type.
2556  if (!type)
2557  type = initializer->getType();
2558  else if (ast::Type mergedType = type.refineWith(initializer->getType()))
2559  type = mergedType;
2560  else if (failed(convertExpressionTo(initializer, type)))
2561  return failure();
2562 
2563  // Otherwise, if there is no initializer check that the type has already
2564  // been resolved from the constraint list.
2565  } else if (!type) {
2566  return emitErrorAndNote(
2567  loc, "unable to infer type for variable `" + name + "`", loc,
2568  "the type of a variable must be inferable from the constraint "
2569  "list or the initializer");
2570  }
2571 
2572  // Constraint types cannot be used when defining variables.
2573  if (isa<ast::ConstraintType, ast::RewriteType>(type)) {
2574  return emitError(
2575  loc, llvm::formatv("unable to define variable of `{0}` type", type));
2576  }
2577 
2578  // Try to define a variable with the given name.
2579  FailureOr<ast::VariableDecl *> varDecl =
2580  defineVariableDecl(name, loc, type, initializer, constraints);
2581  if (failed(varDecl))
2582  return failure();
2583 
2584  return *varDecl;
2585 }
2586 
2587 FailureOr<ast::VariableDecl *>
2588 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
2589  const ast::ConstraintRef &constraint) {
2590  ast::Type argType;
2591  if (failed(validateVariableConstraint(constraint, argType)))
2592  return failure();
2593  return defineVariableDecl(name, loc, argType, constraint);
2594 }
2595 
2596 LogicalResult
2597 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
2598  ast::Type &inferredType) {
2599  for (const ast::ConstraintRef &ref : constraints)
2600  if (failed(validateVariableConstraint(ref, inferredType)))
2601  return failure();
2602  return success();
2603 }
2604 
2605 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
2606  ast::Type &inferredType) {
2607  ast::Type constraintType;
2608  if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) {
2609  if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2610  if (failed(validateTypeConstraintExpr(typeExpr)))
2611  return failure();
2612  }
2613  constraintType = ast::AttributeType::get(ctx);
2614  } else if (const auto *cst =
2615  dyn_cast<ast::OpConstraintDecl>(ref.constraint)) {
2616  constraintType = ast::OperationType::get(
2617  ctx, cst->getName(), lookupODSOperation(cst->getName()));
2618  } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) {
2619  constraintType = typeTy;
2620  } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) {
2621  constraintType = typeRangeTy;
2622  } else if (const auto *cst =
2623  dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) {
2624  if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2625  if (failed(validateTypeConstraintExpr(typeExpr)))
2626  return failure();
2627  }
2628  constraintType = valueTy;
2629  } else if (const auto *cst =
2630  dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) {
2631  if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2632  if (failed(validateTypeRangeConstraintExpr(typeExpr)))
2633  return failure();
2634  }
2635  constraintType = valueRangeTy;
2636  } else if (const auto *cst =
2637  dyn_cast<ast::UserConstraintDecl>(ref.constraint)) {
2638  ArrayRef<ast::VariableDecl *> inputs = cst->getInputs();
2639  if (inputs.size() != 1) {
2640  return emitErrorAndNote(ref.referenceLoc,
2641  "`Constraint`s applied via a variable constraint "
2642  "list must take a single input, but got " +
2643  Twine(inputs.size()),
2644  cst->getLoc(),
2645  "see definition of constraint here");
2646  }
2647  constraintType = inputs.front()->getType();
2648  } else {
2649  llvm_unreachable("unknown constraint type");
2650  }
2651 
2652  // Check that the constraint type is compatible with the current inferred
2653  // type.
2654  if (!inferredType) {
2655  inferredType = constraintType;
2656  } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) {
2657  inferredType = mergedTy;
2658  } else {
2659  return emitError(ref.referenceLoc,
2660  llvm::formatv("constraint type `{0}` is incompatible "
2661  "with the previously inferred type `{1}`",
2662  constraintType, inferredType));
2663  }
2664  return success();
2665 }
2666 
2667 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) {
2668  ast::Type typeExprType = typeExpr->getType();
2669  if (typeExprType != typeTy) {
2670  return emitError(typeExpr->getLoc(),
2671  "expected expression of `Type` in type constraint");
2672  }
2673  return success();
2674 }
2675 
2676 LogicalResult
2677 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) {
2678  ast::Type typeExprType = typeExpr->getType();
2679  if (typeExprType != typeRangeTy) {
2680  return emitError(typeExpr->getLoc(),
2681  "expected expression of `TypeRange` in type constraint");
2682  }
2683  return success();
2684 }
2685 
2686 //===----------------------------------------------------------------------===//
2687 // Exprs
2688 //===----------------------------------------------------------------------===//
2689 
2690 FailureOr<ast::CallExpr *>
2691 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
2692  MutableArrayRef<ast::Expr *> arguments, bool isNegated) {
2693  ast::Type parentType = parentExpr->getType();
2694 
2695  ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr);
2696  if (!callableDecl) {
2697  return emitError(loc,
2698  llvm::formatv("expected a reference to a callable "
2699  "`Constraint` or `Rewrite`, but got: `{0}`",
2700  parentType));
2701  }
2702  if (parserContext == ParserContext::Rewrite) {
2703  if (isa<ast::UserConstraintDecl>(callableDecl))
2704  return emitError(
2705  loc, "unable to invoke `Constraint` within a rewrite section");
2706  if (isNegated)
2707  return emitError(loc, "unable to negate a Rewrite");
2708  } else {
2709  if (isa<ast::UserRewriteDecl>(callableDecl))
2710  return emitError(loc,
2711  "unable to invoke `Rewrite` within a match section");
2712  if (isNegated && cast<ast::UserConstraintDecl>(callableDecl)->getBody())
2713  return emitError(loc, "unable to negate non native constraints");
2714  }
2715 
2716  // Verify the arguments of the call.
2717  /// Handle size mismatch.
2718  ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs();
2719  if (callArgs.size() != arguments.size()) {
2720  return emitErrorAndNote(
2721  loc,
2722  llvm::formatv("invalid number of arguments for {0} call; expected "
2723  "{1}, but got {2}",
2724  callableDecl->getCallableType(), callArgs.size(),
2725  arguments.size()),
2726  callableDecl->getLoc(),
2727  llvm::formatv("see the definition of {0} here",
2728  callableDecl->getName()->getName()));
2729  }
2730 
2731  /// Handle argument type mismatch.
2732  auto attachDiagFn = [&](ast::Diagnostic &diag) {
2733  diag.attachNote(llvm::formatv("see the definition of `{0}` here",
2734  callableDecl->getName()->getName()),
2735  callableDecl->getLoc());
2736  };
2737  for (auto it : llvm::zip(callArgs, arguments)) {
2738  if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(),
2739  attachDiagFn)))
2740  return failure();
2741  }
2742 
2743  return ast::CallExpr::create(ctx, loc, parentExpr, arguments,
2744  callableDecl->getResultType(), isNegated);
2745 }
2746 
2747 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,
2748  ast::Decl *decl) {
2749  // Check the type of decl being referenced.
2750  ast::Type declType;
2751  if (isa<ast::ConstraintDecl>(decl))
2752  declType = ast::ConstraintType::get(ctx);
2753  else if (isa<ast::UserRewriteDecl>(decl))
2754  declType = ast::RewriteType::get(ctx);
2755  else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
2756  declType = varDecl->getType();
2757  else
2758  return emitError(loc, "invalid reference to `" +
2759  decl->getName()->getName() + "`");
2760 
2761  return ast::DeclRefExpr::create(ctx, loc, decl, declType);
2762 }
2763 
2764 FailureOr<ast::DeclRefExpr *>
2765 Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
2766  ArrayRef<ast::ConstraintRef> constraints) {
2767  FailureOr<ast::VariableDecl *> decl =
2768  defineVariableDecl(name, loc, type, constraints);
2769  if (failed(decl))
2770  return failure();
2771  return ast::DeclRefExpr::create(ctx, loc, *decl, type);
2772 }
2773 
2774 FailureOr<ast::MemberAccessExpr *>
2775 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
2776  SMRange loc) {
2777  // Validate the member name for the given parent expression.
2778  FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc);
2779  if (failed(memberType))
2780  return failure();
2781 
2782  return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType);
2783 }
2784 
2785 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
2786  StringRef name, SMRange loc) {
2787  ast::Type parentType = parentExpr->getType();
2788  if (ast::OperationType opType = dyn_cast<ast::OperationType>(parentType)) {
2790  return valueRangeTy;
2791 
2792  // Verify member access based on the operation type.
2793  if (const ods::Operation *odsOp = opType.getODSOperation()) {
2794  auto results = odsOp->getResults();
2795 
2796  // Handle indexed results.
2797  unsigned index = 0;
2798  if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2799  index < results.size()) {
2800  return results[index].isVariadic() ? valueRangeTy : valueTy;
2801  }
2802 
2803  // Handle named results.
2804  const auto *it = llvm::find_if(results, [&](const auto &result) {
2805  return result.getName() == name;
2806  });
2807  if (it != results.end())
2808  return it->isVariadic() ? valueRangeTy : valueTy;
2809  } else if (llvm::isDigit(name[0])) {
2810  // Allow unchecked numeric indexing of the results of unregistered
2811  // operations. It returns a single value.
2812  return valueTy;
2813  }
2814  } else if (auto tupleType = dyn_cast<ast::TupleType>(parentType)) {
2815  // Handle indexed results.
2816  unsigned index = 0;
2817  if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2818  index < tupleType.size()) {
2819  return tupleType.getElementTypes()[index];
2820  }
2821 
2822  // Handle named results.
2823  auto elementNames = tupleType.getElementNames();
2824  const auto *it = llvm::find(elementNames, name);
2825  if (it != elementNames.end())
2826  return tupleType.getElementTypes()[it - elementNames.begin()];
2827  }
2828  return emitError(
2829  loc,
2830  llvm::formatv("invalid member access `{0}` on expression of type `{1}`",
2831  name, parentType));
2832 }
2833 
2834 FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
2835  SMRange loc, const ast::OpNameDecl *name,
2836  OpResultTypeContext resultTypeContext,
2837  SmallVectorImpl<ast::Expr *> &operands,
2839  SmallVectorImpl<ast::Expr *> &results) {
2840  std::optional<StringRef> opNameRef = name->getName();
2841  const ods::Operation *odsOp = lookupODSOperation(opNameRef);
2842 
2843  // Verify the inputs operands.
2844  if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands)))
2845  return failure();
2846 
2847  // Verify the attribute list.
2848  for (ast::NamedAttributeDecl *attr : attributes) {
2849  // Check for an attribute type, or a type awaiting resolution.
2850  ast::Type attrType = attr->getValue()->getType();
2851  if (!isa<ast::AttributeType>(attrType)) {
2852  return emitError(
2853  attr->getValue()->getLoc(),
2854  llvm::formatv("expected `Attr` expression, but got `{0}`", attrType));
2855  }
2856  }
2857 
2858  assert(
2859  (resultTypeContext == OpResultTypeContext::Explicit || results.empty()) &&
2860  "unexpected inferrence when results were explicitly specified");
2861 
2862  // If we aren't relying on type inferrence, or explicit results were provided,
2863  // validate them.
2864  if (resultTypeContext == OpResultTypeContext::Explicit) {
2865  if (failed(validateOperationResults(loc, opNameRef, odsOp, results)))
2866  return failure();
2867 
2868  // Validate the use of interface based type inferrence for this operation.
2869  } else if (resultTypeContext == OpResultTypeContext::Interface) {
2870  assert(opNameRef &&
2871  "expected valid operation name when inferring operation results");
2872  checkOperationResultTypeInferrence(loc, *opNameRef, odsOp);
2873  }
2874 
2875  return ast::OperationExpr::create(ctx, loc, odsOp, name, operands, results,
2876  attributes);
2877 }
2878 
2879 LogicalResult
2880 Parser::validateOperationOperands(SMRange loc, std::optional<StringRef> name,
2881  const ods::Operation *odsOp,
2882  SmallVectorImpl<ast::Expr *> &operands) {
2883  return validateOperationOperandsOrResults(
2884  "operand", loc, odsOp ? odsOp->getLoc() : std::optional<SMRange>(), name,
2885  operands,
2887  valueTy, valueRangeTy);
2888 }
2889 
2890 LogicalResult
2891 Parser::validateOperationResults(SMRange loc, std::optional<StringRef> name,
2892  const ods::Operation *odsOp,
2893  SmallVectorImpl<ast::Expr *> &results) {
2894  return validateOperationOperandsOrResults(
2895  "result", loc, odsOp ? odsOp->getLoc() : std::optional<SMRange>(), name,
2896  results,
2897  odsOp ? odsOp->getResults() : ArrayRef<pdll::ods::OperandOrResult>(),
2898  typeTy, typeRangeTy);
2899 }
2900 
2901 void Parser::checkOperationResultTypeInferrence(SMRange loc, StringRef opName,
2902  const ods::Operation *odsOp) {
2903  // If the operation might not have inferrence support, emit a warning to the
2904  // user. We don't emit an error because the interface might be added to the
2905  // operation at runtime. It's rare, but it could still happen. We emit a
2906  // warning here instead.
2907 
2908  // Handle inferrence warnings for unknown operations.
2909  if (!odsOp) {
2910  ctx.getDiagEngine().emitWarning(
2911  loc, llvm::formatv(
2912  "operation result types are marked to be inferred, but "
2913  "`{0}` is unknown. Ensure that `{0}` supports zero "
2914  "results or implements `InferTypeOpInterface`. Include "
2915  "the ODS definition of this operation to remove this warning.",
2916  opName));
2917  return;
2918  }
2919 
2920  // Handle inferrence warnings for known operations that expected at least one
2921  // result, but don't have inference support. An elided results list can mean
2922  // "zero-results", and we don't want to warn when that is the expected
2923  // behavior.
2924  bool requiresInferrence =
2925  llvm::any_of(odsOp->getResults(), [](const ods::OperandOrResult &result) {
2926  return !result.isVariableLength();
2927  });
2928  if (requiresInferrence && !odsOp->hasResultTypeInferrence()) {
2930  loc,
2931  llvm::formatv("operation result types are marked to be inferred, but "
2932  "`{0}` does not provide an implementation of "
2933  "`InferTypeOpInterface`. Ensure that `{0}` attaches "
2934  "`InferTypeOpInterface` at runtime, or add support to "
2935  "the ODS definition to remove this warning.",
2936  opName));
2937  diag->attachNote(llvm::formatv("see the definition of `{0}` here", opName),
2938  odsOp->getLoc());
2939  return;
2940  }
2941 }
2942 
2943 LogicalResult Parser::validateOperationOperandsOrResults(
2944  StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc,
2945  std::optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values,
2946  ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
2947  ast::RangeType rangeTy) {
2948  // All operation types accept a single range parameter.
2949  if (values.size() == 1) {
2950  if (failed(convertExpressionTo(values[0], rangeTy)))
2951  return failure();
2952  return success();
2953  }
2954 
2955  /// If the operation has ODS information, we can more accurately verify the
2956  /// values.
2957  if (odsOpLoc) {
2958  auto emitSizeMismatchError = [&] {
2959  return emitErrorAndNote(
2960  loc,
2961  llvm::formatv("invalid number of {0} groups for `{1}`; expected "
2962  "{2}, but got {3}",
2963  groupName, *name, odsValues.size(), values.size()),
2964  *odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name));
2965  };
2966 
2967  // Handle the case where no values were provided.
2968  if (values.empty()) {
2969  // If we don't expect any on the ODS side, we are done.
2970  if (odsValues.empty())
2971  return success();
2972 
2973  // If we do, check if we actually need to provide values (i.e. if any of
2974  // the values are actually required).
2975  unsigned numVariadic = 0;
2976  for (const auto &odsValue : odsValues) {
2977  if (!odsValue.isVariableLength())
2978  return emitSizeMismatchError();
2979  ++numVariadic;
2980  }
2981 
2982  // If we are in a non-rewrite context, we don't need to do anything more.
2983  // Zero-values is a valid constraint on the operation.
2984  if (parserContext != ParserContext::Rewrite)
2985  return success();
2986 
2987  // Otherwise, when in a rewrite we may need to provide values to match the
2988  // ODS signature of the operation to create.
2989 
2990  // If we only have one variadic value, just use an empty list.
2991  if (numVariadic == 1)
2992  return success();
2993 
2994  // Otherwise, create dummy values for each of the entries so that we
2995  // adhere to the ODS signature.
2996  for (unsigned i = 0, e = odsValues.size(); i < e; ++i) {
2997  values.push_back(
2998  ast::RangeExpr::create(ctx, loc, /*elements=*/{}, rangeTy));
2999  }
3000  return success();
3001  }
3002 
3003  // Verify that the number of values provided matches the number of value
3004  // groups ODS expects.
3005  if (odsValues.size() != values.size())
3006  return emitSizeMismatchError();
3007 
3008  auto diagFn = [&](ast::Diagnostic &diag) {
3009  diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name),
3010  *odsOpLoc);
3011  };
3012  for (unsigned i = 0, e = values.size(); i < e; ++i) {
3013  ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy;
3014  if (failed(convertExpressionTo(values[i], expectedType, diagFn)))
3015  return failure();
3016  }
3017  return success();
3018  }
3019 
3020  // Otherwise, accept the value groups as they have been defined and just
3021  // ensure they are one of the expected types.
3022  for (ast::Expr *&valueExpr : values) {
3023  ast::Type valueExprType = valueExpr->getType();
3024 
3025  // Check if this is one of the expected types.
3026  if (valueExprType == rangeTy || valueExprType == singleTy)
3027  continue;
3028 
3029  // If the operand is an Operation, allow converting to a Value or
3030  // ValueRange. This situations arises quite often with nested operation
3031  // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)`
3032  if (singleTy == valueTy) {
3033  if (isa<ast::OperationType>(valueExprType)) {
3034  valueExpr = convertOpToValue(valueExpr);
3035  continue;
3036  }
3037  }
3038 
3039  // Otherwise, try to convert the expression to a range.
3040  if (succeeded(convertExpressionTo(valueExpr, rangeTy)))
3041  continue;
3042 
3043  return emitError(
3044  valueExpr->getLoc(),
3045  llvm::formatv(
3046  "expected `{0}` or `{1}` convertible expression, but got `{2}`",
3047  singleTy, rangeTy, valueExprType));
3048  }
3049  return success();
3050 }
3051 
3052 FailureOr<ast::TupleExpr *>
3053 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements,
3054  ArrayRef<StringRef> elementNames) {
3055  for (const ast::Expr *element : elements) {
3056  ast::Type eleTy = element->getType();
3057  if (isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>(eleTy)) {
3058  return emitError(
3059  element->getLoc(),
3060  llvm::formatv("unable to build a tuple with `{0}` element", eleTy));
3061  }
3062  }
3063  return ast::TupleExpr::create(ctx, loc, elements, elementNames);
3064 }
3065 
3066 //===----------------------------------------------------------------------===//
3067 // Stmts
3068 //===----------------------------------------------------------------------===//
3069 
3070 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc,
3071  ast::Expr *rootOp) {
3072  // Check that root is an Operation.
3073  ast::Type rootType = rootOp->getType();
3074  if (!isa<ast::OperationType>(rootType))
3075  return emitError(rootOp->getLoc(), "expected `Op` expression");
3076 
3077  return ast::EraseStmt::create(ctx, loc, rootOp);
3078 }
3079 
3080 FailureOr<ast::ReplaceStmt *>
3081 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp,
3082  MutableArrayRef<ast::Expr *> replValues) {
3083  // Check that root is an Operation.
3084  ast::Type rootType = rootOp->getType();
3085  if (!isa<ast::OperationType>(rootType)) {
3086  return emitError(
3087  rootOp->getLoc(),
3088  llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
3089  }
3090 
3091  // If there are multiple replacement values, we implicitly convert any Op
3092  // expressions to the value form.
3093  bool shouldConvertOpToValues = replValues.size() > 1;
3094  for (ast::Expr *&replExpr : replValues) {
3095  ast::Type replType = replExpr->getType();
3096 
3097  // Check that replExpr is an Operation, Value, or ValueRange.
3098  if (isa<ast::OperationType>(replType)) {
3099  if (shouldConvertOpToValues)
3100  replExpr = convertOpToValue(replExpr);
3101  continue;
3102  }
3103 
3104  if (replType != valueTy && replType != valueRangeTy) {
3105  return emitError(replExpr->getLoc(),
3106  llvm::formatv("expected `Op`, `Value` or `ValueRange` "
3107  "expression, but got `{0}`",
3108  replType));
3109  }
3110  }
3111 
3112  return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues);
3113 }
3114 
3115 FailureOr<ast::RewriteStmt *>
3116 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp,
3117  ast::CompoundStmt *rewriteBody) {
3118  // Check that root is an Operation.
3119  ast::Type rootType = rootOp->getType();
3120  if (!isa<ast::OperationType>(rootType)) {
3121  return emitError(
3122  rootOp->getLoc(),
3123  llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
3124  }
3125 
3126  return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody);
3127 }
3128 
3129 //===----------------------------------------------------------------------===//
3130 // Code Completion
3131 //===----------------------------------------------------------------------===//
3132 
3133 LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) {
3134  ast::Type parentType = parentExpr->getType();
3135  if (ast::OperationType opType = dyn_cast<ast::OperationType>(parentType))
3136  codeCompleteContext->codeCompleteOperationMemberAccess(opType);
3137  else if (ast::TupleType tupleType = dyn_cast<ast::TupleType>(parentType))
3138  codeCompleteContext->codeCompleteTupleMemberAccess(tupleType);
3139  return failure();
3140 }
3141 
3142 LogicalResult
3143 Parser::codeCompleteAttributeName(std::optional<StringRef> opName) {
3144  if (opName)
3145  codeCompleteContext->codeCompleteOperationAttributeName(*opName);
3146  return failure();
3147 }
3148 
3149 LogicalResult
3150 Parser::codeCompleteConstraintName(ast::Type inferredType,
3151  bool allowInlineTypeConstraints) {
3152  codeCompleteContext->codeCompleteConstraintName(
3153  inferredType, allowInlineTypeConstraints, curDeclScope);
3154  return failure();
3155 }
3156 
3157 LogicalResult Parser::codeCompleteDialectName() {
3158  codeCompleteContext->codeCompleteDialectName();
3159  return failure();
3160 }
3161 
3162 LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) {
3163  codeCompleteContext->codeCompleteOperationName(dialectName);
3164  return failure();
3165 }
3166 
3167 LogicalResult Parser::codeCompletePatternMetadata() {
3168  codeCompleteContext->codeCompletePatternMetadata();
3169  return failure();
3170 }
3171 
3172 LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) {
3173  codeCompleteContext->codeCompleteIncludeFilename(curPath);
3174  return failure();
3175 }
3176 
3177 void Parser::codeCompleteCallSignature(ast::Node *parent,
3178  unsigned currentNumArgs) {
3179  ast::CallableDecl *callableDecl = tryExtractCallableDecl(parent);
3180  if (!callableDecl)
3181  return;
3182 
3183  codeCompleteContext->codeCompleteCallSignature(callableDecl, currentNumArgs);
3184 }
3185 
3186 void Parser::codeCompleteOperationOperandsSignature(
3187  std::optional<StringRef> opName, unsigned currentNumOperands) {
3188  codeCompleteContext->codeCompleteOperationOperandsSignature(
3189  opName, currentNumOperands);
3190 }
3191 
3192 void Parser::codeCompleteOperationResultsSignature(
3193  std::optional<StringRef> opName, unsigned currentNumResults) {
3194  codeCompleteContext->codeCompleteOperationResultsSignature(opName,
3195  currentNumResults);
3196 }
3197 
3198 //===----------------------------------------------------------------------===//
3199 // Parser
3200 //===----------------------------------------------------------------------===//
3201 
3202 FailureOr<ast::Module *>
3203 mlir::pdll::parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
3204  bool enableDocumentation,
3205  CodeCompleteContext *codeCompleteContext) {
3206  Parser parser(ctx, sourceMgr, enableDocumentation, codeCompleteContext);
3207  return parser.parseModule();
3208 }
union mlir::linalg::@1244::ArityGroupAndKind::Kind kind
static std::string diag(const llvm::Value &value)
This class breaks up the current file into a token stream.
Definition: Lexer.h:23
This represents a token in the MLIR syntax.
Definition: Token.h:20
bool isKeyword() const
Return true if this is one of the keyword token kinds (e.g. kw_if).
Definition: Token.cpp:192
SMLoc getLoc() const
Definition: Token.cpp:24
bool isNot(Kind k) const
Definition: Token.h:50
StringRef getSpelling() const
Definition: Token.h:34
This class provides an abstract interface into the parser for hooking in code completion events.
Definition: CodeComplete.h:30
virtual void codeCompleteConstraintName(ast::Type currentType, bool allowInlineTypeConstraints, const ast::DeclScope *scope)
Signal code completion for a constraint name with an optional decl scope.
virtual void codeCompleteOperationAttributeName(StringRef opName)
Signal code completion for a member access into the given operation type.
Definition: CodeComplete.h:48
virtual void codeCompleteOperationOperandsSignature(std::optional< StringRef > opName, unsigned currentNumOperands)
Signal code completion for the signature of an operation's operands.
Definition: CodeComplete.h:80
virtual void codeCompleteOperationName(StringRef dialectName)
Signal code completion for an operation name in the given dialect.
Definition: CodeComplete.h:62
virtual void codeCompleteOperationResultsSignature(std::optional< StringRef > opName, unsigned currentNumResults)
Signal code completion for the signature of an operation's results.
Definition: CodeComplete.h:85
virtual void codeCompleteDialectName()
Signal code completion for a dialect name.
Definition: CodeComplete.h:59
virtual void codeCompleteOperationMemberAccess(ast::OperationType opType)
Signal code completion for a member access into the given operation type.
virtual void codeCompleteTupleMemberAccess(ast::TupleType tupleType)
Signal code completion for a member access into the given tuple type.
virtual void codeCompletePatternMetadata()
Signal code completion for Pattern metadata.
Definition: CodeComplete.h:65
virtual void codeCompleteCallSignature(const ast::CallableDecl *callable, unsigned currentNumArgs)
Signal code completion for the signature of a callable.
Definition: CodeComplete.h:75
virtual void codeCompleteIncludeFilename(StringRef curPath)
Signal code completion for an include filename.
Definition: CodeComplete.h:68
@ code_complete_string
Token signifying a code completion location within a string.
Definition: Lexer.h:41
@ directive
Tokens.
Definition: Lexer.h:93
@ eof
Markers.
Definition: Lexer.h:36
@ code_complete
Token signifying a code completion location.
Definition: Lexer.h:39
@ arrow
Punctuation.
Definition: Lexer.h:74
@ less
Paired punctuation.
Definition: Lexer.h:82
@ kw_Attr
General keywords.
Definition: Lexer.h:54
static StringRef getMemberName()
Return the member name used for the "all-results" access.
Definition: Nodes.h:487
static AllResultsMemberAccessExpr * create(Context &ctx, SMRange loc, const Expr *parentExpr, Type type)
Definition: Nodes.h:489
The class represents an Attribute constraint, and constrains a variable to be an Attribute.
Definition: Nodes.h:750
static AttrConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr=nullptr)
Definition: Nodes.cpp:385
static AttributeExpr * create(Context &ctx, SMRange loc, StringRef value)
Definition: Nodes.cpp:259
static AttributeType get(Context &context)
Return an instance of the Attribute type.
Definition: Types.cpp:57
static CallExpr * create(Context &ctx, SMRange loc, Expr *callable, ArrayRef< Expr * > arguments, Type resultType, bool isNegated=false)
Definition: Nodes.cpp:269
This decl represents a shared interface for all callable decls.
Definition: Nodes.h:1194
Type getResultType() const
Return the result type of this decl.
Definition: Nodes.h:1212
ArrayRef< VariableDecl * > getInputs() const
Return the inputs of this decl.
Definition: Nodes.h:1205
StringRef getCallableType() const
Return the callable type of this decl.
Definition: Nodes.h:1197
This statement represents a compound statement, which contains a collection of other statements.
Definition: Nodes.h:179
ArrayRef< Stmt * >::iterator begin() const
Definition: Nodes.h:191
MutableArrayRef< Stmt * > getChildren()
Return the children of this compound statement.
Definition: Nodes.h:185
ArrayRef< Stmt * >::iterator end() const
Definition: Nodes.h:192
static CompoundStmt * create(Context &ctx, SMRange location, ArrayRef< Stmt * > children)
Definition: Nodes.cpp:192
This class represents the base of all AST Constraint decls.
Definition: Nodes.h:704
static ConstraintType get(Context &context)
Return an instance of the Constraint type.
Definition: Types.cpp:65
This class represents the main context of the PDLL AST.
Definition: Context.h:25
DiagnosticEngine & getDiagEngine()
Return the diagnostic engine of this context.
Definition: Context.h:42
ods::Context & getODSContext()
Return the ODS context used by the AST.
Definition: Context.h:38
static DeclRefExpr * create(Context &ctx, SMRange loc, Decl *decl, Type type)
Definition: Nodes.cpp:285
This class represents a scope for named AST decls.
Definition: Nodes.h:64
DeclScope * getParentScope()
Return the parent scope of this scope, or nullptr if there is no parent.
Definition: Nodes.h:70
void add(Decl *decl)
Add a new decl to the scope.
Definition: Nodes.cpp:175
This class represents the base Decl node.
Definition: Nodes.h:669
const Name * getName() const
Return the name of the decl, or nullptr if it doesn't have one.
Definition: Nodes.h:672
InFlightDiagnostic emitWarning(SMRange loc, const Twine &msg)
Definition: Diagnostic.h:149
InFlightDiagnostic emitError(SMRange loc, const Twine &msg)
Emit an error to the diagnostic engine.
Definition: Diagnostic.h:145
This class provides a simple implementation of a PDLL diagnostic.
Definition: Diagnostic.h:29
static EraseStmt * create(Context &ctx, SMRange loc, Expr *rootOp)
Definition: Nodes.cpp:218
This class represents a base AST Expression node.
Definition: Nodes.h:348
Type getType() const
Return the type of this expression.
Definition: Nodes.h:351
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostic.h:83
static LetStmt * create(Context &ctx, SMRange loc, VariableDecl *varDecl)
Definition: Nodes.cpp:206
static MemberAccessExpr * create(Context &ctx, SMRange loc, const Expr *parentExpr, StringRef memberName, Type type)
Definition: Nodes.cpp:295
static Module * create(Context &ctx, SMLoc loc, ArrayRef< Decl * > children)
Definition: Nodes.cpp:566
This Decl represents a NamedAttribute, and contains a string name and attribute value.
Definition: Nodes.h:998
static NamedAttributeDecl * create(Context &ctx, const Name &name, Expr *value)
Definition: Nodes.cpp:492
This class represents a base AST node.
Definition: Nodes.h:108
SMRange getLoc() const
Return the location of this node.
Definition: Nodes.h:131
static OpConstraintDecl * create(Context &ctx, SMRange loc, const OpNameDecl *nameDecl=nullptr)
Definition: Nodes.cpp:395
This Decl represents an OperationName.
Definition: Nodes.h:1022
std::optional< StringRef > getName() const
Return the name of this operation, or std::nullopt if the name is unknown.
Definition: Nodes.h:1028
static OpNameDecl * create(Context &ctx, const Name &name)
Definition: Nodes.cpp:502
static OperationExpr * create(Context &ctx, SMRange loc, const ods::Operation *odsOp, const OpNameDecl *nameDecl, ArrayRef< Expr * > operands, ArrayRef< Expr * > resultTypes, ArrayRef< NamedAttributeDecl * > attributes)
Definition: Nodes.cpp:307
This class represents a PDLL type that corresponds to an mlir::Operation.
Definition: Types.h:134
const ods::Operation * getODSOperation() const
Return the ODS operation that this type refers to, or nullptr if the ODS operation is unknown.
Definition: Types.cpp:87
static OperationType get(Context &context, std::optional< StringRef > name=std::nullopt, const ods::Operation *odsOp=nullptr)
Return an instance of the Operation type with an optional operation name.
Definition: Types.cpp:73
static PatternDecl * create(Context &ctx, SMRange location, const Name *name, std::optional< uint16_t > benefit, bool hasBoundedRecursion, const CompoundStmt *body)
Definition: Nodes.cpp:513
static RangeExpr * create(Context &ctx, SMRange loc, ArrayRef< Expr * > elements, RangeType type)
Definition: Nodes.cpp:335
This class represents a PDLL type that corresponds to a range of elements with a given element type.
Definition: Types.h:159
static ReplaceStmt * create(Context &ctx, SMRange loc, Expr *rootOp, ArrayRef< Expr * > replExprs)
Definition: Nodes.cpp:226
static ReturnStmt * create(Context &ctx, SMRange loc, Expr *resultExpr)
Definition: Nodes.cpp:250
static RewriteStmt * create(Context &ctx, SMRange loc, Expr *rootOp, CompoundStmt *rewriteBody)
Definition: Nodes.cpp:240
static RewriteType get(Context &context)
Return an instance of the Rewrite type.
Definition: Types.cpp:136
This class represents a base AST Statement node.
Definition: Nodes.h:164
static TupleExpr * create(Context &ctx, SMRange loc, ArrayRef< Expr * > elements, ArrayRef< StringRef > elementNames)
Definition: Nodes.cpp:349
This class represents a PDLL tuple type, i.e.
Definition: Types.h:222
size_t size() const
Return the number of elements within this tuple.
Definition: Types.h:238
ArrayRef< Type > getElementTypes() const
Return the element types of this tuple.
Definition: Types.cpp:155
static TupleType get(Context &context, ArrayRef< Type > elementTypes, ArrayRef< StringRef > elementNames)
Return an instance of the Tuple type.
Definition: Types.cpp:144
static TypeConstraintDecl * create(Context &ctx, SMRange loc)
Definition: Nodes.cpp:412
static TypeExpr * create(Context &ctx, SMRange loc, StringRef value)
Definition: Nodes.cpp:368
static TypeRangeConstraintDecl * create(Context &ctx, SMRange loc)
Definition: Nodes.cpp:421
static TypeRangeType get(Context &context)
Return an instance of the TypeRange type.
Definition: Types.cpp:113
static TypeType get(Context &context)
Return an instance of the Type type.
Definition: Types.cpp:167
Type refineWith(Type other) const
Try to refine this type with the one provided.
Definition: Types.cpp:33
static UserConstraintDecl * createNative(Context &ctx, const Name &name, ArrayRef< VariableDecl * > inputs, ArrayRef< VariableDecl * > results, std::optional< StringRef > codeBlock, Type resultType, ArrayRef< StringRef > nativeInputTypes={})
Create a native constraint with the given optional code block.
Definition: Nodes.h:892
The class represents a Value constraint, and constrains a variable to be a Value.
Definition: Nodes.h:830
static ValueConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr)
Definition: Nodes.cpp:431
The class represents a ValueRange constraint, and constrains a variable to be a ValueRange.
Definition: Nodes.h:853
static ValueRangeConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr=nullptr)
Definition: Nodes.cpp:442
static ValueRangeType get(Context &context)
Return an instance of the ValueRange type.
Definition: Types.cpp:127
static ValueType get(Context &context)
Return an instance of the Value type.
Definition: Types.cpp:175
static VariableDecl * create(Context &ctx, const Name &name, Type type, Expr *initExpr, ArrayRef< ConstraintRef > constraints)
Definition: Nodes.cpp:549
This class contains all of the registered ODS operation classes.
Definition: Context.h:32
std::pair< Operation *, bool > insertOperation(StringRef name, StringRef summary, StringRef desc, StringRef nativeClassName, bool supportsResultTypeInferrence, SMLoc loc)
Insert a new operation with the context.
Definition: Context.cpp:63
const TypeConstraint & insertTypeConstraint(StringRef name, StringRef summary, StringRef cppClass)
Insert a new type constraint with the context.
Definition: Context.cpp:41
const AttributeConstraint & insertAttributeConstraint(StringRef name, StringRef summary, StringRef cppClass)
Insert a new attribute constraint with the context.
Definition: Context.cpp:27
const Operation * lookupOperation(StringRef name) const
Lookup an operation registered with the given name, or null if no operation with that name is registe...
Definition: Context.cpp:72
This class provides an ODS representation of a specific operation operand or result.
Definition: Operation.h:74
This class provides an ODS representation of a specific operation.
Definition: Operation.h:125
SMRange getLoc() const
Return the source location of this operation.
Definition: Operation.h:128
ArrayRef< OperandOrResult > getOperands() const
Returns the operands of this operation.
Definition: Operation.h:165
bool hasResultTypeInferrence() const
Return if the operation is known to support result type inferrence.
Definition: Operation.h:171
ArrayRef< OperandOrResult > getResults() const
Returns the results of this operation.
Definition: Operation.h:168
This class represents a generic ODS Type constraint.
Definition: Constraint.h:84
raw_ostream subclass that simplifies indention a sequence of code.
raw_indented_ostream & printReindented(StringRef str, StringRef extraPrefix="")
Prints a string re-indented to the current indent.
StringRef getSummary() const
Definition: Constraint.cpp:57
std::string getUniqueDefName() const
Returns a unique name for the TablGen def of this constraint.
Definition: Constraint.cpp:74
StringRef getDescription() const
Definition: Constraint.cpp:64
std::string getConditionTemplate() const
Definition: Constraint.cpp:53
Format context containing substitutions for special placeholders.
Definition: Format.h:40
FmtContext & withSelf(Twine subst)
Definition: Format.cpp:41
Wrapper class that contains a MLIR op's information (e.g., operands, attributes) defined in TableGen ...
Definition: Operator.h:77
FailureOr< ast::Module * > parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr, bool enableDocumentation=false, CodeCompleteContext *codeCompleteContext=nullptr)
Parse an AST module from the main file of the given source manager.
Definition: Parser.cpp:3203
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
auto tgfmt(StringRef fmt, const FmtContext *ctx, Ts &&...vals) -> FmtObject< decltype(std::make_tuple(llvm::support::detail::build_format_adapter(std::forward< Ts >(vals))...))>
Formats text by substituting placeholders in format string with replacement parameters.
Definition: Format.h:262
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
This class represents a reference to a constraint, and contains a constraint and the location of the ...
Definition: Nodes.h:716
const ConstraintDecl * constraint
Definition: Nodes.h:722
This class provides a convenient API for interacting with source names.
Definition: Nodes.h:37
StringRef getName() const
Return the raw string name.
Definition: Nodes.h:41
SMRange getLoc() const
Get the location of this name.
Definition: Nodes.h:44
static const Name & create(Context &ctx, StringRef name, SMRange location)
Definition: Nodes.cpp:33