MLIR  20.0.0git
Parser.cpp
Go to the documentation of this file.
1 //===- Parser.cpp ---------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "Lexer.h"
12 #include "mlir/TableGen/Argument.h"
15 #include "mlir/TableGen/Format.h"
16 #include "mlir/TableGen/Operator.h"
25 #include "llvm/ADT/StringExtras.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "llvm/Support/ManagedStatic.h"
29 #include "llvm/Support/SaveAndRestore.h"
30 #include "llvm/Support/ScopedPrinter.h"
31 #include "llvm/TableGen/Error.h"
32 #include "llvm/TableGen/Parser.h"
33 #include <optional>
34 #include <string>
35 
36 using namespace mlir;
37 using namespace mlir::pdll;
38 
39 //===----------------------------------------------------------------------===//
40 // Parser
41 //===----------------------------------------------------------------------===//
42 
43 namespace {
44 class Parser {
45 public:
46  Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
47  bool enableDocumentation, CodeCompleteContext *codeCompleteContext)
48  : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext),
49  curToken(lexer.lexToken()), enableDocumentation(enableDocumentation),
50  typeTy(ast::TypeType::get(ctx)), valueTy(ast::ValueType::get(ctx)),
51  typeRangeTy(ast::TypeRangeType::get(ctx)),
52  valueRangeTy(ast::ValueRangeType::get(ctx)),
53  attrTy(ast::AttributeType::get(ctx)),
54  codeCompleteContext(codeCompleteContext) {}
55 
56  /// Try to parse a new module. Returns nullptr in the case of failure.
57  FailureOr<ast::Module *> parseModule();
58 
59 private:
60  /// The current context of the parser. It allows for the parser to know a bit
61  /// about the construct it is nested within during parsing. This is used
62  /// specifically to provide additional verification during parsing, e.g. to
63  /// prevent using rewrites within a match context, matcher constraints within
64  /// a rewrite section, etc.
65  enum class ParserContext {
66  /// The parser is in the global context.
67  Global,
68  /// The parser is currently within a Constraint, which disallows all types
69  /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.).
70  Constraint,
71  /// The parser is currently within the matcher portion of a Pattern, which
72  /// is allows a terminal operation rewrite statement but no other rewrite
73  /// transformations.
74  PatternMatch,
75  /// The parser is currently within a Rewrite, which disallows calls to
76  /// constraints, requires operation expressions to have names, etc.
77  Rewrite,
78  };
79 
80  /// The current specification context of an operations result type. This
81  /// indicates how the result types of an operation may be inferred.
82  enum class OpResultTypeContext {
83  /// The result types of the operation are not known to be inferred.
84  Explicit,
85  /// The result types of the operation are inferred from the root input of a
86  /// `replace` statement.
87  Replacement,
88  /// The result types of the operation are inferred by using the
89  /// `InferTypeOpInterface` interface provided by the operation.
90  Interface,
91  };
92 
93  //===--------------------------------------------------------------------===//
94  // Parsing
95  //===--------------------------------------------------------------------===//
96 
97  /// Push a new decl scope onto the lexer.
98  ast::DeclScope *pushDeclScope() {
99  ast::DeclScope *newScope =
100  new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope);
101  return (curDeclScope = newScope);
102  }
103  void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; }
104 
105  /// Pop the last decl scope from the lexer.
106  void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); }
107 
108  /// Parse the body of an AST module.
109  LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls);
110 
111  /// Try to convert the given expression to `type`. Returns failure and emits
112  /// an error if a conversion is not viable. On failure, `noteAttachFn` is
113  /// invoked to attach notes to the emitted error diagnostic. On success,
114  /// `expr` is updated to the expression used to convert to `type`.
115  LogicalResult convertExpressionTo(
116  ast::Expr *&expr, ast::Type type,
117  function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {});
118  LogicalResult
119  convertOpExpressionTo(ast::Expr *&expr, ast::OperationType exprType,
120  ast::Type type,
121  function_ref<ast::InFlightDiagnostic()> emitErrorFn);
122  LogicalResult convertTupleExpressionTo(
123  ast::Expr *&expr, ast::TupleType exprType, ast::Type type,
124  function_ref<ast::InFlightDiagnostic()> emitErrorFn,
125  function_ref<void(ast::Diagnostic &diag)> noteAttachFn);
126 
127  /// Given an operation expression, convert it to a Value or ValueRange
128  /// typed expression.
129  ast::Expr *convertOpToValue(const ast::Expr *opExpr);
130 
131  /// Lookup ODS information for the given operation, returns nullptr if no
132  /// information is found.
133  const ods::Operation *lookupODSOperation(std::optional<StringRef> opName) {
134  return opName ? ctx.getODSContext().lookupOperation(*opName) : nullptr;
135  }
136 
137  /// Process the given documentation string, or return an empty string if
138  /// documentation isn't enabled.
139  StringRef processDoc(StringRef doc) {
140  return enableDocumentation ? doc : StringRef();
141  }
142 
143  /// Process the given documentation string and format it, or return an empty
144  /// string if documentation isn't enabled.
145  std::string processAndFormatDoc(const Twine &doc) {
146  if (!enableDocumentation)
147  return "";
148  std::string docStr;
149  {
150  llvm::raw_string_ostream docOS(docStr);
152  StringRef(docStr).rtrim(" \t"));
153  }
154  return docStr;
155  }
156 
157  //===--------------------------------------------------------------------===//
158  // Directives
159 
160  LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls);
161  LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls);
162  LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc,
164 
165  /// Process the records of a parsed tablegen include file.
166  void processTdIncludeRecords(const llvm::RecordKeeper &tdRecords,
168 
169  /// Create a user defined native constraint for a constraint imported from
170  /// ODS.
171  template <typename ConstraintT>
172  ast::Decl *
173  createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock,
174  SMRange loc, ast::Type type,
175  StringRef nativeType, StringRef docString);
176  template <typename ConstraintT>
177  ast::Decl *
178  createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
179  SMRange loc, ast::Type type,
180  StringRef nativeType);
181 
182  //===--------------------------------------------------------------------===//
183  // Decls
184 
185  /// This structure contains the set of pattern metadata that may be parsed.
186  struct ParsedPatternMetadata {
187  std::optional<uint16_t> benefit;
188  bool hasBoundedRecursion = false;
189  };
190 
191  FailureOr<ast::Decl *> parseTopLevelDecl();
192  FailureOr<ast::NamedAttributeDecl *>
193  parseNamedAttributeDecl(std::optional<StringRef> parentOpName);
194 
195  /// Parse an argument variable as part of the signature of a
196  /// UserConstraintDecl or UserRewriteDecl.
197  FailureOr<ast::VariableDecl *> parseArgumentDecl();
198 
199  /// Parse a result variable as part of the signature of a UserConstraintDecl
200  /// or UserRewriteDecl.
201  FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum);
202 
203  /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being
204  /// defined in a non-global context.
205  FailureOr<ast::UserConstraintDecl *>
206  parseUserConstraintDecl(bool isInline = false);
207 
208  /// Parse an inline UserConstraintDecl. An inline decl is one defined in a
209  /// non-global context, such as within a Pattern/Constraint/etc.
210  FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl();
211 
212  /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
213  /// PDLL constructs.
214  FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl(
215  const ast::Name &name, bool isInline,
216  ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
217  ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
218 
219  /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being
220  /// defined in a non-global context.
221  FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false);
222 
223  /// Parse an inline UserRewriteDecl. An inline decl is one defined in a
224  /// non-global context, such as within a Pattern/Rewrite/etc.
225  FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl();
226 
227  /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
228  /// PDLL constructs.
229  FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl(
230  const ast::Name &name, bool isInline,
231  ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
232  ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
233 
234  /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have
235  /// effectively the same syntax, and only differ on slight semantics (given
236  /// the different parsing contexts).
237  template <typename T, typename ParseUserPDLLDeclFnT>
238  FailureOr<T *> parseUserConstraintOrRewriteDecl(
239  ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
240  StringRef anonymousNamePrefix, bool isInline);
241 
242  /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl.
243  /// These decls have effectively the same syntax.
244  template <typename T>
245  FailureOr<T *> parseUserNativeConstraintOrRewriteDecl(
246  const ast::Name &name, bool isInline,
248  ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
249 
250  /// Parse the functional signature (i.e. the arguments and results) of a
251  /// UserConstraintDecl or UserRewriteDecl.
252  LogicalResult parseUserConstraintOrRewriteSignature(
255  ast::DeclScope *&argumentScope, ast::Type &resultType);
256 
257  /// Validate the return (which if present is specified by bodyIt) of a
258  /// UserConstraintDecl or UserRewriteDecl.
259  LogicalResult validateUserConstraintOrRewriteReturn(
260  StringRef declType, ast::CompoundStmt *body,
263  ArrayRef<ast::VariableDecl *> results, ast::Type &resultType);
264 
265  FailureOr<ast::CompoundStmt *>
266  parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
267  bool expectTerminalSemicolon = true);
268  FailureOr<ast::CompoundStmt *> parsePatternLambdaBody();
269  FailureOr<ast::Decl *> parsePatternDecl();
270  LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
271 
272  /// Check to see if a decl has already been defined with the given name, if
273  /// one has emit and error and return failure. Returns success otherwise.
274  LogicalResult checkDefineNamedDecl(const ast::Name &name);
275 
276  /// Try to define a variable decl with the given components, returns the
277  /// variable on success.
278  FailureOr<ast::VariableDecl *>
279  defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
280  ast::Expr *initExpr,
281  ArrayRef<ast::ConstraintRef> constraints);
282  FailureOr<ast::VariableDecl *>
283  defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
284  ArrayRef<ast::ConstraintRef> constraints);
285 
286  /// Parse the constraint reference list for a variable decl.
287  LogicalResult parseVariableDeclConstraintList(
289 
290  /// Parse the expression used within a type constraint, e.g. Attr<type-expr>.
291  FailureOr<ast::Expr *> parseTypeConstraintExpr();
292 
293  /// Try to parse a single reference to a constraint. `typeConstraint` is the
294  /// location of a previously parsed type constraint for the entity that will
295  /// be constrained by the parsed constraint. `existingConstraints` are any
296  /// existing constraints that have already been parsed for the same entity
297  /// that will be constrained by this constraint. `allowInlineTypeConstraints`
298  /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`.
299  FailureOr<ast::ConstraintRef>
300  parseConstraint(std::optional<SMRange> &typeConstraint,
301  ArrayRef<ast::ConstraintRef> existingConstraints,
302  bool allowInlineTypeConstraints);
303 
304  /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl
305  /// argument or result variable. The constraints for these variables do not
306  /// allow inline type constraints, and only permit a single constraint.
307  FailureOr<ast::ConstraintRef> parseArgOrResultConstraint();
308 
309  //===--------------------------------------------------------------------===//
310  // Exprs
311 
312  FailureOr<ast::Expr *> parseExpr();
313 
314  /// Identifier expressions.
315  FailureOr<ast::Expr *> parseAttributeExpr();
316  FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr,
317  bool isNegated = false);
318  FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
319  FailureOr<ast::Expr *> parseIdentifierExpr();
320  FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
321  FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
322  FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
323  FailureOr<ast::Expr *> parseNegatedExpr();
324  FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
325  FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
326  FailureOr<ast::Expr *>
327  parseOperationExpr(OpResultTypeContext inputResultTypeContext =
328  OpResultTypeContext::Explicit);
329  FailureOr<ast::Expr *> parseTupleExpr();
330  FailureOr<ast::Expr *> parseTypeExpr();
331  FailureOr<ast::Expr *> parseUnderscoreExpr();
332 
333  //===--------------------------------------------------------------------===//
334  // Stmts
335 
336  FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true);
337  FailureOr<ast::CompoundStmt *> parseCompoundStmt();
338  FailureOr<ast::EraseStmt *> parseEraseStmt();
339  FailureOr<ast::LetStmt *> parseLetStmt();
340  FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
341  FailureOr<ast::ReturnStmt *> parseReturnStmt();
342  FailureOr<ast::RewriteStmt *> parseRewriteStmt();
343 
344  //===--------------------------------------------------------------------===//
345  // Creation+Analysis
346  //===--------------------------------------------------------------------===//
347 
348  //===--------------------------------------------------------------------===//
349  // Decls
350 
351  /// Try to extract a callable from the given AST node. Returns nullptr on
352  /// failure.
353  ast::CallableDecl *tryExtractCallableDecl(ast::Node *node);
354 
355  /// Try to create a pattern decl with the given components, returning the
356  /// Pattern on success.
357  FailureOr<ast::PatternDecl *>
358  createPatternDecl(SMRange loc, const ast::Name *name,
359  const ParsedPatternMetadata &metadata,
360  ast::CompoundStmt *body);
361 
362  /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set
363  /// of results, defined as part of the signature.
364  ast::Type
365  createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results);
366 
367  /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl.
368  template <typename T>
369  FailureOr<T *> createUserPDLLConstraintOrRewriteDecl(
370  const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
371  ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
372  ast::CompoundStmt *body);
373 
374  /// Try to create a variable decl with the given components, returning the
375  /// Variable on success.
376  FailureOr<ast::VariableDecl *>
377  createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
378  ArrayRef<ast::ConstraintRef> constraints);
379 
380  /// Create a variable for an argument or result defined as part of the
381  /// signature of a UserConstraintDecl/UserRewriteDecl.
382  FailureOr<ast::VariableDecl *>
383  createArgOrResultVariableDecl(StringRef name, SMRange loc,
384  const ast::ConstraintRef &constraint);
385 
386  /// Validate the constraints used to constraint a variable decl.
387  /// `inferredType` is the type of the variable inferred by the constraints
388  /// within the list, and is updated to the most refined type as determined by
389  /// the constraints. Returns success if the constraint list is valid, failure
390  /// otherwise.
391  LogicalResult
392  validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
393  ast::Type &inferredType);
394  /// Validate a single reference to a constraint. `inferredType` contains the
395  /// currently inferred variabled type and is refined within the type defined
396  /// by the constraint. Returns success if the constraint is valid, failure
397  /// otherwise.
398  LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref,
399  ast::Type &inferredType);
400  LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr);
401  LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr);
402 
403  //===--------------------------------------------------------------------===//
404  // Exprs
405 
406  FailureOr<ast::CallExpr *>
407  createCallExpr(SMRange loc, ast::Expr *parentExpr,
409  bool isNegated = false);
410  FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl);
411  FailureOr<ast::DeclRefExpr *>
412  createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
413  ArrayRef<ast::ConstraintRef> constraints);
414  FailureOr<ast::MemberAccessExpr *>
415  createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc);
416 
417  /// Validate the member access `name` into the given parent expression. On
418  /// success, this also returns the type of the member accessed.
419  FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr,
420  StringRef name, SMRange loc);
421  FailureOr<ast::OperationExpr *>
422  createOperationExpr(SMRange loc, const ast::OpNameDecl *name,
423  OpResultTypeContext resultTypeContext,
427  LogicalResult
428  validateOperationOperands(SMRange loc, std::optional<StringRef> name,
429  const ods::Operation *odsOp,
430  SmallVectorImpl<ast::Expr *> &operands);
431  LogicalResult validateOperationResults(SMRange loc,
432  std::optional<StringRef> name,
433  const ods::Operation *odsOp,
435  void checkOperationResultTypeInferrence(SMRange loc, StringRef name,
436  const ods::Operation *odsOp);
437  LogicalResult validateOperationOperandsOrResults(
438  StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc,
439  std::optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values,
440  ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
441  ast::RangeType rangeTy);
442  FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
443  ArrayRef<ast::Expr *> elements,
444  ArrayRef<StringRef> elementNames);
445 
446  //===--------------------------------------------------------------------===//
447  // Stmts
448 
449  FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp);
450  FailureOr<ast::ReplaceStmt *>
451  createReplaceStmt(SMRange loc, ast::Expr *rootOp,
452  MutableArrayRef<ast::Expr *> replValues);
453  FailureOr<ast::RewriteStmt *>
454  createRewriteStmt(SMRange loc, ast::Expr *rootOp,
455  ast::CompoundStmt *rewriteBody);
456 
457  //===--------------------------------------------------------------------===//
458  // Code Completion
459  //===--------------------------------------------------------------------===//
460 
461  /// The set of various code completion methods. Every completion method
462  /// returns `failure` to stop the parsing process after providing completion
463  /// results.
464 
465  LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr);
466  LogicalResult codeCompleteAttributeName(std::optional<StringRef> opName);
467  LogicalResult codeCompleteConstraintName(ast::Type inferredType,
468  bool allowInlineTypeConstraints);
469  LogicalResult codeCompleteDialectName();
470  LogicalResult codeCompleteOperationName(StringRef dialectName);
471  LogicalResult codeCompletePatternMetadata();
472  LogicalResult codeCompleteIncludeFilename(StringRef curPath);
473 
474  void codeCompleteCallSignature(ast::Node *parent, unsigned currentNumArgs);
475  void codeCompleteOperationOperandsSignature(std::optional<StringRef> opName,
476  unsigned currentNumOperands);
477  void codeCompleteOperationResultsSignature(std::optional<StringRef> opName,
478  unsigned currentNumResults);
479 
480  //===--------------------------------------------------------------------===//
481  // Lexer Utilities
482  //===--------------------------------------------------------------------===//
483 
484  /// If the current token has the specified kind, consume it and return true.
485  /// If not, return false.
486  bool consumeIf(Token::Kind kind) {
487  if (curToken.isNot(kind))
488  return false;
489  consumeToken(kind);
490  return true;
491  }
492 
493  /// Advance the current lexer onto the next token.
494  void consumeToken() {
495  assert(curToken.isNot(Token::eof, Token::error) &&
496  "shouldn't advance past EOF or errors");
497  curToken = lexer.lexToken();
498  }
499 
500  /// Advance the current lexer onto the next token, asserting what the expected
501  /// current token is. This is preferred to the above method because it leads
502  /// to more self-documenting code with better checking.
503  void consumeToken(Token::Kind kind) {
504  assert(curToken.is(kind) && "consumed an unexpected token");
505  consumeToken();
506  }
507 
508  /// Reset the lexer to the location at the given position.
509  void resetToken(SMRange tokLoc) {
510  lexer.resetPointer(tokLoc.Start.getPointer());
511  curToken = lexer.lexToken();
512  }
513 
514  /// Consume the specified token if present and return success. On failure,
515  /// output a diagnostic and return failure.
516  LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
517  if (curToken.getKind() != kind)
518  return emitError(curToken.getLoc(), msg);
519  consumeToken();
520  return success();
521  }
522  LogicalResult emitError(SMRange loc, const Twine &msg) {
523  lexer.emitError(loc, msg);
524  return failure();
525  }
526  LogicalResult emitError(const Twine &msg) {
527  return emitError(curToken.getLoc(), msg);
528  }
529  LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc,
530  const Twine &note) {
531  lexer.emitErrorAndNote(loc, msg, noteLoc, note);
532  return failure();
533  }
534 
535  //===--------------------------------------------------------------------===//
536  // Fields
537  //===--------------------------------------------------------------------===//
538 
539  /// The owning AST context.
540  ast::Context &ctx;
541 
542  /// The lexer of this parser.
543  Lexer lexer;
544 
545  /// The current token within the lexer.
546  Token curToken;
547 
548  /// A flag indicating if the parser should add documentation to AST nodes when
549  /// viable.
550  bool enableDocumentation;
551 
552  /// The most recently defined decl scope.
553  ast::DeclScope *curDeclScope = nullptr;
554  llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;
555 
556  /// The current context of the parser.
557  ParserContext parserContext = ParserContext::Global;
558 
559  /// Cached types to simplify verification and expression creation.
560  ast::Type typeTy, valueTy;
561  ast::RangeType typeRangeTy, valueRangeTy;
562  ast::Type attrTy;
563 
564  /// A counter used when naming anonymous constraints and rewrites.
565  unsigned anonymousDeclNameCounter = 0;
566 
567  /// The optional code completion context.
568  CodeCompleteContext *codeCompleteContext;
569 };
570 } // namespace
571 
572 FailureOr<ast::Module *> Parser::parseModule() {
573  SMLoc moduleLoc = curToken.getStartLoc();
574  pushDeclScope();
575 
576  // Parse the top-level decls of the module.
578  if (failed(parseModuleBody(decls)))
579  return popDeclScope(), failure();
580 
581  popDeclScope();
582  return ast::Module::create(ctx, moduleLoc, decls);
583 }
584 
585 LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) {
586  while (curToken.isNot(Token::eof)) {
587  if (curToken.is(Token::directive)) {
588  if (failed(parseDirective(decls)))
589  return failure();
590  continue;
591  }
592 
593  FailureOr<ast::Decl *> decl = parseTopLevelDecl();
594  if (failed(decl))
595  return failure();
596  decls.push_back(*decl);
597  }
598  return success();
599 }
600 
601 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) {
602  return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr,
603  valueRangeTy);
604 }
605 
606 LogicalResult Parser::convertExpressionTo(
607  ast::Expr *&expr, ast::Type type,
608  function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
609  ast::Type exprType = expr->getType();
610  if (exprType == type)
611  return success();
612 
613  auto emitConvertError = [&]() -> ast::InFlightDiagnostic {
615  expr->getLoc(), llvm::formatv("unable to convert expression of type "
616  "`{0}` to the expected type of "
617  "`{1}`",
618  exprType, type));
619  if (noteAttachFn)
620  noteAttachFn(*diag);
621  return diag;
622  };
623 
624  if (auto exprOpType = dyn_cast<ast::OperationType>(exprType))
625  return convertOpExpressionTo(expr, exprOpType, type, emitConvertError);
626 
627  // FIXME: Decide how to allow/support converting a single result to multiple,
628  // and multiple to a single result. For now, we just allow Single->Range,
629  // but this isn't something really supported in the PDL dialect. We should
630  // figure out some way to support both.
631  if ((exprType == valueTy || exprType == valueRangeTy) &&
632  (type == valueTy || type == valueRangeTy))
633  return success();
634  if ((exprType == typeTy || exprType == typeRangeTy) &&
635  (type == typeTy || type == typeRangeTy))
636  return success();
637 
638  // Handle tuple types.
639  if (auto exprTupleType = dyn_cast<ast::TupleType>(exprType))
640  return convertTupleExpressionTo(expr, exprTupleType, type, emitConvertError,
641  noteAttachFn);
642 
643  return emitConvertError();
644 }
645 
646 LogicalResult Parser::convertOpExpressionTo(
647  ast::Expr *&expr, ast::OperationType exprType, ast::Type type,
648  function_ref<ast::InFlightDiagnostic()> emitErrorFn) {
649  // Two operation types are compatible if they have the same name, or if the
650  // expected type is more general.
651  if (auto opType = dyn_cast<ast::OperationType>(type)) {
652  if (opType.getName())
653  return emitErrorFn();
654  return success();
655  }
656 
657  // An operation can always convert to a ValueRange.
658  if (type == valueRangeTy) {
659  expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
660  valueRangeTy);
661  return success();
662  }
663 
664  // Allow conversion to a single value by constraining the result range.
665  if (type == valueTy) {
666  // If the operation is registered, we can verify if it can ever have a
667  // single result.
668  if (const ods::Operation *odsOp = exprType.getODSOperation()) {
669  if (odsOp->getResults().empty()) {
670  return emitErrorFn()->attachNote(
671  llvm::formatv("see the definition of `{0}`, which was defined "
672  "with zero results",
673  odsOp->getName()),
674  odsOp->getLoc());
675  }
676 
677  unsigned numSingleResults = llvm::count_if(
678  odsOp->getResults(), [](const ods::OperandOrResult &result) {
679  return result.getVariableLengthKind() ==
680  ods::VariableLengthKind::Single;
681  });
682  if (numSingleResults > 1) {
683  return emitErrorFn()->attachNote(
684  llvm::formatv("see the definition of `{0}`, which was defined "
685  "with at least {1} results",
686  odsOp->getName(), numSingleResults),
687  odsOp->getLoc());
688  }
689  }
690 
691  expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
692  valueTy);
693  return success();
694  }
695  return emitErrorFn();
696 }
697 
698 LogicalResult Parser::convertTupleExpressionTo(
699  ast::Expr *&expr, ast::TupleType exprType, ast::Type type,
700  function_ref<ast::InFlightDiagnostic()> emitErrorFn,
701  function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
702  // Handle conversions between tuples.
703  if (auto tupleType = dyn_cast<ast::TupleType>(type)) {
704  if (tupleType.size() != exprType.size())
705  return emitErrorFn();
706 
707  // Build a new tuple expression using each of the elements of the current
708  // tuple.
709  SmallVector<ast::Expr *> newExprs;
710  for (unsigned i = 0, e = exprType.size(); i < e; ++i) {
711  newExprs.push_back(ast::MemberAccessExpr::create(
712  ctx, expr->getLoc(), expr, llvm::to_string(i),
713  exprType.getElementTypes()[i]));
714 
715  auto diagFn = [&](ast::Diagnostic &diag) {
716  diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`",
717  i, exprType));
718  if (noteAttachFn)
719  noteAttachFn(diag);
720  };
721  if (failed(convertExpressionTo(newExprs.back(),
722  tupleType.getElementTypes()[i], diagFn)))
723  return failure();
724  }
725  expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs,
726  tupleType.getElementNames());
727  return success();
728  }
729 
730  // Handle conversion to a range.
731  auto convertToRange = [&](ArrayRef<ast::Type> allowedElementTypes,
732  ast::RangeType resultTy) -> LogicalResult {
733  // TODO: We currently only allow range conversion within a rewrite context.
734  if (parserContext != ParserContext::Rewrite) {
735  return emitErrorFn()->attachNote("Tuple to Range conversion is currently "
736  "only allowed within a rewrite context");
737  }
738 
739  // All of the tuple elements must be allowed types.
740  for (ast::Type elementType : exprType.getElementTypes())
741  if (!llvm::is_contained(allowedElementTypes, elementType))
742  return emitErrorFn();
743 
744  // Build a new tuple expression using each of the elements of the current
745  // tuple.
746  SmallVector<ast::Expr *> newExprs;
747  for (unsigned i = 0, e = exprType.size(); i < e; ++i) {
748  newExprs.push_back(ast::MemberAccessExpr::create(
749  ctx, expr->getLoc(), expr, llvm::to_string(i),
750  exprType.getElementTypes()[i]));
751  }
752  expr = ast::RangeExpr::create(ctx, expr->getLoc(), newExprs, resultTy);
753  return success();
754  };
755  if (type == valueRangeTy)
756  return convertToRange({valueTy, valueRangeTy}, valueRangeTy);
757  if (type == typeRangeTy)
758  return convertToRange({typeTy, typeRangeTy}, typeRangeTy);
759 
760  return emitErrorFn();
761 }
762 
763 //===----------------------------------------------------------------------===//
764 // Directives
765 
766 LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) {
767  StringRef directive = curToken.getSpelling();
768  if (directive == "#include")
769  return parseInclude(decls);
770 
771  return emitError("unknown directive `" + directive + "`");
772 }
773 
774 LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) {
775  SMRange loc = curToken.getLoc();
776  consumeToken(Token::directive);
777 
778  // Handle code completion of the include file path.
779  if (curToken.is(Token::code_complete_string))
780  return codeCompleteIncludeFilename(curToken.getStringValue());
781 
782  // Parse the file being included.
783  if (!curToken.isString())
784  return emitError(loc,
785  "expected string file name after `include` directive");
786  SMRange fileLoc = curToken.getLoc();
787  std::string filenameStr = curToken.getStringValue();
788  StringRef filename = filenameStr;
789  consumeToken();
790 
791  // Check the type of include. If ending with `.pdll`, this is another pdl file
792  // to be parsed along with the current module.
793  if (filename.ends_with(".pdll")) {
794  if (failed(lexer.pushInclude(filename, fileLoc)))
795  return emitError(fileLoc,
796  "unable to open include file `" + filename + "`");
797 
798  // If we added the include successfully, parse it into the current module.
799  // Make sure to update to the next token after we finish parsing the nested
800  // file.
801  curToken = lexer.lexToken();
802  LogicalResult result = parseModuleBody(decls);
803  curToken = lexer.lexToken();
804  return result;
805  }
806 
807  // Otherwise, this must be a `.td` include.
808  if (filename.ends_with(".td"))
809  return parseTdInclude(filename, fileLoc, decls);
810 
811  return emitError(fileLoc,
812  "expected include filename to end with `.pdll` or `.td`");
813 }
814 
815 LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc,
817  llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr();
818 
819  // Use the source manager to open the file, but don't yet add it.
820  std::string includedFile;
821  llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer =
822  parserSrcMgr.OpenIncludeFile(filename.str(), includedFile);
823  if (!includeBuffer)
824  return emitError(fileLoc, "unable to open include file `" + filename + "`");
825 
826  // Setup the source manager for parsing the tablegen file.
827  llvm::SourceMgr tdSrcMgr;
828  tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc());
829  tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs());
830 
831  // This class provides a context argument for the llvm::SourceMgr diagnostic
832  // handler.
833  struct DiagHandlerContext {
834  Parser &parser;
835  StringRef filename;
836  llvm::SMRange loc;
837  } handlerContext{*this, filename, fileLoc};
838 
839  // Set the diagnostic handler for the tablegen source manager.
840  tdSrcMgr.setDiagHandler(
841  [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) {
842  auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext);
843  (void)ctx->parser.emitError(
844  ctx->loc,
845  llvm::formatv("error while processing include file `{0}`: {1}",
846  ctx->filename, diag.getMessage()));
847  },
848  &handlerContext);
849 
850  // Parse the tablegen file.
851  llvm::RecordKeeper tdRecords;
852  if (llvm::TableGenParseFile(tdSrcMgr, tdRecords))
853  return failure();
854 
855  // Process the parsed records.
856  processTdIncludeRecords(tdRecords, decls);
857 
858  // After we are done processing, move all of the tablegen source buffers to
859  // the main parser source mgr. This allows for directly using source locations
860  // from the .td files without needing to remap them.
861  parserSrcMgr.takeSourceBuffersFrom(tdSrcMgr, fileLoc.End);
862  return success();
863 }
864 
865 void Parser::processTdIncludeRecords(const llvm::RecordKeeper &tdRecords,
867  // Return the length kind of the given value.
868  auto getLengthKind = [](const auto &value) {
869  if (value.isOptional())
871  return value.isVariadic() ? ods::VariableLengthKind::Variadic
873  };
874 
875  // Insert a type constraint into the ODS context.
876  ods::Context &odsContext = ctx.getODSContext();
877  auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst)
878  -> const ods::TypeConstraint & {
879  return odsContext.insertTypeConstraint(
880  cst.constraint.getUniqueDefName(),
881  processDoc(cst.constraint.getSummary()), cst.constraint.getCppType());
882  };
883  auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange {
884  return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)};
885  };
886 
887  // Process the parsed tablegen records to build ODS information.
888  /// Operations.
889  for (const llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) {
890  tblgen::Operator op(def);
891 
892  // Check to see if this operation is known to support type inferrence.
893  bool supportsResultTypeInferrence =
894  op.getTrait("::mlir::InferTypeOpInterface::Trait");
895 
896  auto [odsOp, inserted] = odsContext.insertOperation(
897  op.getOperationName(), processDoc(op.getSummary()),
898  processAndFormatDoc(op.getDescription()), op.getQualCppClassName(),
899  supportsResultTypeInferrence, op.getLoc().front());
900 
901  // Ignore operations that have already been added.
902  if (!inserted)
903  continue;
904 
905  for (const tblgen::NamedAttribute &attr : op.getAttributes()) {
906  odsOp->appendAttribute(attr.name, attr.attr.isOptional(),
907  odsContext.insertAttributeConstraint(
908  attr.attr.getUniqueDefName(),
909  processDoc(attr.attr.getSummary()),
910  attr.attr.getStorageType()));
911  }
912  for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) {
913  odsOp->appendOperand(operand.name, getLengthKind(operand),
914  addTypeConstraint(operand));
915  }
916  for (const tblgen::NamedTypeConstraint &result : op.getResults()) {
917  odsOp->appendResult(result.name, getLengthKind(result),
918  addTypeConstraint(result));
919  }
920  }
921 
922  auto shouldBeSkipped = [this](const llvm::Record *def) {
923  return def->isAnonymous() || curDeclScope->lookup(def->getName()) ||
924  def->isSubClassOf("DeclareInterfaceMethods");
925  };
926 
927  /// Attr constraints.
928  for (const llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) {
929  if (shouldBeSkipped(def))
930  continue;
931 
932  tblgen::Attribute constraint(def);
933  decls.push_back(createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
934  constraint, convertLocToRange(def->getLoc().front()), attrTy,
935  constraint.getStorageType()));
936  }
937  /// Type constraints.
938  for (const llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) {
939  if (shouldBeSkipped(def))
940  continue;
941 
942  tblgen::TypeConstraint constraint(def);
943  decls.push_back(createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
944  constraint, convertLocToRange(def->getLoc().front()), typeTy,
945  constraint.getCppType()));
946  }
947  /// OpInterfaces.
949  for (const llvm::Record *def :
950  tdRecords.getAllDerivedDefinitions("OpInterface")) {
951  if (shouldBeSkipped(def))
952  continue;
953 
954  SMRange loc = convertLocToRange(def->getLoc().front());
955 
956  std::string cppClassName =
957  llvm::formatv("{0}::{1}", def->getValueAsString("cppNamespace"),
958  def->getValueAsString("cppInterfaceName"))
959  .str();
960  std::string codeBlock =
961  llvm::formatv("return ::mlir::success(llvm::isa<{0}>(self));",
962  cppClassName)
963  .str();
964 
965  std::string desc =
966  processAndFormatDoc(def->getValueAsString("description"));
967  decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>(
968  def->getName(), codeBlock, loc, opTy, cppClassName, desc));
969  }
970 }
971 
972 template <typename ConstraintT>
973 ast::Decl *Parser::createODSNativePDLLConstraintDecl(
974  StringRef name, StringRef codeBlock, SMRange loc, ast::Type type,
975  StringRef nativeType, StringRef docString) {
976  // Build the single input parameter.
977  ast::DeclScope *argScope = pushDeclScope();
978  auto *paramVar = ast::VariableDecl::create(
979  ctx, ast::Name::create(ctx, "self", loc), type,
980  /*initExpr=*/nullptr, ast::ConstraintRef(ConstraintT::create(ctx, loc)));
981  argScope->add(paramVar);
982  popDeclScope();
983 
984  // Build the native constraint.
985  auto *constraintDecl = ast::UserConstraintDecl::createNative(
986  ctx, ast::Name::create(ctx, name, loc), paramVar,
987  /*results=*/std::nullopt, codeBlock, ast::TupleType::get(ctx),
988  nativeType);
989  constraintDecl->setDocComment(ctx, docString);
990  curDeclScope->add(constraintDecl);
991  return constraintDecl;
992 }
993 
994 template <typename ConstraintT>
995 ast::Decl *
996 Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
997  SMRange loc, ast::Type type,
998  StringRef nativeType) {
999  // Format the condition template.
1000  tblgen::FmtContext fmtContext;
1001  fmtContext.withSelf("self");
1002  std::string codeBlock = tblgen::tgfmt(
1003  "return ::mlir::success(" + constraint.getConditionTemplate() + ");",
1004  &fmtContext);
1005 
1006  // If documentation was enabled, build the doc string for the generated
1007  // constraint. It would be nice to do this lazily, but TableGen information is
1008  // destroyed after we finish parsing the file.
1009  std::string docString;
1010  if (enableDocumentation) {
1011  StringRef desc = constraint.getDescription();
1012  docString = processAndFormatDoc(
1013  constraint.getSummary() +
1014  (desc.empty() ? "" : ("\n\n" + constraint.getDescription())));
1015  }
1016 
1017  return createODSNativePDLLConstraintDecl<ConstraintT>(
1018  constraint.getUniqueDefName(), codeBlock, loc, type, nativeType,
1019  docString);
1020 }
1021 
1022 //===----------------------------------------------------------------------===//
1023 // Decls
1024 
1025 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
1026  FailureOr<ast::Decl *> decl;
1027  switch (curToken.getKind()) {
1028  case Token::kw_Constraint:
1029  decl = parseUserConstraintDecl();
1030  break;
1031  case Token::kw_Pattern:
1032  decl = parsePatternDecl();
1033  break;
1034  case Token::kw_Rewrite:
1035  decl = parseUserRewriteDecl();
1036  break;
1037  default:
1038  return emitError("expected top-level declaration, such as a `Pattern`");
1039  }
1040  if (failed(decl))
1041  return failure();
1042 
1043  // If the decl has a name, add it to the current scope.
1044  if (const ast::Name *name = (*decl)->getName()) {
1045  if (failed(checkDefineNamedDecl(*name)))
1046  return failure();
1047  curDeclScope->add(*decl);
1048  }
1049  return decl;
1050 }
1051 
1052 FailureOr<ast::NamedAttributeDecl *>
1053 Parser::parseNamedAttributeDecl(std::optional<StringRef> parentOpName) {
1054  // Check for name code completion.
1055  if (curToken.is(Token::code_complete))
1056  return codeCompleteAttributeName(parentOpName);
1057 
1058  std::string attrNameStr;
1059  if (curToken.isString())
1060  attrNameStr = curToken.getStringValue();
1061  else if (curToken.is(Token::identifier) || curToken.isKeyword())
1062  attrNameStr = curToken.getSpelling().str();
1063  else
1064  return emitError("expected identifier or string attribute name");
1065  const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc());
1066  consumeToken();
1067 
1068  // Check for a value of the attribute.
1069  ast::Expr *attrValue = nullptr;
1070  if (consumeIf(Token::equal)) {
1071  FailureOr<ast::Expr *> attrExpr = parseExpr();
1072  if (failed(attrExpr))
1073  return failure();
1074  attrValue = *attrExpr;
1075  } else {
1076  // If there isn't a concrete value, create an expression representing a
1077  // UnitAttr.
1078  attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit");
1079  }
1080 
1081  return ast::NamedAttributeDecl::create(ctx, name, attrValue);
1082 }
1083 
1084 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody(
1085  function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
1086  bool expectTerminalSemicolon) {
1087  consumeToken(Token::equal_arrow);
1088 
1089  // Parse the single statement of the lambda body.
1090  SMLoc bodyStartLoc = curToken.getStartLoc();
1091  pushDeclScope();
1092  FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon);
1093  bool failedToParse =
1094  failed(singleStatement) || failed(processStatementFn(*singleStatement));
1095  popDeclScope();
1096  if (failedToParse)
1097  return failure();
1098 
1099  SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc());
1100  return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement);
1101 }
1102 
1103 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() {
1104  // Ensure that the argument is named.
1105  if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword())
1106  return emitError("expected identifier argument name");
1107 
1108  // Parse the argument similarly to a normal variable.
1109  StringRef name = curToken.getSpelling();
1110  SMRange nameLoc = curToken.getLoc();
1111  consumeToken();
1112 
1113  if (failed(
1114  parseToken(Token::colon, "expected `:` before argument constraint")))
1115  return failure();
1116 
1117  FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1118  if (failed(cst))
1119  return failure();
1120 
1121  return createArgOrResultVariableDecl(name, nameLoc, *cst);
1122 }
1123 
1124 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) {
1125  // Check to see if this result is named.
1126  if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
1127  // Check to see if this name actually refers to a Constraint.
1128  if (!curDeclScope->lookup<ast::ConstraintDecl>(curToken.getSpelling())) {
1129  // If it wasn't a constraint, parse the result similarly to a variable. If
1130  // there is already an existing decl, we will emit an error when defining
1131  // this variable later.
1132  StringRef name = curToken.getSpelling();
1133  SMRange nameLoc = curToken.getLoc();
1134  consumeToken();
1135 
1136  if (failed(parseToken(Token::colon,
1137  "expected `:` before result constraint")))
1138  return failure();
1139 
1140  FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1141  if (failed(cst))
1142  return failure();
1143 
1144  return createArgOrResultVariableDecl(name, nameLoc, *cst);
1145  }
1146  }
1147 
1148  // If it isn't named, we parse the constraint directly and create an unnamed
1149  // result variable.
1150  FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1151  if (failed(cst))
1152  return failure();
1153 
1154  return createArgOrResultVariableDecl("", cst->referenceLoc, *cst);
1155 }
1156 
1157 FailureOr<ast::UserConstraintDecl *>
1158 Parser::parseUserConstraintDecl(bool isInline) {
1159  // Constraints and rewrites have very similar formats, dispatch to a shared
1160  // interface for parsing.
1161  return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1162  [&](auto &&...args) {
1163  return this->parseUserPDLLConstraintDecl(args...);
1164  },
1165  ParserContext::Constraint, "constraint", isInline);
1166 }
1167 
1168 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() {
1169  FailureOr<ast::UserConstraintDecl *> decl =
1170  parseUserConstraintDecl(/*isInline=*/true);
1171  if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1172  return failure();
1173 
1174  curDeclScope->add(*decl);
1175  return decl;
1176 }
1177 
1178 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl(
1179  const ast::Name &name, bool isInline,
1180  ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1181  ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1182  // Push the argument scope back onto the list, so that the body can
1183  // reference arguments.
1184  pushDeclScope(argumentScope);
1185 
1186  // Parse the body of the constraint. The body is either defined as a compound
1187  // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`.
1188  ast::CompoundStmt *body;
1189  if (curToken.is(Token::equal_arrow)) {
1190  FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1191  [&](ast::Stmt *&stmt) -> LogicalResult {
1192  ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt);
1193  if (!stmtExpr) {
1194  return emitError(stmt->getLoc(),
1195  "expected `Constraint` lambda body to contain a "
1196  "single expression");
1197  }
1198  stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr);
1199  return success();
1200  },
1201  /*expectTerminalSemicolon=*/!isInline);
1202  if (failed(bodyResult))
1203  return failure();
1204  body = *bodyResult;
1205  } else {
1206  FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1207  if (failed(bodyResult))
1208  return failure();
1209  body = *bodyResult;
1210 
1211  // Verify the structure of the body.
1212  auto bodyIt = body->begin(), bodyE = body->end();
1213  for (; bodyIt != bodyE; ++bodyIt)
1214  if (isa<ast::ReturnStmt>(*bodyIt))
1215  break;
1216  if (failed(validateUserConstraintOrRewriteReturn(
1217  "Constraint", body, bodyIt, bodyE, results, resultType)))
1218  return failure();
1219  }
1220  popDeclScope();
1221 
1222  return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1223  name, arguments, results, resultType, body);
1224 }
1225 
1226 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) {
1227  // Constraints and rewrites have very similar formats, dispatch to a shared
1228  // interface for parsing.
1229  return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1230  [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(args...); },
1231  ParserContext::Rewrite, "rewrite", isInline);
1232 }
1233 
1234 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() {
1235  FailureOr<ast::UserRewriteDecl *> decl =
1236  parseUserRewriteDecl(/*isInline=*/true);
1237  if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1238  return failure();
1239 
1240  curDeclScope->add(*decl);
1241  return decl;
1242 }
1243 
1244 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl(
1245  const ast::Name &name, bool isInline,
1246  ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1247  ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1248  // Push the argument scope back onto the list, so that the body can
1249  // reference arguments.
1250  curDeclScope = argumentScope;
1251  ast::CompoundStmt *body;
1252  if (curToken.is(Token::equal_arrow)) {
1253  FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1254  [&](ast::Stmt *&statement) -> LogicalResult {
1255  if (isa<ast::OpRewriteStmt>(statement))
1256  return success();
1257 
1258  ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement);
1259  if (!statementExpr) {
1260  return emitError(
1261  statement->getLoc(),
1262  "expected `Rewrite` lambda body to contain a single expression "
1263  "or an operation rewrite statement; such as `erase`, "
1264  "`replace`, or `rewrite`");
1265  }
1266  statement =
1267  ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr);
1268  return success();
1269  },
1270  /*expectTerminalSemicolon=*/!isInline);
1271  if (failed(bodyResult))
1272  return failure();
1273  body = *bodyResult;
1274  } else {
1275  FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1276  if (failed(bodyResult))
1277  return failure();
1278  body = *bodyResult;
1279  }
1280  popDeclScope();
1281 
1282  // Verify the structure of the body.
1283  auto bodyIt = body->begin(), bodyE = body->end();
1284  for (; bodyIt != bodyE; ++bodyIt)
1285  if (isa<ast::ReturnStmt>(*bodyIt))
1286  break;
1287  if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt,
1288  bodyE, results, resultType)))
1289  return failure();
1290  return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1291  name, arguments, results, resultType, body);
1292 }
1293 
1294 template <typename T, typename ParseUserPDLLDeclFnT>
1295 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl(
1296  ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
1297  StringRef anonymousNamePrefix, bool isInline) {
1298  SMRange loc = curToken.getLoc();
1299  consumeToken();
1300  llvm::SaveAndRestore saveCtx(parserContext, declContext);
1301 
1302  // Parse the name of the decl.
1303  const ast::Name *name = nullptr;
1304  if (curToken.isNot(Token::identifier)) {
1305  // Only inline decls can be un-named. Inline decls are similar to "lambdas"
1306  // in C++, so being unnamed is fine.
1307  if (!isInline)
1308  return emitError("expected identifier name");
1309 
1310  // Create a unique anonymous name to use, as the name for this decl is not
1311  // important.
1312  std::string anonName =
1313  llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix,
1314  anonymousDeclNameCounter++)
1315  .str();
1316  name = &ast::Name::create(ctx, anonName, loc);
1317  } else {
1318  // If a name was provided, we can use it directly.
1319  name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1320  consumeToken(Token::identifier);
1321  }
1322 
1323  // Parse the functional signature of the decl.
1324  SmallVector<ast::VariableDecl *> arguments, results;
1325  ast::DeclScope *argumentScope;
1326  ast::Type resultType;
1327  if (failed(parseUserConstraintOrRewriteSignature(arguments, results,
1328  argumentScope, resultType)))
1329  return failure();
1330 
1331  // Check to see which type of constraint this is. If the constraint contains a
1332  // compound body, this is a PDLL decl.
1333  if (curToken.isAny(Token::l_brace, Token::equal_arrow))
1334  return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results,
1335  resultType);
1336 
1337  // Otherwise, this is a native decl.
1338  return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments,
1339  results, resultType);
1340 }
1341 
1342 template <typename T>
1343 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
1344  const ast::Name &name, bool isInline,
1346  ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1347  // If followed by a string, the native code body has also been specified.
1348  std::string codeStrStorage;
1349  std::optional<StringRef> optCodeStr;
1350  if (curToken.isString()) {
1351  codeStrStorage = curToken.getStringValue();
1352  optCodeStr = codeStrStorage;
1353  consumeToken();
1354  } else if (isInline) {
1355  return emitError(name.getLoc(),
1356  "external declarations must be declared in global scope");
1357  } else if (curToken.is(Token::error)) {
1358  return failure();
1359  }
1360  if (failed(parseToken(Token::semicolon,
1361  "expected `;` after native declaration")))
1362  return failure();
1363  return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
1364 }
1365 
1366 LogicalResult Parser::parseUserConstraintOrRewriteSignature(
1369  ast::DeclScope *&argumentScope, ast::Type &resultType) {
1370  // Parse the argument list of the decl.
1371  if (failed(parseToken(Token::l_paren, "expected `(` to start argument list")))
1372  return failure();
1373 
1374  argumentScope = pushDeclScope();
1375  if (curToken.isNot(Token::r_paren)) {
1376  do {
1377  FailureOr<ast::VariableDecl *> argument = parseArgumentDecl();
1378  if (failed(argument))
1379  return failure();
1380  arguments.emplace_back(*argument);
1381  } while (consumeIf(Token::comma));
1382  }
1383  popDeclScope();
1384  if (failed(parseToken(Token::r_paren, "expected `)` to end argument list")))
1385  return failure();
1386 
1387  // Parse the results of the decl.
1388  pushDeclScope();
1389  if (consumeIf(Token::arrow)) {
1390  auto parseResultFn = [&]() -> LogicalResult {
1391  FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size());
1392  if (failed(result))
1393  return failure();
1394  results.emplace_back(*result);
1395  return success();
1396  };
1397 
1398  // Check for a list of results.
1399  if (consumeIf(Token::l_paren)) {
1400  do {
1401  if (failed(parseResultFn()))
1402  return failure();
1403  } while (consumeIf(Token::comma));
1404  if (failed(parseToken(Token::r_paren, "expected `)` to end result list")))
1405  return failure();
1406 
1407  // Otherwise, there is only one result.
1408  } else if (failed(parseResultFn())) {
1409  return failure();
1410  }
1411  }
1412  popDeclScope();
1413 
1414  // Compute the result type of the decl.
1415  resultType = createUserConstraintRewriteResultType(results);
1416 
1417  // Verify that results are only named if there are more than one.
1418  if (results.size() == 1 && !results.front()->getName().getName().empty()) {
1419  return emitError(
1420  results.front()->getLoc(),
1421  "cannot create a single-element tuple with an element label");
1422  }
1423  return success();
1424 }
1425 
1426 LogicalResult Parser::validateUserConstraintOrRewriteReturn(
1427  StringRef declType, ast::CompoundStmt *body,
1430  ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) {
1431  // Handle if a `return` was provided.
1432  if (bodyIt != bodyE) {
1433  // Emit an error if we have trailing statements after the return.
1434  if (std::next(bodyIt) != bodyE) {
1435  return emitError(
1436  (*std::next(bodyIt))->getLoc(),
1437  llvm::formatv("`return` terminated the `{0}` body, but found "
1438  "trailing statements afterwards",
1439  declType));
1440  }
1441 
1442  // Otherwise if a return wasn't provided, check that no results are
1443  // expected.
1444  } else if (!results.empty()) {
1445  return emitError(
1446  {body->getLoc().End, body->getLoc().End},
1447  llvm::formatv("missing return in a `{0}` expected to return `{1}`",
1448  declType, resultType));
1449  }
1450  return success();
1451 }
1452 
1453 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
1454  return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult {
1455  if (isa<ast::OpRewriteStmt>(statement))
1456  return success();
1457  return emitError(
1458  statement->getLoc(),
1459  "expected Pattern lambda body to contain a single operation "
1460  "rewrite statement, such as `erase`, `replace`, or `rewrite`");
1461  });
1462 }
1463 
1464 FailureOr<ast::Decl *> Parser::parsePatternDecl() {
1465  SMRange loc = curToken.getLoc();
1466  consumeToken(Token::kw_Pattern);
1467  llvm::SaveAndRestore saveCtx(parserContext, ParserContext::PatternMatch);
1468 
1469  // Check for an optional identifier for the pattern name.
1470  const ast::Name *name = nullptr;
1471  if (curToken.is(Token::identifier)) {
1472  name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1473  consumeToken(Token::identifier);
1474  }
1475 
1476  // Parse any pattern metadata.
1477  ParsedPatternMetadata metadata;
1478  if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata)))
1479  return failure();
1480 
1481  // Parse the pattern body.
1482  ast::CompoundStmt *body;
1483 
1484  // Handle a lambda body.
1485  if (curToken.is(Token::equal_arrow)) {
1486  FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody();
1487  if (failed(bodyResult))
1488  return failure();
1489  body = *bodyResult;
1490  } else {
1491  if (curToken.isNot(Token::l_brace))
1492  return emitError("expected `{` or `=>` to start pattern body");
1493  FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1494  if (failed(bodyResult))
1495  return failure();
1496  body = *bodyResult;
1497 
1498  // Verify the body of the pattern.
1499  auto bodyIt = body->begin(), bodyE = body->end();
1500  for (; bodyIt != bodyE; ++bodyIt) {
1501  if (isa<ast::ReturnStmt>(*bodyIt)) {
1502  return emitError((*bodyIt)->getLoc(),
1503  "`return` statements are only permitted within a "
1504  "`Constraint` or `Rewrite` body");
1505  }
1506  // Break when we've found the rewrite statement.
1507  if (isa<ast::OpRewriteStmt>(*bodyIt))
1508  break;
1509  }
1510  if (bodyIt == bodyE) {
1511  return emitError(loc,
1512  "expected Pattern body to terminate with an operation "
1513  "rewrite statement, such as `erase`");
1514  }
1515  if (std::next(bodyIt) != bodyE) {
1516  return emitError((*std::next(bodyIt))->getLoc(),
1517  "Pattern body was terminated by an operation "
1518  "rewrite statement, but found trailing statements");
1519  }
1520  }
1521 
1522  return createPatternDecl(loc, name, metadata, body);
1523 }
1524 
1525 LogicalResult
1526 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
1527  std::optional<SMRange> benefitLoc;
1528  std::optional<SMRange> hasBoundedRecursionLoc;
1529 
1530  do {
1531  // Handle metadata code completion.
1532  if (curToken.is(Token::code_complete))
1533  return codeCompletePatternMetadata();
1534 
1535  if (curToken.isNot(Token::identifier))
1536  return emitError("expected pattern metadata identifier");
1537  StringRef metadataStr = curToken.getSpelling();
1538  SMRange metadataLoc = curToken.getLoc();
1539  consumeToken(Token::identifier);
1540 
1541  // Parse the benefit metadata: benefit(<integer-value>)
1542  if (metadataStr == "benefit") {
1543  if (benefitLoc) {
1544  return emitErrorAndNote(metadataLoc,
1545  "pattern benefit has already been specified",
1546  *benefitLoc, "see previous definition here");
1547  }
1548  if (failed(parseToken(Token::l_paren,
1549  "expected `(` before pattern benefit")))
1550  return failure();
1551 
1552  uint16_t benefitValue = 0;
1553  if (curToken.isNot(Token::integer))
1554  return emitError("expected integral pattern benefit");
1555  if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue))
1556  return emitError(
1557  "expected pattern benefit to fit within a 16-bit integer");
1558  consumeToken(Token::integer);
1559 
1560  metadata.benefit = benefitValue;
1561  benefitLoc = metadataLoc;
1562 
1563  if (failed(
1564  parseToken(Token::r_paren, "expected `)` after pattern benefit")))
1565  return failure();
1566  continue;
1567  }
1568 
1569  // Parse the bounded recursion metadata: recursion
1570  if (metadataStr == "recursion") {
1571  if (hasBoundedRecursionLoc) {
1572  return emitErrorAndNote(
1573  metadataLoc,
1574  "pattern recursion metadata has already been specified",
1575  *hasBoundedRecursionLoc, "see previous definition here");
1576  }
1577  metadata.hasBoundedRecursion = true;
1578  hasBoundedRecursionLoc = metadataLoc;
1579  continue;
1580  }
1581 
1582  return emitError(metadataLoc, "unknown pattern metadata");
1583  } while (consumeIf(Token::comma));
1584 
1585  return success();
1586 }
1587 
1588 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() {
1589  consumeToken(Token::less);
1590 
1591  FailureOr<ast::Expr *> typeExpr = parseExpr();
1592  if (failed(typeExpr) ||
1593  failed(parseToken(Token::greater,
1594  "expected `>` after variable type constraint")))
1595  return failure();
1596  return typeExpr;
1597 }
1598 
1599 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) {
1600  assert(curDeclScope && "defining decl outside of a decl scope");
1601  if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) {
1602  return emitErrorAndNote(
1603  name.getLoc(), "`" + name.getName() + "` has already been defined",
1604  lastDecl->getName()->getLoc(), "see previous definition here");
1605  }
1606  return success();
1607 }
1608 
1609 FailureOr<ast::VariableDecl *>
1610 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1611  ast::Expr *initExpr,
1612  ArrayRef<ast::ConstraintRef> constraints) {
1613  assert(curDeclScope && "defining variable outside of decl scope");
1614  const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc);
1615 
1616  // If the name of the variable indicates a special variable, we don't add it
1617  // to the scope. This variable is local to the definition point.
1618  if (name.empty() || name == "_") {
1619  return ast::VariableDecl::create(ctx, nameDecl, type, initExpr,
1620  constraints);
1621  }
1622  if (failed(checkDefineNamedDecl(nameDecl)))
1623  return failure();
1624 
1625  auto *varDecl =
1626  ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints);
1627  curDeclScope->add(varDecl);
1628  return varDecl;
1629 }
1630 
1631 FailureOr<ast::VariableDecl *>
1632 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1633  ArrayRef<ast::ConstraintRef> constraints) {
1634  return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr,
1635  constraints);
1636 }
1637 
1638 LogicalResult Parser::parseVariableDeclConstraintList(
1639  SmallVectorImpl<ast::ConstraintRef> &constraints) {
1640  std::optional<SMRange> typeConstraint;
1641  auto parseSingleConstraint = [&] {
1642  FailureOr<ast::ConstraintRef> constraint = parseConstraint(
1643  typeConstraint, constraints, /*allowInlineTypeConstraints=*/true);
1644  if (failed(constraint))
1645  return failure();
1646  constraints.push_back(*constraint);
1647  return success();
1648  };
1649 
1650  // Check to see if this is a single constraint, or a list.
1651  if (!consumeIf(Token::l_square))
1652  return parseSingleConstraint();
1653 
1654  do {
1655  if (failed(parseSingleConstraint()))
1656  return failure();
1657  } while (consumeIf(Token::comma));
1658  return parseToken(Token::r_square, "expected `]` after constraint list");
1659 }
1660 
1661 FailureOr<ast::ConstraintRef>
1662 Parser::parseConstraint(std::optional<SMRange> &typeConstraint,
1663  ArrayRef<ast::ConstraintRef> existingConstraints,
1664  bool allowInlineTypeConstraints) {
1665  auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
1666  if (!allowInlineTypeConstraints) {
1667  return emitError(
1668  curToken.getLoc(),
1669  "inline `Attr`, `Value`, and `ValueRange` type constraints are not "
1670  "permitted on arguments or results");
1671  }
1672  if (typeConstraint)
1673  return emitErrorAndNote(
1674  curToken.getLoc(),
1675  "the type of this variable has already been constrained",
1676  *typeConstraint, "see previous constraint location here");
1677  FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr();
1678  if (failed(constraintExpr))
1679  return failure();
1680  typeExpr = *constraintExpr;
1681  typeConstraint = typeExpr->getLoc();
1682  return success();
1683  };
1684 
1685  SMRange loc = curToken.getLoc();
1686  switch (curToken.getKind()) {
1687  case Token::kw_Attr: {
1688  consumeToken(Token::kw_Attr);
1689 
1690  // Check for a type constraint.
1691  ast::Expr *typeExpr = nullptr;
1692  if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1693  return failure();
1694  return ast::ConstraintRef(
1695  ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc);
1696  }
1697  case Token::kw_Op: {
1698  consumeToken(Token::kw_Op);
1699 
1700  // Parse an optional operation name. If the name isn't provided, this refers
1701  // to "any" operation.
1702  FailureOr<ast::OpNameDecl *> opName =
1703  parseWrappedOperationName(/*allowEmptyName=*/true);
1704  if (failed(opName))
1705  return failure();
1706 
1707  return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName),
1708  loc);
1709  }
1710  case Token::kw_Type:
1711  consumeToken(Token::kw_Type);
1712  return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc);
1713  case Token::kw_TypeRange:
1714  consumeToken(Token::kw_TypeRange);
1716  loc);
1717  case Token::kw_Value: {
1718  consumeToken(Token::kw_Value);
1719 
1720  // Check for a type constraint.
1721  ast::Expr *typeExpr = nullptr;
1722  if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1723  return failure();
1724 
1725  return ast::ConstraintRef(
1726  ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc);
1727  }
1728  case Token::kw_ValueRange: {
1729  consumeToken(Token::kw_ValueRange);
1730 
1731  // Check for a type constraint.
1732  ast::Expr *typeExpr = nullptr;
1733  if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1734  return failure();
1735 
1736  return ast::ConstraintRef(
1737  ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc);
1738  }
1739 
1740  case Token::kw_Constraint: {
1741  // Handle an inline constraint.
1742  FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1743  if (failed(decl))
1744  return failure();
1745  return ast::ConstraintRef(*decl, loc);
1746  }
1747  case Token::identifier: {
1748  StringRef constraintName = curToken.getSpelling();
1749  consumeToken(Token::identifier);
1750 
1751  // Lookup the referenced constraint.
1752  ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName);
1753  if (!cstDecl) {
1754  return emitError(loc, "unknown reference to constraint `" +
1755  constraintName + "`");
1756  }
1757 
1758  // Handle a reference to a proper constraint.
1759  if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl))
1760  return ast::ConstraintRef(cst, loc);
1761 
1762  return emitErrorAndNote(
1763  loc, "invalid reference to non-constraint", cstDecl->getLoc(),
1764  "see the definition of `" + constraintName + "` here");
1765  }
1766  // Handle single entity constraint code completion.
1767  case Token::code_complete: {
1768  // Try to infer the current type for use by code completion.
1769  ast::Type inferredType;
1770  if (failed(validateVariableConstraints(existingConstraints, inferredType)))
1771  return failure();
1772 
1773  return codeCompleteConstraintName(inferredType, allowInlineTypeConstraints);
1774  }
1775  default:
1776  break;
1777  }
1778  return emitError(loc, "expected identifier constraint");
1779 }
1780 
1781 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() {
1782  std::optional<SMRange> typeConstraint;
1783  return parseConstraint(typeConstraint, /*existingConstraints=*/std::nullopt,
1784  /*allowInlineTypeConstraints=*/false);
1785 }
1786 
1787 //===----------------------------------------------------------------------===//
1788 // Exprs
1789 
1790 FailureOr<ast::Expr *> Parser::parseExpr() {
1791  if (curToken.is(Token::underscore))
1792  return parseUnderscoreExpr();
1793 
1794  // Parse the LHS expression.
1795  FailureOr<ast::Expr *> lhsExpr;
1796  switch (curToken.getKind()) {
1797  case Token::kw_attr:
1798  lhsExpr = parseAttributeExpr();
1799  break;
1800  case Token::kw_Constraint:
1801  lhsExpr = parseInlineConstraintLambdaExpr();
1802  break;
1803  case Token::kw_not:
1804  lhsExpr = parseNegatedExpr();
1805  break;
1806  case Token::identifier:
1807  lhsExpr = parseIdentifierExpr();
1808  break;
1809  case Token::kw_op:
1810  lhsExpr = parseOperationExpr();
1811  break;
1812  case Token::kw_Rewrite:
1813  lhsExpr = parseInlineRewriteLambdaExpr();
1814  break;
1815  case Token::kw_type:
1816  lhsExpr = parseTypeExpr();
1817  break;
1818  case Token::l_paren:
1819  lhsExpr = parseTupleExpr();
1820  break;
1821  default:
1822  return emitError("expected expression");
1823  }
1824  if (failed(lhsExpr))
1825  return failure();
1826 
1827  // Check for an operator expression.
1828  while (true) {
1829  switch (curToken.getKind()) {
1830  case Token::dot:
1831  lhsExpr = parseMemberAccessExpr(*lhsExpr);
1832  break;
1833  case Token::l_paren:
1834  lhsExpr = parseCallExpr(*lhsExpr);
1835  break;
1836  default:
1837  return lhsExpr;
1838  }
1839  if (failed(lhsExpr))
1840  return failure();
1841  }
1842 }
1843 
1844 FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
1845  SMRange loc = curToken.getLoc();
1846  consumeToken(Token::kw_attr);
1847 
1848  // If we aren't followed by a `<`, the `attr` keyword is treated as a normal
1849  // identifier.
1850  if (!consumeIf(Token::less)) {
1851  resetToken(loc);
1852  return parseIdentifierExpr();
1853  }
1854 
1855  if (!curToken.isString())
1856  return emitError("expected string literal containing MLIR attribute");
1857  std::string attrExpr = curToken.getStringValue();
1858  consumeToken();
1859 
1860  loc.End = curToken.getEndLoc();
1861  if (failed(
1862  parseToken(Token::greater, "expected `>` after attribute literal")))
1863  return failure();
1864  return ast::AttributeExpr::create(ctx, loc, attrExpr);
1865 }
1866 
1867 FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr,
1868  bool isNegated) {
1869  consumeToken(Token::l_paren);
1870 
1871  // Parse the arguments of the call.
1872  SmallVector<ast::Expr *> arguments;
1873  if (curToken.isNot(Token::r_paren)) {
1874  do {
1875  // Handle code completion for the call arguments.
1876  if (curToken.is(Token::code_complete)) {
1877  codeCompleteCallSignature(parentExpr, arguments.size());
1878  return failure();
1879  }
1880 
1881  FailureOr<ast::Expr *> argument = parseExpr();
1882  if (failed(argument))
1883  return failure();
1884  arguments.push_back(*argument);
1885  } while (consumeIf(Token::comma));
1886  }
1887 
1888  SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1889  if (failed(parseToken(Token::r_paren, "expected `)` after argument list")))
1890  return failure();
1891 
1892  return createCallExpr(loc, parentExpr, arguments, isNegated);
1893 }
1894 
1895 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
1896  ast::Decl *decl = curDeclScope->lookup(name);
1897  if (!decl)
1898  return emitError(loc, "undefined reference to `" + name + "`");
1899 
1900  return createDeclRefExpr(loc, decl);
1901 }
1902 
1903 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
1904  StringRef name = curToken.getSpelling();
1905  SMRange nameLoc = curToken.getLoc();
1906  consumeToken();
1907 
1908  // Check to see if this is a decl ref expression that defines a variable
1909  // inline.
1910  if (consumeIf(Token::colon)) {
1911  SmallVector<ast::ConstraintRef> constraints;
1912  if (failed(parseVariableDeclConstraintList(constraints)))
1913  return failure();
1914  ast::Type type;
1915  if (failed(validateVariableConstraints(constraints, type)))
1916  return failure();
1917  return createInlineVariableExpr(type, name, nameLoc, constraints);
1918  }
1919 
1920  return parseDeclRefExpr(name, nameLoc);
1921 }
1922 
1923 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() {
1924  FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1925  if (failed(decl))
1926  return failure();
1927 
1928  return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1930 }
1931 
1932 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() {
1933  FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl();
1934  if (failed(decl))
1935  return failure();
1936 
1937  return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1938  ast::RewriteType::get(ctx));
1939 }
1940 
1941 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
1942  SMRange dotLoc = curToken.getLoc();
1943  consumeToken(Token::dot);
1944 
1945  // Check for code completion of the member name.
1946  if (curToken.is(Token::code_complete))
1947  return codeCompleteMemberAccess(parentExpr);
1948 
1949  // Parse the member name.
1950  Token memberNameTok = curToken;
1951  if (memberNameTok.isNot(Token::identifier, Token::integer) &&
1952  !memberNameTok.isKeyword())
1953  return emitError(dotLoc, "expected identifier or numeric member name");
1954  StringRef memberName = memberNameTok.getSpelling();
1955  SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1956  consumeToken();
1957 
1958  return createMemberAccessExpr(parentExpr, memberName, loc);
1959 }
1960 
1961 FailureOr<ast::Expr *> Parser::parseNegatedExpr() {
1962  consumeToken(Token::kw_not);
1963  // Only native constraints are supported after negation
1964  if (!curToken.is(Token::identifier))
1965  return emitError("expected native constraint");
1966  FailureOr<ast::Expr *> identifierExpr = parseIdentifierExpr();
1967  if (failed(identifierExpr))
1968  return failure();
1969  if (!curToken.is(Token::l_paren))
1970  return emitError("expected `(` after function name");
1971  return parseCallExpr(*identifierExpr, /*isNegated = */ true);
1972 }
1973 
1974 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
1975  SMRange loc = curToken.getLoc();
1976 
1977  // Check for code completion for the dialect name.
1978  if (curToken.is(Token::code_complete))
1979  return codeCompleteDialectName();
1980 
1981  // Handle the case of an no operation name.
1982  if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) {
1983  if (allowEmptyName)
1984  return ast::OpNameDecl::create(ctx, SMRange());
1985  return emitError("expected dialect namespace");
1986  }
1987  StringRef name = curToken.getSpelling();
1988  consumeToken();
1989 
1990  // Otherwise, this is a literal operation name.
1991  if (failed(parseToken(Token::dot, "expected `.` after dialect namespace")))
1992  return failure();
1993 
1994  // Check for code completion for the operation name.
1995  if (curToken.is(Token::code_complete))
1996  return codeCompleteOperationName(name);
1997 
1998  if (curToken.isNot(Token::identifier) && !curToken.isKeyword())
1999  return emitError("expected operation name after dialect namespace");
2000 
2001  name = StringRef(name.data(), name.size() + 1);
2002  do {
2003  name = StringRef(name.data(), name.size() + curToken.getSpelling().size());
2004  loc.End = curToken.getEndLoc();
2005  consumeToken();
2006  } while (curToken.isAny(Token::identifier, Token::dot) ||
2007  curToken.isKeyword());
2008  return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc));
2009 }
2010 
2011 FailureOr<ast::OpNameDecl *>
2012 Parser::parseWrappedOperationName(bool allowEmptyName) {
2013  if (!consumeIf(Token::less))
2014  return ast::OpNameDecl::create(ctx, SMRange());
2015 
2016  FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName);
2017  if (failed(opNameDecl))
2018  return failure();
2019 
2020  if (failed(parseToken(Token::greater, "expected `>` after operation name")))
2021  return failure();
2022  return opNameDecl;
2023 }
2024 
2025 FailureOr<ast::Expr *>
2026 Parser::parseOperationExpr(OpResultTypeContext inputResultTypeContext) {
2027  SMRange loc = curToken.getLoc();
2028  consumeToken(Token::kw_op);
2029 
2030  // If it isn't followed by a `<`, the `op` keyword is treated as a normal
2031  // identifier.
2032  if (curToken.isNot(Token::less)) {
2033  resetToken(loc);
2034  return parseIdentifierExpr();
2035  }
2036 
2037  // Parse the operation name. The name may be elided, in which case the
2038  // operation refers to "any" operation(i.e. a difference between `MyOp` and
2039  // `Operation*`). Operation names within a rewrite context must be named.
2040  bool allowEmptyName = parserContext != ParserContext::Rewrite;
2041  FailureOr<ast::OpNameDecl *> opNameDecl =
2042  parseWrappedOperationName(allowEmptyName);
2043  if (failed(opNameDecl))
2044  return failure();
2045  std::optional<StringRef> opName = (*opNameDecl)->getName();
2046 
2047  // Functor used to create an implicit range variable, used for implicit "all"
2048  // operand or results variables.
2049  auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) {
2050  FailureOr<ast::VariableDecl *> rangeVar =
2051  defineVariableDecl("_", loc, type, ast::ConstraintRef(cst, loc));
2052  assert(succeeded(rangeVar) && "expected range variable to be valid");
2053  return ast::DeclRefExpr::create(ctx, loc, *rangeVar, type);
2054  };
2055 
2056  // Check for the optional list of operands.
2057  SmallVector<ast::Expr *> operands;
2058  if (!consumeIf(Token::l_paren)) {
2059  // If the operand list isn't specified and we are in a match context, define
2060  // an inplace unconstrained operand range corresponding to all of the
2061  // operands of the operation. This avoids treating zero operands the same
2062  // way as "unconstrained operands".
2063  if (parserContext != ParserContext::Rewrite) {
2064  operands.push_back(createImplicitRangeVar(
2065  ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy));
2066  }
2067  } else if (!consumeIf(Token::r_paren)) {
2068  // If the operand list was specified and non-empty, parse the operands.
2069  do {
2070  // Check for operand signature code completion.
2071  if (curToken.is(Token::code_complete)) {
2072  codeCompleteOperationOperandsSignature(opName, operands.size());
2073  return failure();
2074  }
2075 
2076  FailureOr<ast::Expr *> operand = parseExpr();
2077  if (failed(operand))
2078  return failure();
2079  operands.push_back(*operand);
2080  } while (consumeIf(Token::comma));
2081 
2082  if (failed(parseToken(Token::r_paren,
2083  "expected `)` after operation operand list")))
2084  return failure();
2085  }
2086 
2087  // Check for the optional list of attributes.
2089  if (consumeIf(Token::l_brace)) {
2090  do {
2091  FailureOr<ast::NamedAttributeDecl *> decl =
2092  parseNamedAttributeDecl(opName);
2093  if (failed(decl))
2094  return failure();
2095  attributes.emplace_back(*decl);
2096  } while (consumeIf(Token::comma));
2097 
2098  if (failed(parseToken(Token::r_brace,
2099  "expected `}` after operation attribute list")))
2100  return failure();
2101  }
2102 
2103  // Handle the result types of the operation.
2104  SmallVector<ast::Expr *> resultTypes;
2105  OpResultTypeContext resultTypeContext = inputResultTypeContext;
2106 
2107  // Check for an explicit list of result types.
2108  if (consumeIf(Token::arrow)) {
2109  if (failed(parseToken(Token::l_paren,
2110  "expected `(` before operation result type list")))
2111  return failure();
2112 
2113  // If result types are provided, initially assume that the operation does
2114  // not rely on type inferrence. We don't assert that it isn't, because we
2115  // may be inferring the value of some type/type range variables, but given
2116  // that these variables may be defined in calls we can't always discern when
2117  // this is the case.
2118  resultTypeContext = OpResultTypeContext::Explicit;
2119 
2120  // Handle the case of an empty result list.
2121  if (!consumeIf(Token::r_paren)) {
2122  do {
2123  // Check for result signature code completion.
2124  if (curToken.is(Token::code_complete)) {
2125  codeCompleteOperationResultsSignature(opName, resultTypes.size());
2126  return failure();
2127  }
2128 
2129  FailureOr<ast::Expr *> resultTypeExpr = parseExpr();
2130  if (failed(resultTypeExpr))
2131  return failure();
2132  resultTypes.push_back(*resultTypeExpr);
2133  } while (consumeIf(Token::comma));
2134 
2135  if (failed(parseToken(Token::r_paren,
2136  "expected `)` after operation result type list")))
2137  return failure();
2138  }
2139  } else if (parserContext != ParserContext::Rewrite) {
2140  // If the result list isn't specified and we are in a match context, define
2141  // an inplace unconstrained result range corresponding to all of the results
2142  // of the operation. This avoids treating zero results the same way as
2143  // "unconstrained results".
2144  resultTypes.push_back(createImplicitRangeVar(
2145  ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy));
2146  } else if (resultTypeContext == OpResultTypeContext::Explicit) {
2147  // If the result list isn't specified and we are in a rewrite, try to infer
2148  // them at runtime instead.
2149  resultTypeContext = OpResultTypeContext::Interface;
2150  }
2151 
2152  return createOperationExpr(loc, *opNameDecl, resultTypeContext, operands,
2153  attributes, resultTypes);
2154 }
2155 
2156 FailureOr<ast::Expr *> Parser::parseTupleExpr() {
2157  SMRange loc = curToken.getLoc();
2158  consumeToken(Token::l_paren);
2159 
2160  DenseMap<StringRef, SMRange> usedNames;
2161  SmallVector<StringRef> elementNames;
2162  SmallVector<ast::Expr *> elements;
2163  if (curToken.isNot(Token::r_paren)) {
2164  do {
2165  // Check for the optional element name assignment before the value.
2166  StringRef elementName;
2167  if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
2168  Token elementNameTok = curToken;
2169  consumeToken();
2170 
2171  // The element name is only present if followed by an `=`.
2172  if (consumeIf(Token::equal)) {
2173  elementName = elementNameTok.getSpelling();
2174 
2175  // Check to see if this name is already used.
2176  auto elementNameIt =
2177  usedNames.try_emplace(elementName, elementNameTok.getLoc());
2178  if (!elementNameIt.second) {
2179  return emitErrorAndNote(
2180  elementNameTok.getLoc(),
2181  llvm::formatv("duplicate tuple element label `{0}`",
2182  elementName),
2183  elementNameIt.first->getSecond(),
2184  "see previous label use here");
2185  }
2186  } else {
2187  // Otherwise, we treat this as part of an expression so reset the
2188  // lexer.
2189  resetToken(elementNameTok.getLoc());
2190  }
2191  }
2192  elementNames.push_back(elementName);
2193 
2194  // Parse the tuple element value.
2195  FailureOr<ast::Expr *> element = parseExpr();
2196  if (failed(element))
2197  return failure();
2198  elements.push_back(*element);
2199  } while (consumeIf(Token::comma));
2200  }
2201  loc.End = curToken.getEndLoc();
2202  if (failed(
2203  parseToken(Token::r_paren, "expected `)` after tuple element list")))
2204  return failure();
2205  return createTupleExpr(loc, elements, elementNames);
2206 }
2207 
2208 FailureOr<ast::Expr *> Parser::parseTypeExpr() {
2209  SMRange loc = curToken.getLoc();
2210  consumeToken(Token::kw_type);
2211 
2212  // If we aren't followed by a `<`, the `type` keyword is treated as a normal
2213  // identifier.
2214  if (!consumeIf(Token::less)) {
2215  resetToken(loc);
2216  return parseIdentifierExpr();
2217  }
2218 
2219  if (!curToken.isString())
2220  return emitError("expected string literal containing MLIR type");
2221  std::string attrExpr = curToken.getStringValue();
2222  consumeToken();
2223 
2224  loc.End = curToken.getEndLoc();
2225  if (failed(parseToken(Token::greater, "expected `>` after type literal")))
2226  return failure();
2227  return ast::TypeExpr::create(ctx, loc, attrExpr);
2228 }
2229 
2230 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
2231  StringRef name = curToken.getSpelling();
2232  SMRange nameLoc = curToken.getLoc();
2233  consumeToken(Token::underscore);
2234 
2235  // Underscore expressions require a constraint list.
2236  if (failed(parseToken(Token::colon, "expected `:` after `_` variable")))
2237  return failure();
2238 
2239  // Parse the constraints for the expression.
2240  SmallVector<ast::ConstraintRef> constraints;
2241  if (failed(parseVariableDeclConstraintList(constraints)))
2242  return failure();
2243 
2244  ast::Type type;
2245  if (failed(validateVariableConstraints(constraints, type)))
2246  return failure();
2247  return createInlineVariableExpr(type, name, nameLoc, constraints);
2248 }
2249 
2250 //===----------------------------------------------------------------------===//
2251 // Stmts
2252 
2253 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) {
2254  FailureOr<ast::Stmt *> stmt;
2255  switch (curToken.getKind()) {
2256  case Token::kw_erase:
2257  stmt = parseEraseStmt();
2258  break;
2259  case Token::kw_let:
2260  stmt = parseLetStmt();
2261  break;
2262  case Token::kw_replace:
2263  stmt = parseReplaceStmt();
2264  break;
2265  case Token::kw_return:
2266  stmt = parseReturnStmt();
2267  break;
2268  case Token::kw_rewrite:
2269  stmt = parseRewriteStmt();
2270  break;
2271  default:
2272  stmt = parseExpr();
2273  break;
2274  }
2275  if (failed(stmt) ||
2276  (expectTerminalSemicolon &&
2277  failed(parseToken(Token::semicolon, "expected `;` after statement"))))
2278  return failure();
2279  return stmt;
2280 }
2281 
2282 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() {
2283  SMLoc startLoc = curToken.getStartLoc();
2284  consumeToken(Token::l_brace);
2285 
2286  // Push a new block scope and parse any nested statements.
2287  pushDeclScope();
2288  SmallVector<ast::Stmt *> statements;
2289  while (curToken.isNot(Token::r_brace)) {
2290  FailureOr<ast::Stmt *> statement = parseStmt();
2291  if (failed(statement))
2292  return popDeclScope(), failure();
2293  statements.push_back(*statement);
2294  }
2295  popDeclScope();
2296 
2297  // Consume the end brace.
2298  SMRange location(startLoc, curToken.getEndLoc());
2299  consumeToken(Token::r_brace);
2300 
2301  return ast::CompoundStmt::create(ctx, location, statements);
2302 }
2303 
2304 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
2305  if (parserContext == ParserContext::Constraint)
2306  return emitError("`erase` cannot be used within a Constraint");
2307  SMRange loc = curToken.getLoc();
2308  consumeToken(Token::kw_erase);
2309 
2310  // Parse the root operation expression.
2311  FailureOr<ast::Expr *> rootOp = parseExpr();
2312  if (failed(rootOp))
2313  return failure();
2314 
2315  return createEraseStmt(loc, *rootOp);
2316 }
2317 
2318 FailureOr<ast::LetStmt *> Parser::parseLetStmt() {
2319  SMRange loc = curToken.getLoc();
2320  consumeToken(Token::kw_let);
2321 
2322  // Parse the name of the new variable.
2323  SMRange varLoc = curToken.getLoc();
2324  if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) {
2325  // `_` is a reserved variable name.
2326  if (curToken.is(Token::underscore)) {
2327  return emitError(varLoc,
2328  "`_` may only be used to define \"inline\" variables");
2329  }
2330  return emitError(varLoc,
2331  "expected identifier after `let` to name a new variable");
2332  }
2333  StringRef varName = curToken.getSpelling();
2334  consumeToken();
2335 
2336  // Parse the optional set of constraints.
2337  SmallVector<ast::ConstraintRef> constraints;
2338  if (consumeIf(Token::colon) &&
2339  failed(parseVariableDeclConstraintList(constraints)))
2340  return failure();
2341 
2342  // Parse the optional initializer expression.
2343  ast::Expr *initializer = nullptr;
2344  if (consumeIf(Token::equal)) {
2345  FailureOr<ast::Expr *> initOrFailure = parseExpr();
2346  if (failed(initOrFailure))
2347  return failure();
2348  initializer = *initOrFailure;
2349 
2350  // Check that the constraints are compatible with having an initializer,
2351  // e.g. type constraints cannot be used with initializers.
2352  for (ast::ConstraintRef constraint : constraints) {
2353  LogicalResult result =
2354  TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint)
2356  ast::ValueRangeConstraintDecl>([&](const auto *cst) {
2357  if (cst->getTypeExpr()) {
2358  return this->emitError(
2359  constraint.referenceLoc,
2360  "type constraints are not permitted on variables with "
2361  "initializers");
2362  }
2363  return success();
2364  })
2365  .Default(success());
2366  if (failed(result))
2367  return failure();
2368  }
2369  }
2370 
2371  FailureOr<ast::VariableDecl *> varDecl =
2372  createVariableDecl(varName, varLoc, initializer, constraints);
2373  if (failed(varDecl))
2374  return failure();
2375  return ast::LetStmt::create(ctx, loc, *varDecl);
2376 }
2377 
2378 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
2379  if (parserContext == ParserContext::Constraint)
2380  return emitError("`replace` cannot be used within a Constraint");
2381  SMRange loc = curToken.getLoc();
2382  consumeToken(Token::kw_replace);
2383 
2384  // Parse the root operation expression.
2385  FailureOr<ast::Expr *> rootOp = parseExpr();
2386  if (failed(rootOp))
2387  return failure();
2388 
2389  if (failed(
2390  parseToken(Token::kw_with, "expected `with` after root operation")))
2391  return failure();
2392 
2393  // The replacement portion of this statement is within a rewrite context.
2394  llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
2395 
2396  // Parse the replacement values.
2397  SmallVector<ast::Expr *> replValues;
2398  if (consumeIf(Token::l_paren)) {
2399  if (consumeIf(Token::r_paren)) {
2400  return emitError(
2401  loc, "expected at least one replacement value, consider using "
2402  "`erase` if no replacement values are desired");
2403  }
2404 
2405  do {
2406  FailureOr<ast::Expr *> replExpr = parseExpr();
2407  if (failed(replExpr))
2408  return failure();
2409  replValues.emplace_back(*replExpr);
2410  } while (consumeIf(Token::comma));
2411 
2412  if (failed(parseToken(Token::r_paren,
2413  "expected `)` after replacement values")))
2414  return failure();
2415  } else {
2416  // Handle replacement with an operation uniquely, as the replacement
2417  // operation supports type inferrence from the root operation.
2418  FailureOr<ast::Expr *> replExpr;
2419  if (curToken.is(Token::kw_op))
2420  replExpr = parseOperationExpr(OpResultTypeContext::Replacement);
2421  else
2422  replExpr = parseExpr();
2423  if (failed(replExpr))
2424  return failure();
2425  replValues.emplace_back(*replExpr);
2426  }
2427 
2428  return createReplaceStmt(loc, *rootOp, replValues);
2429 }
2430 
2431 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() {
2432  SMRange loc = curToken.getLoc();
2433  consumeToken(Token::kw_return);
2434 
2435  // Parse the result value.
2436  FailureOr<ast::Expr *> resultExpr = parseExpr();
2437  if (failed(resultExpr))
2438  return failure();
2439 
2440  return ast::ReturnStmt::create(ctx, loc, *resultExpr);
2441 }
2442 
2443 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
2444  if (parserContext == ParserContext::Constraint)
2445  return emitError("`rewrite` cannot be used within a Constraint");
2446  SMRange loc = curToken.getLoc();
2447  consumeToken(Token::kw_rewrite);
2448 
2449  // Parse the root operation.
2450  FailureOr<ast::Expr *> rootOp = parseExpr();
2451  if (failed(rootOp))
2452  return failure();
2453 
2454  if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body")))
2455  return failure();
2456 
2457  if (curToken.isNot(Token::l_brace))
2458  return emitError("expected `{` to start rewrite body");
2459 
2460  // The rewrite body of this statement is within a rewrite context.
2461  llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
2462 
2463  FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt();
2464  if (failed(rewriteBody))
2465  return failure();
2466 
2467  // Verify the rewrite body.
2468  for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) {
2469  if (isa<ast::ReturnStmt>(stmt)) {
2470  return emitError(stmt->getLoc(),
2471  "`return` statements are only permitted within a "
2472  "`Constraint` or `Rewrite` body");
2473  }
2474  }
2475 
2476  return createRewriteStmt(loc, *rootOp, *rewriteBody);
2477 }
2478 
2479 //===----------------------------------------------------------------------===//
2480 // Creation+Analysis
2481 //===----------------------------------------------------------------------===//
2482 
2483 //===----------------------------------------------------------------------===//
2484 // Decls
2485 
2486 ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) {
2487  // Unwrap reference expressions.
2488  if (auto *init = dyn_cast<ast::DeclRefExpr>(node))
2489  node = init->getDecl();
2490  return dyn_cast<ast::CallableDecl>(node);
2491 }
2492 
2493 FailureOr<ast::PatternDecl *>
2494 Parser::createPatternDecl(SMRange loc, const ast::Name *name,
2495  const ParsedPatternMetadata &metadata,
2496  ast::CompoundStmt *body) {
2497  return ast::PatternDecl::create(ctx, loc, name, metadata.benefit,
2498  metadata.hasBoundedRecursion, body);
2499 }
2500 
2501 ast::Type Parser::createUserConstraintRewriteResultType(
2503  // Single result decls use the type of the single result.
2504  if (results.size() == 1)
2505  return results[0]->getType();
2506 
2507  // Multiple results use a tuple type, with the types and names grabbed from
2508  // the result variable decls.
2509  auto resultTypes = llvm::map_range(
2510  results, [&](const auto *result) { return result->getType(); });
2511  auto resultNames = llvm::map_range(
2512  results, [&](const auto *result) { return result->getName().getName(); });
2513  return ast::TupleType::get(ctx, llvm::to_vector(resultTypes),
2514  llvm::to_vector(resultNames));
2515 }
2516 
2517 template <typename T>
2518 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl(
2519  const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
2520  ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
2521  ast::CompoundStmt *body) {
2522  if (!body->getChildren().empty()) {
2523  if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) {
2524  ast::Expr *resultExpr = retStmt->getResultExpr();
2525 
2526  // Process the result of the decl. If no explicit signature results
2527  // were provided, check for return type inference. Otherwise, check that
2528  // the return expression can be converted to the expected type.
2529  if (results.empty())
2530  resultType = resultExpr->getType();
2531  else if (failed(convertExpressionTo(resultExpr, resultType)))
2532  return failure();
2533  else
2534  retStmt->setResultExpr(resultExpr);
2535  }
2536  }
2537  return T::createPDLL(ctx, name, arguments, results, body, resultType);
2538 }
2539 
2540 FailureOr<ast::VariableDecl *>
2541 Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
2542  ArrayRef<ast::ConstraintRef> constraints) {
2543  // The type of the variable, which is expected to be inferred by either a
2544  // constraint or an initializer expression.
2545  ast::Type type;
2546  if (failed(validateVariableConstraints(constraints, type)))
2547  return failure();
2548 
2549  if (initializer) {
2550  // Update the variable type based on the initializer, or try to convert the
2551  // initializer to the existing type.
2552  if (!type)
2553  type = initializer->getType();
2554  else if (ast::Type mergedType = type.refineWith(initializer->getType()))
2555  type = mergedType;
2556  else if (failed(convertExpressionTo(initializer, type)))
2557  return failure();
2558 
2559  // Otherwise, if there is no initializer check that the type has already
2560  // been resolved from the constraint list.
2561  } else if (!type) {
2562  return emitErrorAndNote(
2563  loc, "unable to infer type for variable `" + name + "`", loc,
2564  "the type of a variable must be inferable from the constraint "
2565  "list or the initializer");
2566  }
2567 
2568  // Constraint types cannot be used when defining variables.
2569  if (isa<ast::ConstraintType, ast::RewriteType>(type)) {
2570  return emitError(
2571  loc, llvm::formatv("unable to define variable of `{0}` type", type));
2572  }
2573 
2574  // Try to define a variable with the given name.
2575  FailureOr<ast::VariableDecl *> varDecl =
2576  defineVariableDecl(name, loc, type, initializer, constraints);
2577  if (failed(varDecl))
2578  return failure();
2579 
2580  return *varDecl;
2581 }
2582 
2583 FailureOr<ast::VariableDecl *>
2584 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
2585  const ast::ConstraintRef &constraint) {
2586  ast::Type argType;
2587  if (failed(validateVariableConstraint(constraint, argType)))
2588  return failure();
2589  return defineVariableDecl(name, loc, argType, constraint);
2590 }
2591 
2592 LogicalResult
2593 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
2594  ast::Type &inferredType) {
2595  for (const ast::ConstraintRef &ref : constraints)
2596  if (failed(validateVariableConstraint(ref, inferredType)))
2597  return failure();
2598  return success();
2599 }
2600 
2601 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
2602  ast::Type &inferredType) {
2603  ast::Type constraintType;
2604  if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) {
2605  if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2606  if (failed(validateTypeConstraintExpr(typeExpr)))
2607  return failure();
2608  }
2609  constraintType = ast::AttributeType::get(ctx);
2610  } else if (const auto *cst =
2611  dyn_cast<ast::OpConstraintDecl>(ref.constraint)) {
2612  constraintType = ast::OperationType::get(
2613  ctx, cst->getName(), lookupODSOperation(cst->getName()));
2614  } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) {
2615  constraintType = typeTy;
2616  } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) {
2617  constraintType = typeRangeTy;
2618  } else if (const auto *cst =
2619  dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) {
2620  if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2621  if (failed(validateTypeConstraintExpr(typeExpr)))
2622  return failure();
2623  }
2624  constraintType = valueTy;
2625  } else if (const auto *cst =
2626  dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) {
2627  if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2628  if (failed(validateTypeRangeConstraintExpr(typeExpr)))
2629  return failure();
2630  }
2631  constraintType = valueRangeTy;
2632  } else if (const auto *cst =
2633  dyn_cast<ast::UserConstraintDecl>(ref.constraint)) {
2634  ArrayRef<ast::VariableDecl *> inputs = cst->getInputs();
2635  if (inputs.size() != 1) {
2636  return emitErrorAndNote(ref.referenceLoc,
2637  "`Constraint`s applied via a variable constraint "
2638  "list must take a single input, but got " +
2639  Twine(inputs.size()),
2640  cst->getLoc(),
2641  "see definition of constraint here");
2642  }
2643  constraintType = inputs.front()->getType();
2644  } else {
2645  llvm_unreachable("unknown constraint type");
2646  }
2647 
2648  // Check that the constraint type is compatible with the current inferred
2649  // type.
2650  if (!inferredType) {
2651  inferredType = constraintType;
2652  } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) {
2653  inferredType = mergedTy;
2654  } else {
2655  return emitError(ref.referenceLoc,
2656  llvm::formatv("constraint type `{0}` is incompatible "
2657  "with the previously inferred type `{1}`",
2658  constraintType, inferredType));
2659  }
2660  return success();
2661 }
2662 
2663 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) {
2664  ast::Type typeExprType = typeExpr->getType();
2665  if (typeExprType != typeTy) {
2666  return emitError(typeExpr->getLoc(),
2667  "expected expression of `Type` in type constraint");
2668  }
2669  return success();
2670 }
2671 
2672 LogicalResult
2673 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) {
2674  ast::Type typeExprType = typeExpr->getType();
2675  if (typeExprType != typeRangeTy) {
2676  return emitError(typeExpr->getLoc(),
2677  "expected expression of `TypeRange` in type constraint");
2678  }
2679  return success();
2680 }
2681 
2682 //===----------------------------------------------------------------------===//
2683 // Exprs
2684 
2685 FailureOr<ast::CallExpr *>
2686 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
2687  MutableArrayRef<ast::Expr *> arguments, bool isNegated) {
2688  ast::Type parentType = parentExpr->getType();
2689 
2690  ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr);
2691  if (!callableDecl) {
2692  return emitError(loc,
2693  llvm::formatv("expected a reference to a callable "
2694  "`Constraint` or `Rewrite`, but got: `{0}`",
2695  parentType));
2696  }
2697  if (parserContext == ParserContext::Rewrite) {
2698  if (isa<ast::UserConstraintDecl>(callableDecl))
2699  return emitError(
2700  loc, "unable to invoke `Constraint` within a rewrite section");
2701  if (isNegated)
2702  return emitError(loc, "unable to negate a Rewrite");
2703  } else {
2704  if (isa<ast::UserRewriteDecl>(callableDecl))
2705  return emitError(loc,
2706  "unable to invoke `Rewrite` within a match section");
2707  if (isNegated && cast<ast::UserConstraintDecl>(callableDecl)->getBody())
2708  return emitError(loc, "unable to negate non native constraints");
2709  }
2710 
2711  // Verify the arguments of the call.
2712  /// Handle size mismatch.
2713  ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs();
2714  if (callArgs.size() != arguments.size()) {
2715  return emitErrorAndNote(
2716  loc,
2717  llvm::formatv("invalid number of arguments for {0} call; expected "
2718  "{1}, but got {2}",
2719  callableDecl->getCallableType(), callArgs.size(),
2720  arguments.size()),
2721  callableDecl->getLoc(),
2722  llvm::formatv("see the definition of {0} here",
2723  callableDecl->getName()->getName()));
2724  }
2725 
2726  /// Handle argument type mismatch.
2727  auto attachDiagFn = [&](ast::Diagnostic &diag) {
2728  diag.attachNote(llvm::formatv("see the definition of `{0}` here",
2729  callableDecl->getName()->getName()),
2730  callableDecl->getLoc());
2731  };
2732  for (auto it : llvm::zip(callArgs, arguments)) {
2733  if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(),
2734  attachDiagFn)))
2735  return failure();
2736  }
2737 
2738  return ast::CallExpr::create(ctx, loc, parentExpr, arguments,
2739  callableDecl->getResultType(), isNegated);
2740 }
2741 
2742 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,
2743  ast::Decl *decl) {
2744  // Check the type of decl being referenced.
2745  ast::Type declType;
2746  if (isa<ast::ConstraintDecl>(decl))
2747  declType = ast::ConstraintType::get(ctx);
2748  else if (isa<ast::UserRewriteDecl>(decl))
2749  declType = ast::RewriteType::get(ctx);
2750  else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
2751  declType = varDecl->getType();
2752  else
2753  return emitError(loc, "invalid reference to `" +
2754  decl->getName()->getName() + "`");
2755 
2756  return ast::DeclRefExpr::create(ctx, loc, decl, declType);
2757 }
2758 
2759 FailureOr<ast::DeclRefExpr *>
2760 Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
2761  ArrayRef<ast::ConstraintRef> constraints) {
2762  FailureOr<ast::VariableDecl *> decl =
2763  defineVariableDecl(name, loc, type, constraints);
2764  if (failed(decl))
2765  return failure();
2766  return ast::DeclRefExpr::create(ctx, loc, *decl, type);
2767 }
2768 
2769 FailureOr<ast::MemberAccessExpr *>
2770 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
2771  SMRange loc) {
2772  // Validate the member name for the given parent expression.
2773  FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc);
2774  if (failed(memberType))
2775  return failure();
2776 
2777  return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType);
2778 }
2779 
2780 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
2781  StringRef name, SMRange loc) {
2782  ast::Type parentType = parentExpr->getType();
2783  if (ast::OperationType opType = dyn_cast<ast::OperationType>(parentType)) {
2785  return valueRangeTy;
2786 
2787  // Verify member access based on the operation type.
2788  if (const ods::Operation *odsOp = opType.getODSOperation()) {
2789  auto results = odsOp->getResults();
2790 
2791  // Handle indexed results.
2792  unsigned index = 0;
2793  if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2794  index < results.size()) {
2795  return results[index].isVariadic() ? valueRangeTy : valueTy;
2796  }
2797 
2798  // Handle named results.
2799  const auto *it = llvm::find_if(results, [&](const auto &result) {
2800  return result.getName() == name;
2801  });
2802  if (it != results.end())
2803  return it->isVariadic() ? valueRangeTy : valueTy;
2804  } else if (llvm::isDigit(name[0])) {
2805  // Allow unchecked numeric indexing of the results of unregistered
2806  // operations. It returns a single value.
2807  return valueTy;
2808  }
2809  } else if (auto tupleType = dyn_cast<ast::TupleType>(parentType)) {
2810  // Handle indexed results.
2811  unsigned index = 0;
2812  if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2813  index < tupleType.size()) {
2814  return tupleType.getElementTypes()[index];
2815  }
2816 
2817  // Handle named results.
2818  auto elementNames = tupleType.getElementNames();
2819  const auto *it = llvm::find(elementNames, name);
2820  if (it != elementNames.end())
2821  return tupleType.getElementTypes()[it - elementNames.begin()];
2822  }
2823  return emitError(
2824  loc,
2825  llvm::formatv("invalid member access `{0}` on expression of type `{1}`",
2826  name, parentType));
2827 }
2828 
2829 FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
2830  SMRange loc, const ast::OpNameDecl *name,
2831  OpResultTypeContext resultTypeContext,
2832  SmallVectorImpl<ast::Expr *> &operands,
2834  SmallVectorImpl<ast::Expr *> &results) {
2835  std::optional<StringRef> opNameRef = name->getName();
2836  const ods::Operation *odsOp = lookupODSOperation(opNameRef);
2837 
2838  // Verify the inputs operands.
2839  if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands)))
2840  return failure();
2841 
2842  // Verify the attribute list.
2843  for (ast::NamedAttributeDecl *attr : attributes) {
2844  // Check for an attribute type, or a type awaiting resolution.
2845  ast::Type attrType = attr->getValue()->getType();
2846  if (!isa<ast::AttributeType>(attrType)) {
2847  return emitError(
2848  attr->getValue()->getLoc(),
2849  llvm::formatv("expected `Attr` expression, but got `{0}`", attrType));
2850  }
2851  }
2852 
2853  assert(
2854  (resultTypeContext == OpResultTypeContext::Explicit || results.empty()) &&
2855  "unexpected inferrence when results were explicitly specified");
2856 
2857  // If we aren't relying on type inferrence, or explicit results were provided,
2858  // validate them.
2859  if (resultTypeContext == OpResultTypeContext::Explicit) {
2860  if (failed(validateOperationResults(loc, opNameRef, odsOp, results)))
2861  return failure();
2862 
2863  // Validate the use of interface based type inferrence for this operation.
2864  } else if (resultTypeContext == OpResultTypeContext::Interface) {
2865  assert(opNameRef &&
2866  "expected valid operation name when inferring operation results");
2867  checkOperationResultTypeInferrence(loc, *opNameRef, odsOp);
2868  }
2869 
2870  return ast::OperationExpr::create(ctx, loc, odsOp, name, operands, results,
2871  attributes);
2872 }
2873 
2874 LogicalResult
2875 Parser::validateOperationOperands(SMRange loc, std::optional<StringRef> name,
2876  const ods::Operation *odsOp,
2877  SmallVectorImpl<ast::Expr *> &operands) {
2878  return validateOperationOperandsOrResults(
2879  "operand", loc, odsOp ? odsOp->getLoc() : std::optional<SMRange>(), name,
2880  operands, odsOp ? odsOp->getOperands() : std::nullopt, valueTy,
2881  valueRangeTy);
2882 }
2883 
2884 LogicalResult
2885 Parser::validateOperationResults(SMRange loc, std::optional<StringRef> name,
2886  const ods::Operation *odsOp,
2887  SmallVectorImpl<ast::Expr *> &results) {
2888  return validateOperationOperandsOrResults(
2889  "result", loc, odsOp ? odsOp->getLoc() : std::optional<SMRange>(), name,
2890  results, odsOp ? odsOp->getResults() : std::nullopt, typeTy, typeRangeTy);
2891 }
2892 
2893 void Parser::checkOperationResultTypeInferrence(SMRange loc, StringRef opName,
2894  const ods::Operation *odsOp) {
2895  // If the operation might not have inferrence support, emit a warning to the
2896  // user. We don't emit an error because the interface might be added to the
2897  // operation at runtime. It's rare, but it could still happen. We emit a
2898  // warning here instead.
2899 
2900  // Handle inferrence warnings for unknown operations.
2901  if (!odsOp) {
2902  ctx.getDiagEngine().emitWarning(
2903  loc, llvm::formatv(
2904  "operation result types are marked to be inferred, but "
2905  "`{0}` is unknown. Ensure that `{0}` supports zero "
2906  "results or implements `InferTypeOpInterface`. Include "
2907  "the ODS definition of this operation to remove this warning.",
2908  opName));
2909  return;
2910  }
2911 
2912  // Handle inferrence warnings for known operations that expected at least one
2913  // result, but don't have inference support. An elided results list can mean
2914  // "zero-results", and we don't want to warn when that is the expected
2915  // behavior.
2916  bool requiresInferrence =
2917  llvm::any_of(odsOp->getResults(), [](const ods::OperandOrResult &result) {
2918  return !result.isVariableLength();
2919  });
2920  if (requiresInferrence && !odsOp->hasResultTypeInferrence()) {
2922  loc,
2923  llvm::formatv("operation result types are marked to be inferred, but "
2924  "`{0}` does not provide an implementation of "
2925  "`InferTypeOpInterface`. Ensure that `{0}` attaches "
2926  "`InferTypeOpInterface` at runtime, or add support to "
2927  "the ODS definition to remove this warning.",
2928  opName));
2929  diag->attachNote(llvm::formatv("see the definition of `{0}` here", opName),
2930  odsOp->getLoc());
2931  return;
2932  }
2933 }
2934 
2935 LogicalResult Parser::validateOperationOperandsOrResults(
2936  StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc,
2937  std::optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values,
2938  ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
2939  ast::RangeType rangeTy) {
2940  // All operation types accept a single range parameter.
2941  if (values.size() == 1) {
2942  if (failed(convertExpressionTo(values[0], rangeTy)))
2943  return failure();
2944  return success();
2945  }
2946 
2947  /// If the operation has ODS information, we can more accurately verify the
2948  /// values.
2949  if (odsOpLoc) {
2950  auto emitSizeMismatchError = [&] {
2951  return emitErrorAndNote(
2952  loc,
2953  llvm::formatv("invalid number of {0} groups for `{1}`; expected "
2954  "{2}, but got {3}",
2955  groupName, *name, odsValues.size(), values.size()),
2956  *odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name));
2957  };
2958 
2959  // Handle the case where no values were provided.
2960  if (values.empty()) {
2961  // If we don't expect any on the ODS side, we are done.
2962  if (odsValues.empty())
2963  return success();
2964 
2965  // If we do, check if we actually need to provide values (i.e. if any of
2966  // the values are actually required).
2967  unsigned numVariadic = 0;
2968  for (const auto &odsValue : odsValues) {
2969  if (!odsValue.isVariableLength())
2970  return emitSizeMismatchError();
2971  ++numVariadic;
2972  }
2973 
2974  // If we are in a non-rewrite context, we don't need to do anything more.
2975  // Zero-values is a valid constraint on the operation.
2976  if (parserContext != ParserContext::Rewrite)
2977  return success();
2978 
2979  // Otherwise, when in a rewrite we may need to provide values to match the
2980  // ODS signature of the operation to create.
2981 
2982  // If we only have one variadic value, just use an empty list.
2983  if (numVariadic == 1)
2984  return success();
2985 
2986  // Otherwise, create dummy values for each of the entries so that we
2987  // adhere to the ODS signature.
2988  for (unsigned i = 0, e = odsValues.size(); i < e; ++i) {
2989  values.push_back(ast::RangeExpr::create(
2990  ctx, loc, /*elements=*/std::nullopt, rangeTy));
2991  }
2992  return success();
2993  }
2994 
2995  // Verify that the number of values provided matches the number of value
2996  // groups ODS expects.
2997  if (odsValues.size() != values.size())
2998  return emitSizeMismatchError();
2999 
3000  auto diagFn = [&](ast::Diagnostic &diag) {
3001  diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name),
3002  *odsOpLoc);
3003  };
3004  for (unsigned i = 0, e = values.size(); i < e; ++i) {
3005  ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy;
3006  if (failed(convertExpressionTo(values[i], expectedType, diagFn)))
3007  return failure();
3008  }
3009  return success();
3010  }
3011 
3012  // Otherwise, accept the value groups as they have been defined and just
3013  // ensure they are one of the expected types.
3014  for (ast::Expr *&valueExpr : values) {
3015  ast::Type valueExprType = valueExpr->getType();
3016 
3017  // Check if this is one of the expected types.
3018  if (valueExprType == rangeTy || valueExprType == singleTy)
3019  continue;
3020 
3021  // If the operand is an Operation, allow converting to a Value or
3022  // ValueRange. This situations arises quite often with nested operation
3023  // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)`
3024  if (singleTy == valueTy) {
3025  if (isa<ast::OperationType>(valueExprType)) {
3026  valueExpr = convertOpToValue(valueExpr);
3027  continue;
3028  }
3029  }
3030 
3031  // Otherwise, try to convert the expression to a range.
3032  if (succeeded(convertExpressionTo(valueExpr, rangeTy)))
3033  continue;
3034 
3035  return emitError(
3036  valueExpr->getLoc(),
3037  llvm::formatv(
3038  "expected `{0}` or `{1}` convertible expression, but got `{2}`",
3039  singleTy, rangeTy, valueExprType));
3040  }
3041  return success();
3042 }
3043 
3044 FailureOr<ast::TupleExpr *>
3045 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements,
3046  ArrayRef<StringRef> elementNames) {
3047  for (const ast::Expr *element : elements) {
3048  ast::Type eleTy = element->getType();
3049  if (isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>(eleTy)) {
3050  return emitError(
3051  element->getLoc(),
3052  llvm::formatv("unable to build a tuple with `{0}` element", eleTy));
3053  }
3054  }
3055  return ast::TupleExpr::create(ctx, loc, elements, elementNames);
3056 }
3057 
3058 //===----------------------------------------------------------------------===//
3059 // Stmts
3060 
3061 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc,
3062  ast::Expr *rootOp) {
3063  // Check that root is an Operation.
3064  ast::Type rootType = rootOp->getType();
3065  if (!isa<ast::OperationType>(rootType))
3066  return emitError(rootOp->getLoc(), "expected `Op` expression");
3067 
3068  return ast::EraseStmt::create(ctx, loc, rootOp);
3069 }
3070 
3071 FailureOr<ast::ReplaceStmt *>
3072 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp,
3073  MutableArrayRef<ast::Expr *> replValues) {
3074  // Check that root is an Operation.
3075  ast::Type rootType = rootOp->getType();
3076  if (!isa<ast::OperationType>(rootType)) {
3077  return emitError(
3078  rootOp->getLoc(),
3079  llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
3080  }
3081 
3082  // If there are multiple replacement values, we implicitly convert any Op
3083  // expressions to the value form.
3084  bool shouldConvertOpToValues = replValues.size() > 1;
3085  for (ast::Expr *&replExpr : replValues) {
3086  ast::Type replType = replExpr->getType();
3087 
3088  // Check that replExpr is an Operation, Value, or ValueRange.
3089  if (isa<ast::OperationType>(replType)) {
3090  if (shouldConvertOpToValues)
3091  replExpr = convertOpToValue(replExpr);
3092  continue;
3093  }
3094 
3095  if (replType != valueTy && replType != valueRangeTy) {
3096  return emitError(replExpr->getLoc(),
3097  llvm::formatv("expected `Op`, `Value` or `ValueRange` "
3098  "expression, but got `{0}`",
3099  replType));
3100  }
3101  }
3102 
3103  return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues);
3104 }
3105 
3106 FailureOr<ast::RewriteStmt *>
3107 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp,
3108  ast::CompoundStmt *rewriteBody) {
3109  // Check that root is an Operation.
3110  ast::Type rootType = rootOp->getType();
3111  if (!isa<ast::OperationType>(rootType)) {
3112  return emitError(
3113  rootOp->getLoc(),
3114  llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
3115  }
3116 
3117  return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody);
3118 }
3119 
3120 //===----------------------------------------------------------------------===//
3121 // Code Completion
3122 //===----------------------------------------------------------------------===//
3123 
3124 LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) {
3125  ast::Type parentType = parentExpr->getType();
3126  if (ast::OperationType opType = dyn_cast<ast::OperationType>(parentType))
3127  codeCompleteContext->codeCompleteOperationMemberAccess(opType);
3128  else if (ast::TupleType tupleType = dyn_cast<ast::TupleType>(parentType))
3129  codeCompleteContext->codeCompleteTupleMemberAccess(tupleType);
3130  return failure();
3131 }
3132 
3133 LogicalResult
3134 Parser::codeCompleteAttributeName(std::optional<StringRef> opName) {
3135  if (opName)
3136  codeCompleteContext->codeCompleteOperationAttributeName(*opName);
3137  return failure();
3138 }
3139 
3140 LogicalResult
3141 Parser::codeCompleteConstraintName(ast::Type inferredType,
3142  bool allowInlineTypeConstraints) {
3143  codeCompleteContext->codeCompleteConstraintName(
3144  inferredType, allowInlineTypeConstraints, curDeclScope);
3145  return failure();
3146 }
3147 
3148 LogicalResult Parser::codeCompleteDialectName() {
3149  codeCompleteContext->codeCompleteDialectName();
3150  return failure();
3151 }
3152 
3153 LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) {
3154  codeCompleteContext->codeCompleteOperationName(dialectName);
3155  return failure();
3156 }
3157 
3158 LogicalResult Parser::codeCompletePatternMetadata() {
3159  codeCompleteContext->codeCompletePatternMetadata();
3160  return failure();
3161 }
3162 
3163 LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) {
3164  codeCompleteContext->codeCompleteIncludeFilename(curPath);
3165  return failure();
3166 }
3167 
3168 void Parser::codeCompleteCallSignature(ast::Node *parent,
3169  unsigned currentNumArgs) {
3170  ast::CallableDecl *callableDecl = tryExtractCallableDecl(parent);
3171  if (!callableDecl)
3172  return;
3173 
3174  codeCompleteContext->codeCompleteCallSignature(callableDecl, currentNumArgs);
3175 }
3176 
3177 void Parser::codeCompleteOperationOperandsSignature(
3178  std::optional<StringRef> opName, unsigned currentNumOperands) {
3179  codeCompleteContext->codeCompleteOperationOperandsSignature(
3180  opName, currentNumOperands);
3181 }
3182 
3183 void Parser::codeCompleteOperationResultsSignature(
3184  std::optional<StringRef> opName, unsigned currentNumResults) {
3185  codeCompleteContext->codeCompleteOperationResultsSignature(opName,
3186  currentNumResults);
3187 }
3188 
3189 //===----------------------------------------------------------------------===//
3190 // Parser
3191 //===----------------------------------------------------------------------===//
3192 
3193 FailureOr<ast::Module *>
3194 mlir::pdll::parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
3195  bool enableDocumentation,
3196  CodeCompleteContext *codeCompleteContext) {
3197  Parser parser(ctx, sourceMgr, enableDocumentation, codeCompleteContext);
3198  return parser.parseModule();
3199 }
static std::string diag(const llvm::Value &value)
This class breaks up the current file into a token stream.
Definition: Lexer.h:23
This represents a token in the MLIR syntax.
Definition: Token.h:20
bool isKeyword() const
Return true if this is one of the keyword token kinds (e.g. kw_if).
Definition: Token.cpp:192
SMLoc getLoc() const
Definition: Token.cpp:24
bool isNot(Kind k) const
Definition: Token.h:50
StringRef getSpelling() const
Definition: Token.h:34
This class provides an abstract interface into the parser for hooking in code completion events.
Definition: CodeComplete.h:30
virtual void codeCompleteConstraintName(ast::Type currentType, bool allowInlineTypeConstraints, const ast::DeclScope *scope)
Signal code completion for a constraint name with an optional decl scope.
virtual void codeCompleteOperationAttributeName(StringRef opName)
Signal code completion for a member access into the given operation type.
Definition: CodeComplete.h:48
virtual void codeCompleteOperationOperandsSignature(std::optional< StringRef > opName, unsigned currentNumOperands)
Signal code completion for the signature of an operation's operands.
Definition: CodeComplete.h:80
virtual void codeCompleteOperationName(StringRef dialectName)
Signal code completion for an operation name in the given dialect.
Definition: CodeComplete.h:62
virtual void codeCompleteOperationResultsSignature(std::optional< StringRef > opName, unsigned currentNumResults)
Signal code completion for the signature of an operation's results.
Definition: CodeComplete.h:85
virtual void codeCompleteDialectName()
Signal code completion for a dialect name.
Definition: CodeComplete.h:59
virtual void codeCompleteOperationMemberAccess(ast::OperationType opType)
Signal code completion for a member access into the given operation type.
virtual void codeCompleteTupleMemberAccess(ast::TupleType tupleType)
Signal code completion for a member access into the given tuple type.
virtual void codeCompletePatternMetadata()
Signal code completion for Pattern metadata.
Definition: CodeComplete.h:65
virtual void codeCompleteCallSignature(const ast::CallableDecl *callable, unsigned currentNumArgs)
Signal code completion for the signature of a callable.
Definition: CodeComplete.h:75
virtual void codeCompleteIncludeFilename(StringRef curPath)
Signal code completion for an include filename.
Definition: CodeComplete.h:68
@ code_complete_string
Token signifying a code completion location within a string.
Definition: Lexer.h:41
@ directive
Tokens.
Definition: Lexer.h:93
@ eof
Markers.
Definition: Lexer.h:36
@ code_complete
Token signifying a code completion location.
Definition: Lexer.h:39
@ arrow
Punctuation.
Definition: Lexer.h:74
@ less
Paired punctuation.
Definition: Lexer.h:82
@ kw_Attr
General keywords.
Definition: Lexer.h:54
static StringRef getMemberName()
Return the member name used for the "all-results" access.
Definition: Nodes.h:487
static AllResultsMemberAccessExpr * create(Context &ctx, SMRange loc, const Expr *parentExpr, Type type)
Definition: Nodes.h:489
The class represents an Attribute constraint, and constrains a variable to be an Attribute.
Definition: Nodes.h:749
static AttrConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr=nullptr)
Definition: Nodes.cpp:390
static AttributeExpr * create(Context &ctx, SMRange loc, StringRef value)
Definition: Nodes.cpp:258
static AttributeType get(Context &context)
Return an instance of the Attribute type.
Definition: Types.cpp:57
static CallExpr * create(Context &ctx, SMRange loc, Expr *callable, ArrayRef< Expr * > arguments, Type resultType, bool isNegated=false)
Definition: Nodes.cpp:268
This decl represents a shared interface for all callable decls.
Definition: Nodes.h:1188
Type getResultType() const
Return the result type of this decl.
Definition: Nodes.h:1206
ArrayRef< VariableDecl * > getInputs() const
Return the inputs of this decl.
Definition: Nodes.h:1199
StringRef getCallableType() const
Return the callable type of this decl.
Definition: Nodes.h:1191
This statement represents a compound statement, which contains a collection of other statements.
Definition: Nodes.h:179
ArrayRef< Stmt * >::iterator begin() const
Definition: Nodes.h:191
MutableArrayRef< Stmt * > getChildren()
Return the children of this compound statement.
Definition: Nodes.h:185
ArrayRef< Stmt * >::iterator end() const
Definition: Nodes.h:192
static CompoundStmt * create(Context &ctx, SMRange location, ArrayRef< Stmt * > children)
Definition: Nodes.cpp:192
This class represents the base of all AST Constraint decls.
Definition: Nodes.h:704
static ConstraintType get(Context &context)
Return an instance of the Constraint type.
Definition: Types.cpp:65
This class represents the main context of the PDLL AST.
Definition: Context.h:25
DiagnosticEngine & getDiagEngine()
Return the diagnostic engine of this context.
Definition: Context.h:42
ods::Context & getODSContext()
Return the ODS context used by the AST.
Definition: Context.h:38
static DeclRefExpr * create(Context &ctx, SMRange loc, Decl *decl, Type type)
Definition: Nodes.cpp:285
This class represents a scope for named AST decls.
Definition: Nodes.h:64
DeclScope * getParentScope()
Return the parent scope of this scope, or nullptr if there is no parent.
Definition: Nodes.h:70
void add(Decl *decl)
Add a new decl to the scope.
Definition: Nodes.cpp:175
This class represents the base Decl node.
Definition: Nodes.h:669
const Name * getName() const
Return the name of the decl, or nullptr if it doesn't have one.
Definition: Nodes.h:672
InFlightDiagnostic emitWarning(SMRange loc, const Twine &msg)
Definition: Diagnostic.h:149
InFlightDiagnostic emitError(SMRange loc, const Twine &msg)
Emit an error to the diagnostic engine.
Definition: Diagnostic.h:145
This class provides a simple implementation of a PDLL diagnostic.
Definition: Diagnostic.h:29
static EraseStmt * create(Context &ctx, SMRange loc, Expr *rootOp)
Definition: Nodes.cpp:218
This class represents a base AST Expression node.
Definition: Nodes.h:345
Type getType() const
Return the type of this expression.
Definition: Nodes.h:348
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostic.h:83
static LetStmt * create(Context &ctx, SMRange loc, VariableDecl *varDecl)
Definition: Nodes.cpp:207
static MemberAccessExpr * create(Context &ctx, SMRange loc, const Expr *parentExpr, StringRef memberName, Type type)
Definition: Nodes.cpp:295
static Module * create(Context &ctx, SMLoc loc, ArrayRef< Decl * > children)
Definition: Nodes.cpp:576
This Decl represents a NamedAttribute, and contains a string name and attribute value.
Definition: Nodes.h:992
static NamedAttributeDecl * create(Context &ctx, const Name &name, Expr *value)
Definition: Nodes.cpp:499
This class represents a base AST node.
Definition: Nodes.h:108
SMRange getLoc() const
Return the location of this node.
Definition: Nodes.h:131
static OpConstraintDecl * create(Context &ctx, SMRange loc, const OpNameDecl *nameDecl=nullptr)
Definition: Nodes.cpp:400
This Decl represents an OperationName.
Definition: Nodes.h:1016
std::optional< StringRef > getName() const
Return the name of this operation, or std::nullopt if the name is unknown.
Definition: Nodes.h:1022
static OpNameDecl * create(Context &ctx, const Name &name)
Definition: Nodes.cpp:509
static OperationExpr * create(Context &ctx, SMRange loc, const ods::Operation *odsOp, const OpNameDecl *nameDecl, ArrayRef< Expr * > operands, ArrayRef< Expr * > resultTypes, ArrayRef< NamedAttributeDecl * > attributes)
Definition: Nodes.cpp:307
This class represents a PDLL type that corresponds to an mlir::Operation.
Definition: Types.h:163
const ods::Operation * getODSOperation() const
Return the ODS operation that this type refers to, or nullptr if the ODS operation is unknown.
Definition: Types.cpp:87
static OperationType get(Context &context, std::optional< StringRef > name=std::nullopt, const ods::Operation *odsOp=nullptr)
Return an instance of the Operation type with an optional operation name.
Definition: Types.cpp:73
static PatternDecl * create(Context &ctx, SMRange location, const Name *name, std::optional< uint16_t > benefit, bool hasBoundedRecursion, const CompoundStmt *body)
Definition: Nodes.cpp:520
static RangeExpr * create(Context &ctx, SMRange loc, ArrayRef< Expr * > elements, RangeType type)
Definition: Nodes.cpp:338
This class represents a PDLL type that corresponds to a range of elements with a given element type.
Definition: Types.h:188
static ReplaceStmt * create(Context &ctx, SMRange loc, Expr *rootOp, ArrayRef< Expr * > replExprs)
Definition: Nodes.cpp:225
static ReturnStmt * create(Context &ctx, SMRange loc, Expr *resultExpr)
Definition: Nodes.cpp:249
static RewriteStmt * create(Context &ctx, SMRange loc, Expr *rootOp, CompoundStmt *rewriteBody)
Definition: Nodes.cpp:239
static RewriteType get(Context &context)
Return an instance of the Rewrite type.
Definition: Types.cpp:134
This class represents a base AST Statement node.
Definition: Nodes.h:164
static TupleExpr * create(Context &ctx, SMRange loc, ArrayRef< Expr * > elements, ArrayRef< StringRef > elementNames)
Definition: Nodes.cpp:353
This class represents a PDLL tuple type, i.e.
Definition: Types.h:249
size_t size() const
Return the number of elements within this tuple.
Definition: Types.h:266
ArrayRef< Type > getElementTypes() const
Return the element types of this tuple.
Definition: Types.cpp:153
static TupleType get(Context &context, ArrayRef< Type > elementTypes, ArrayRef< StringRef > elementNames)
Return an instance of the Tuple type.
Definition: Types.cpp:142
static TypeConstraintDecl * create(Context &ctx, SMRange loc)
Definition: Nodes.cpp:417
static TypeExpr * create(Context &ctx, SMRange loc, StringRef value)
Definition: Nodes.cpp:373
static TypeRangeConstraintDecl * create(Context &ctx, SMRange loc)
Definition: Nodes.cpp:426
static TypeRangeType get(Context &context)
Return an instance of the TypeRange type.
Definition: Types.cpp:112
static TypeType get(Context &context)
Return an instance of the Type type.
Definition: Types.cpp:165
Type refineWith(Type other) const
Try to refine this type with the one provided.
Definition: Types.cpp:33
static UserConstraintDecl * createNative(Context &ctx, const Name &name, ArrayRef< VariableDecl * > inputs, ArrayRef< VariableDecl * > results, std::optional< StringRef > codeBlock, Type resultType, ArrayRef< StringRef > nativeInputTypes={})
Create a native constraint with the given optional code block.
Definition: Nodes.h:886
The class represents a Value constraint, and constrains a variable to be a Value.
Definition: Nodes.h:825
static ValueConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr)
Definition: Nodes.cpp:436
The class represents a ValueRange constraint, and constrains a variable to be a ValueRange.
Definition: Nodes.h:847
static ValueRangeConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr=nullptr)
Definition: Nodes.cpp:447
static ValueRangeType get(Context &context)
Return an instance of the ValueRange type.
Definition: Types.cpp:125
static ValueType get(Context &context)
Return an instance of the Value type.
Definition: Types.cpp:173
static VariableDecl * create(Context &ctx, const Name &name, Type type, Expr *initExpr, ArrayRef< ConstraintRef > constraints)
Definition: Nodes.cpp:558
This class contains all of the registered ODS operation classes.
Definition: Context.h:32
std::pair< Operation *, bool > insertOperation(StringRef name, StringRef summary, StringRef desc, StringRef nativeClassName, bool supportsResultTypeInferrence, SMLoc loc)
Insert a new operation with the context.
Definition: Context.cpp:64
const TypeConstraint & insertTypeConstraint(StringRef name, StringRef summary, StringRef cppClass)
Insert a new type constraint with the context.
Definition: Context.cpp:42
const AttributeConstraint & insertAttributeConstraint(StringRef name, StringRef summary, StringRef cppClass)
Insert a new attribute constraint with the context.
Definition: Context.cpp:28
const Operation * lookupOperation(StringRef name) const
Lookup an operation registered with the given name, or null if no operation with that name is registe...
Definition: Context.cpp:73
This class provides an ODS representation of a specific operation operand or result.
Definition: Operation.h:74
This class provides an ODS representation of a specific operation.
Definition: Operation.h:125
SMRange getLoc() const
Return the source location of this operation.
Definition: Operation.h:128
ArrayRef< OperandOrResult > getOperands() const
Returns the operands of this operation.
Definition: Operation.h:165
bool hasResultTypeInferrence() const
Return if the operation is known to support result type inferrence.
Definition: Operation.h:171
ArrayRef< OperandOrResult > getResults() const
Returns the results of this operation.
Definition: Operation.h:168
This class represents a generic ODS Type constraint.
Definition: Constraint.h:84
raw_ostream subclass that simplifies indention a sequence of code.
raw_indented_ostream & printReindented(StringRef str, StringRef extraPrefix="")
Prints a string re-indented to the current indent.
StringRef getSummary() const
Definition: Constraint.cpp:55
std::string getUniqueDefName() const
Returns a unique name for the TablGen def of this constraint.
Definition: Constraint.cpp:72
StringRef getDescription() const
Definition: Constraint.cpp:62
std::string getConditionTemplate() const
Definition: Constraint.cpp:51
Format context containing substitutions for special placeholders.
Definition: Format.h:40
FmtContext & withSelf(Twine subst)
Definition: Format.cpp:41
Wrapper class that contains a MLIR op's information (e.g., operands, attributes) defined in TableGen ...
Definition: Operator.h:77
FailureOr< ast::Module * > parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr, bool enableDocumentation=false, CodeCompleteContext *codeCompleteContext=nullptr)
Parse an AST module from the main file of the given source manager.
Definition: Parser.cpp:3194
auto tgfmt(StringRef fmt, const FmtContext *ctx, Ts &&...vals) -> FmtObject< decltype(std::make_tuple(llvm::support::detail::build_format_adapter(std::forward< Ts >(vals))...))>
Formats text by substituting placeholders in format string with replacement parameters.
Definition: Format.h:262
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
This class represents a reference to a constraint, and contains a constraint and the location of the ...
Definition: Nodes.h:716
const ConstraintDecl * constraint
Definition: Nodes.h:722
This class provides a convenient API for interacting with source names.
Definition: Nodes.h:37
StringRef getName() const
Return the raw string name.
Definition: Nodes.h:41
SMRange getLoc() const
Get the location of this name.
Definition: Nodes.h:44
static const Name & create(Context &ctx, StringRef name, SMRange location)
Definition: Nodes.cpp:33