MLIR  18.0.0git
Parser.cpp
Go to the documentation of this file.
1 //===- Parser.cpp ---------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "Lexer.h"
13 #include "mlir/TableGen/Argument.h"
16 #include "mlir/TableGen/Format.h"
17 #include "mlir/TableGen/Operator.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/FormatVariadic.h"
29 #include "llvm/Support/ManagedStatic.h"
30 #include "llvm/Support/SaveAndRestore.h"
31 #include "llvm/Support/ScopedPrinter.h"
32 #include "llvm/TableGen/Error.h"
33 #include "llvm/TableGen/Parser.h"
34 #include <string>
35 #include <optional>
36 
37 using namespace mlir;
38 using namespace mlir::pdll;
39 
40 //===----------------------------------------------------------------------===//
41 // Parser
42 //===----------------------------------------------------------------------===//
43 
44 namespace {
45 class Parser {
46 public:
47  Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
48  bool enableDocumentation, CodeCompleteContext *codeCompleteContext)
49  : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext),
50  curToken(lexer.lexToken()), enableDocumentation(enableDocumentation),
51  typeTy(ast::TypeType::get(ctx)), valueTy(ast::ValueType::get(ctx)),
52  typeRangeTy(ast::TypeRangeType::get(ctx)),
53  valueRangeTy(ast::ValueRangeType::get(ctx)),
54  attrTy(ast::AttributeType::get(ctx)),
55  codeCompleteContext(codeCompleteContext) {}
56 
57  /// Try to parse a new module. Returns nullptr in the case of failure.
58  FailureOr<ast::Module *> parseModule();
59 
60 private:
61  /// The current context of the parser. It allows for the parser to know a bit
62  /// about the construct it is nested within during parsing. This is used
63  /// specifically to provide additional verification during parsing, e.g. to
64  /// prevent using rewrites within a match context, matcher constraints within
65  /// a rewrite section, etc.
66  enum class ParserContext {
67  /// The parser is in the global context.
68  Global,
69  /// The parser is currently within a Constraint, which disallows all types
70  /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.).
71  Constraint,
72  /// The parser is currently within the matcher portion of a Pattern, which
73  /// is allows a terminal operation rewrite statement but no other rewrite
74  /// transformations.
75  PatternMatch,
76  /// The parser is currently within a Rewrite, which disallows calls to
77  /// constraints, requires operation expressions to have names, etc.
78  Rewrite,
79  };
80 
81  /// The current specification context of an operations result type. This
82  /// indicates how the result types of an operation may be inferred.
83  enum class OpResultTypeContext {
84  /// The result types of the operation are not known to be inferred.
85  Explicit,
86  /// The result types of the operation are inferred from the root input of a
87  /// `replace` statement.
88  Replacement,
89  /// The result types of the operation are inferred by using the
90  /// `InferTypeOpInterface` interface provided by the operation.
91  Interface,
92  };
93 
94  //===--------------------------------------------------------------------===//
95  // Parsing
96  //===--------------------------------------------------------------------===//
97 
98  /// Push a new decl scope onto the lexer.
99  ast::DeclScope *pushDeclScope() {
100  ast::DeclScope *newScope =
101  new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope);
102  return (curDeclScope = newScope);
103  }
104  void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; }
105 
106  /// Pop the last decl scope from the lexer.
107  void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); }
108 
109  /// Parse the body of an AST module.
110  LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls);
111 
112  /// Try to convert the given expression to `type`. Returns failure and emits
113  /// an error if a conversion is not viable. On failure, `noteAttachFn` is
114  /// invoked to attach notes to the emitted error diagnostic. On success,
115  /// `expr` is updated to the expression used to convert to `type`.
116  LogicalResult convertExpressionTo(
117  ast::Expr *&expr, ast::Type type,
118  function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {});
120  convertOpExpressionTo(ast::Expr *&expr, ast::OperationType exprType,
121  ast::Type type,
122  function_ref<ast::InFlightDiagnostic()> emitErrorFn);
123  LogicalResult convertTupleExpressionTo(
124  ast::Expr *&expr, ast::TupleType exprType, ast::Type type,
125  function_ref<ast::InFlightDiagnostic()> emitErrorFn,
126  function_ref<void(ast::Diagnostic &diag)> noteAttachFn);
127 
128  /// Given an operation expression, convert it to a Value or ValueRange
129  /// typed expression.
130  ast::Expr *convertOpToValue(const ast::Expr *opExpr);
131 
132  /// Lookup ODS information for the given operation, returns nullptr if no
133  /// information is found.
134  const ods::Operation *lookupODSOperation(std::optional<StringRef> opName) {
135  return opName ? ctx.getODSContext().lookupOperation(*opName) : nullptr;
136  }
137 
138  /// Process the given documentation string, or return an empty string if
139  /// documentation isn't enabled.
140  StringRef processDoc(StringRef doc) {
141  return enableDocumentation ? doc : StringRef();
142  }
143 
144  /// Process the given documentation string and format it, or return an empty
145  /// string if documentation isn't enabled.
146  std::string processAndFormatDoc(const Twine &doc) {
147  if (!enableDocumentation)
148  return "";
149  std::string docStr;
150  {
151  llvm::raw_string_ostream docOS(docStr);
152  std::string tmpDocStr = doc.str();
154  StringRef(tmpDocStr).rtrim(" \t"));
155  }
156  return docStr;
157  }
158 
159  //===--------------------------------------------------------------------===//
160  // Directives
161 
162  LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls);
163  LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls);
164  LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc,
166 
167  /// Process the records of a parsed tablegen include file.
168  void processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
170 
171  /// Create a user defined native constraint for a constraint imported from
172  /// ODS.
173  template <typename ConstraintT>
174  ast::Decl *
175  createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock,
176  SMRange loc, ast::Type type,
177  StringRef nativeType, StringRef docString);
178  template <typename ConstraintT>
179  ast::Decl *
180  createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
181  SMRange loc, ast::Type type,
182  StringRef nativeType);
183 
184  //===--------------------------------------------------------------------===//
185  // Decls
186 
187  /// This structure contains the set of pattern metadata that may be parsed.
188  struct ParsedPatternMetadata {
189  std::optional<uint16_t> benefit;
190  bool hasBoundedRecursion = false;
191  };
192 
193  FailureOr<ast::Decl *> parseTopLevelDecl();
195  parseNamedAttributeDecl(std::optional<StringRef> parentOpName);
196 
197  /// Parse an argument variable as part of the signature of a
198  /// UserConstraintDecl or UserRewriteDecl.
199  FailureOr<ast::VariableDecl *> parseArgumentDecl();
200 
201  /// Parse a result variable as part of the signature of a UserConstraintDecl
202  /// or UserRewriteDecl.
203  FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum);
204 
205  /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being
206  /// defined in a non-global context.
208  parseUserConstraintDecl(bool isInline = false);
209 
210  /// Parse an inline UserConstraintDecl. An inline decl is one defined in a
211  /// non-global context, such as within a Pattern/Constraint/etc.
212  FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl();
213 
214  /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
215  /// PDLL constructs.
216  FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl(
217  const ast::Name &name, bool isInline,
218  ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
219  ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
220 
221  /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being
222  /// defined in a non-global context.
223  FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false);
224 
225  /// Parse an inline UserRewriteDecl. An inline decl is one defined in a
226  /// non-global context, such as within a Pattern/Rewrite/etc.
227  FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl();
228 
229  /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
230  /// PDLL constructs.
231  FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl(
232  const ast::Name &name, bool isInline,
233  ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
234  ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
235 
236  /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have
237  /// effectively the same syntax, and only differ on slight semantics (given
238  /// the different parsing contexts).
239  template <typename T, typename ParseUserPDLLDeclFnT>
240  FailureOr<T *> parseUserConstraintOrRewriteDecl(
241  ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
242  StringRef anonymousNamePrefix, bool isInline);
243 
244  /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl.
245  /// These decls have effectively the same syntax.
246  template <typename T>
247  FailureOr<T *> parseUserNativeConstraintOrRewriteDecl(
248  const ast::Name &name, bool isInline,
250  ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
251 
252  /// Parse the functional signature (i.e. the arguments and results) of a
253  /// UserConstraintDecl or UserRewriteDecl.
254  LogicalResult parseUserConstraintOrRewriteSignature(
257  ast::DeclScope *&argumentScope, ast::Type &resultType);
258 
259  /// Validate the return (which if present is specified by bodyIt) of a
260  /// UserConstraintDecl or UserRewriteDecl.
261  LogicalResult validateUserConstraintOrRewriteReturn(
262  StringRef declType, ast::CompoundStmt *body,
265  ArrayRef<ast::VariableDecl *> results, ast::Type &resultType);
266 
268  parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
269  bool expectTerminalSemicolon = true);
270  FailureOr<ast::CompoundStmt *> parsePatternLambdaBody();
271  FailureOr<ast::Decl *> parsePatternDecl();
272  LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
273 
274  /// Check to see if a decl has already been defined with the given name, if
275  /// one has emit and error and return failure. Returns success otherwise.
276  LogicalResult checkDefineNamedDecl(const ast::Name &name);
277 
278  /// Try to define a variable decl with the given components, returns the
279  /// variable on success.
281  defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
282  ast::Expr *initExpr,
283  ArrayRef<ast::ConstraintRef> constraints);
285  defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
286  ArrayRef<ast::ConstraintRef> constraints);
287 
288  /// Parse the constraint reference list for a variable decl.
289  LogicalResult parseVariableDeclConstraintList(
291 
292  /// Parse the expression used within a type constraint, e.g. Attr<type-expr>.
293  FailureOr<ast::Expr *> parseTypeConstraintExpr();
294 
295  /// Try to parse a single reference to a constraint. `typeConstraint` is the
296  /// location of a previously parsed type constraint for the entity that will
297  /// be constrained by the parsed constraint. `existingConstraints` are any
298  /// existing constraints that have already been parsed for the same entity
299  /// that will be constrained by this constraint. `allowInlineTypeConstraints`
300  /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`.
302  parseConstraint(std::optional<SMRange> &typeConstraint,
303  ArrayRef<ast::ConstraintRef> existingConstraints,
304  bool allowInlineTypeConstraints);
305 
306  /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl
307  /// argument or result variable. The constraints for these variables do not
308  /// allow inline type constraints, and only permit a single constraint.
309  FailureOr<ast::ConstraintRef> parseArgOrResultConstraint();
310 
311  //===--------------------------------------------------------------------===//
312  // Exprs
313 
314  FailureOr<ast::Expr *> parseExpr();
315 
316  /// Identifier expressions.
317  FailureOr<ast::Expr *> parseAttributeExpr();
318  FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr,
319  bool isNegated = false);
320  FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
321  FailureOr<ast::Expr *> parseIdentifierExpr();
322  FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
323  FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
324  FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
325  FailureOr<ast::Expr *> parseNegatedExpr();
326  FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
327  FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
329  parseOperationExpr(OpResultTypeContext inputResultTypeContext =
330  OpResultTypeContext::Explicit);
331  FailureOr<ast::Expr *> parseTupleExpr();
332  FailureOr<ast::Expr *> parseTypeExpr();
333  FailureOr<ast::Expr *> parseUnderscoreExpr();
334 
335  //===--------------------------------------------------------------------===//
336  // Stmts
337 
338  FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true);
339  FailureOr<ast::CompoundStmt *> parseCompoundStmt();
340  FailureOr<ast::EraseStmt *> parseEraseStmt();
341  FailureOr<ast::LetStmt *> parseLetStmt();
342  FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
343  FailureOr<ast::ReturnStmt *> parseReturnStmt();
344  FailureOr<ast::RewriteStmt *> parseRewriteStmt();
345 
346  //===--------------------------------------------------------------------===//
347  // Creation+Analysis
348  //===--------------------------------------------------------------------===//
349 
350  //===--------------------------------------------------------------------===//
351  // Decls
352 
353  /// Try to extract a callable from the given AST node. Returns nullptr on
354  /// failure.
355  ast::CallableDecl *tryExtractCallableDecl(ast::Node *node);
356 
357  /// Try to create a pattern decl with the given components, returning the
358  /// Pattern on success.
360  createPatternDecl(SMRange loc, const ast::Name *name,
361  const ParsedPatternMetadata &metadata,
362  ast::CompoundStmt *body);
363 
364  /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set
365  /// of results, defined as part of the signature.
366  ast::Type
367  createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results);
368 
369  /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl.
370  template <typename T>
371  FailureOr<T *> createUserPDLLConstraintOrRewriteDecl(
372  const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
373  ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
374  ast::CompoundStmt *body);
375 
376  /// Try to create a variable decl with the given components, returning the
377  /// Variable on success.
379  createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
380  ArrayRef<ast::ConstraintRef> constraints);
381 
382  /// Create a variable for an argument or result defined as part of the
383  /// signature of a UserConstraintDecl/UserRewriteDecl.
385  createArgOrResultVariableDecl(StringRef name, SMRange loc,
386  const ast::ConstraintRef &constraint);
387 
388  /// Validate the constraints used to constraint a variable decl.
389  /// `inferredType` is the type of the variable inferred by the constraints
390  /// within the list, and is updated to the most refined type as determined by
391  /// the constraints. Returns success if the constraint list is valid, failure
392  /// otherwise.
394  validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
395  ast::Type &inferredType);
396  /// Validate a single reference to a constraint. `inferredType` contains the
397  /// currently inferred variabled type and is refined within the type defined
398  /// by the constraint. Returns success if the constraint is valid, failure
399  /// otherwise.
400  LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref,
401  ast::Type &inferredType);
402  LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr);
403  LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr);
404 
405  //===--------------------------------------------------------------------===//
406  // Exprs
407 
409  createCallExpr(SMRange loc, ast::Expr *parentExpr,
411  bool isNegated = false);
412  FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl);
414  createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
415  ArrayRef<ast::ConstraintRef> constraints);
417  createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc);
418 
419  /// Validate the member access `name` into the given parent expression. On
420  /// success, this also returns the type of the member accessed.
421  FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr,
422  StringRef name, SMRange loc);
424  createOperationExpr(SMRange loc, const ast::OpNameDecl *name,
425  OpResultTypeContext resultTypeContext,
430  validateOperationOperands(SMRange loc, std::optional<StringRef> name,
431  const ods::Operation *odsOp,
432  SmallVectorImpl<ast::Expr *> &operands);
433  LogicalResult validateOperationResults(SMRange loc,
434  std::optional<StringRef> name,
435  const ods::Operation *odsOp,
437  void checkOperationResultTypeInferrence(SMRange loc, StringRef name,
438  const ods::Operation *odsOp);
439  LogicalResult validateOperationOperandsOrResults(
440  StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc,
441  std::optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values,
442  ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
443  ast::RangeType rangeTy);
444  FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
445  ArrayRef<ast::Expr *> elements,
446  ArrayRef<StringRef> elementNames);
447 
448  //===--------------------------------------------------------------------===//
449  // Stmts
450 
451  FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp);
453  createReplaceStmt(SMRange loc, ast::Expr *rootOp,
454  MutableArrayRef<ast::Expr *> replValues);
456  createRewriteStmt(SMRange loc, ast::Expr *rootOp,
457  ast::CompoundStmt *rewriteBody);
458 
459  //===--------------------------------------------------------------------===//
460  // Code Completion
461  //===--------------------------------------------------------------------===//
462 
463  /// The set of various code completion methods. Every completion method
464  /// returns `failure` to stop the parsing process after providing completion
465  /// results.
466 
467  LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr);
468  LogicalResult codeCompleteAttributeName(std::optional<StringRef> opName);
469  LogicalResult codeCompleteConstraintName(ast::Type inferredType,
470  bool allowInlineTypeConstraints);
471  LogicalResult codeCompleteDialectName();
472  LogicalResult codeCompleteOperationName(StringRef dialectName);
473  LogicalResult codeCompletePatternMetadata();
474  LogicalResult codeCompleteIncludeFilename(StringRef curPath);
475 
476  void codeCompleteCallSignature(ast::Node *parent, unsigned currentNumArgs);
477  void codeCompleteOperationOperandsSignature(std::optional<StringRef> opName,
478  unsigned currentNumOperands);
479  void codeCompleteOperationResultsSignature(std::optional<StringRef> opName,
480  unsigned currentNumResults);
481 
482  //===--------------------------------------------------------------------===//
483  // Lexer Utilities
484  //===--------------------------------------------------------------------===//
485 
486  /// If the current token has the specified kind, consume it and return true.
487  /// If not, return false.
488  bool consumeIf(Token::Kind kind) {
489  if (curToken.isNot(kind))
490  return false;
491  consumeToken(kind);
492  return true;
493  }
494 
495  /// Advance the current lexer onto the next token.
496  void consumeToken() {
497  assert(curToken.isNot(Token::eof, Token::error) &&
498  "shouldn't advance past EOF or errors");
499  curToken = lexer.lexToken();
500  }
501 
502  /// Advance the current lexer onto the next token, asserting what the expected
503  /// current token is. This is preferred to the above method because it leads
504  /// to more self-documenting code with better checking.
505  void consumeToken(Token::Kind kind) {
506  assert(curToken.is(kind) && "consumed an unexpected token");
507  consumeToken();
508  }
509 
510  /// Reset the lexer to the location at the given position.
511  void resetToken(SMRange tokLoc) {
512  lexer.resetPointer(tokLoc.Start.getPointer());
513  curToken = lexer.lexToken();
514  }
515 
516  /// Consume the specified token if present and return success. On failure,
517  /// output a diagnostic and return failure.
518  LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
519  if (curToken.getKind() != kind)
520  return emitError(curToken.getLoc(), msg);
521  consumeToken();
522  return success();
523  }
524  LogicalResult emitError(SMRange loc, const Twine &msg) {
525  lexer.emitError(loc, msg);
526  return failure();
527  }
528  LogicalResult emitError(const Twine &msg) {
529  return emitError(curToken.getLoc(), msg);
530  }
531  LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc,
532  const Twine &note) {
533  lexer.emitErrorAndNote(loc, msg, noteLoc, note);
534  return failure();
535  }
536 
537  //===--------------------------------------------------------------------===//
538  // Fields
539  //===--------------------------------------------------------------------===//
540 
541  /// The owning AST context.
542  ast::Context &ctx;
543 
544  /// The lexer of this parser.
545  Lexer lexer;
546 
547  /// The current token within the lexer.
548  Token curToken;
549 
550  /// A flag indicating if the parser should add documentation to AST nodes when
551  /// viable.
552  bool enableDocumentation;
553 
554  /// The most recently defined decl scope.
555  ast::DeclScope *curDeclScope = nullptr;
556  llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;
557 
558  /// The current context of the parser.
559  ParserContext parserContext = ParserContext::Global;
560 
561  /// Cached types to simplify verification and expression creation.
562  ast::Type typeTy, valueTy;
563  ast::RangeType typeRangeTy, valueRangeTy;
564  ast::Type attrTy;
565 
566  /// A counter used when naming anonymous constraints and rewrites.
567  unsigned anonymousDeclNameCounter = 0;
568 
569  /// The optional code completion context.
570  CodeCompleteContext *codeCompleteContext;
571 };
572 } // namespace
573 
574 FailureOr<ast::Module *> Parser::parseModule() {
575  SMLoc moduleLoc = curToken.getStartLoc();
576  pushDeclScope();
577 
578  // Parse the top-level decls of the module.
580  if (failed(parseModuleBody(decls)))
581  return popDeclScope(), failure();
582 
583  popDeclScope();
584  return ast::Module::create(ctx, moduleLoc, decls);
585 }
586 
587 LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) {
588  while (curToken.isNot(Token::eof)) {
589  if (curToken.is(Token::directive)) {
590  if (failed(parseDirective(decls)))
591  return failure();
592  continue;
593  }
594 
595  FailureOr<ast::Decl *> decl = parseTopLevelDecl();
596  if (failed(decl))
597  return failure();
598  decls.push_back(*decl);
599  }
600  return success();
601 }
602 
603 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) {
604  return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr,
605  valueRangeTy);
606 }
607 
608 LogicalResult Parser::convertExpressionTo(
609  ast::Expr *&expr, ast::Type type,
610  function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
611  ast::Type exprType = expr->getType();
612  if (exprType == type)
613  return success();
614 
615  auto emitConvertError = [&]() -> ast::InFlightDiagnostic {
616  ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError(
617  expr->getLoc(), llvm::formatv("unable to convert expression of type "
618  "`{0}` to the expected type of "
619  "`{1}`",
620  exprType, type));
621  if (noteAttachFn)
622  noteAttachFn(*diag);
623  return diag;
624  };
625 
626  if (auto exprOpType = exprType.dyn_cast<ast::OperationType>())
627  return convertOpExpressionTo(expr, exprOpType, type, emitConvertError);
628 
629  // FIXME: Decide how to allow/support converting a single result to multiple,
630  // and multiple to a single result. For now, we just allow Single->Range,
631  // but this isn't something really supported in the PDL dialect. We should
632  // figure out some way to support both.
633  if ((exprType == valueTy || exprType == valueRangeTy) &&
634  (type == valueTy || type == valueRangeTy))
635  return success();
636  if ((exprType == typeTy || exprType == typeRangeTy) &&
637  (type == typeTy || type == typeRangeTy))
638  return success();
639 
640  // Handle tuple types.
641  if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>())
642  return convertTupleExpressionTo(expr, exprTupleType, type, emitConvertError,
643  noteAttachFn);
644 
645  return emitConvertError();
646 }
647 
648 LogicalResult Parser::convertOpExpressionTo(
649  ast::Expr *&expr, ast::OperationType exprType, ast::Type type,
650  function_ref<ast::InFlightDiagnostic()> emitErrorFn) {
651  // Two operation types are compatible if they have the same name, or if the
652  // expected type is more general.
653  if (auto opType = type.dyn_cast<ast::OperationType>()) {
654  if (opType.getName())
655  return emitErrorFn();
656  return success();
657  }
658 
659  // An operation can always convert to a ValueRange.
660  if (type == valueRangeTy) {
661  expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
662  valueRangeTy);
663  return success();
664  }
665 
666  // Allow conversion to a single value by constraining the result range.
667  if (type == valueTy) {
668  // If the operation is registered, we can verify if it can ever have a
669  // single result.
670  if (const ods::Operation *odsOp = exprType.getODSOperation()) {
671  if (odsOp->getResults().empty()) {
672  return emitErrorFn()->attachNote(
673  llvm::formatv("see the definition of `{0}`, which was defined "
674  "with zero results",
675  odsOp->getName()),
676  odsOp->getLoc());
677  }
678 
679  unsigned numSingleResults = llvm::count_if(
680  odsOp->getResults(), [](const ods::OperandOrResult &result) {
681  return result.getVariableLengthKind() ==
682  ods::VariableLengthKind::Single;
683  });
684  if (numSingleResults > 1) {
685  return emitErrorFn()->attachNote(
686  llvm::formatv("see the definition of `{0}`, which was defined "
687  "with at least {1} results",
688  odsOp->getName(), numSingleResults),
689  odsOp->getLoc());
690  }
691  }
692 
693  expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
694  valueTy);
695  return success();
696  }
697  return emitErrorFn();
698 }
699 
700 LogicalResult Parser::convertTupleExpressionTo(
701  ast::Expr *&expr, ast::TupleType exprType, ast::Type type,
702  function_ref<ast::InFlightDiagnostic()> emitErrorFn,
703  function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
704  // Handle conversions between tuples.
705  if (auto tupleType = type.dyn_cast<ast::TupleType>()) {
706  if (tupleType.size() != exprType.size())
707  return emitErrorFn();
708 
709  // Build a new tuple expression using each of the elements of the current
710  // tuple.
711  SmallVector<ast::Expr *> newExprs;
712  for (unsigned i = 0, e = exprType.size(); i < e; ++i) {
713  newExprs.push_back(ast::MemberAccessExpr::create(
714  ctx, expr->getLoc(), expr, llvm::to_string(i),
715  exprType.getElementTypes()[i]));
716 
717  auto diagFn = [&](ast::Diagnostic &diag) {
718  diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`",
719  i, exprType));
720  if (noteAttachFn)
721  noteAttachFn(diag);
722  };
723  if (failed(convertExpressionTo(newExprs.back(),
724  tupleType.getElementTypes()[i], diagFn)))
725  return failure();
726  }
727  expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs,
728  tupleType.getElementNames());
729  return success();
730  }
731 
732  // Handle conversion to a range.
733  auto convertToRange = [&](ArrayRef<ast::Type> allowedElementTypes,
734  ast::RangeType resultTy) -> LogicalResult {
735  // TODO: We currently only allow range conversion within a rewrite context.
736  if (parserContext != ParserContext::Rewrite) {
737  return emitErrorFn()->attachNote("Tuple to Range conversion is currently "
738  "only allowed within a rewrite context");
739  }
740 
741  // All of the tuple elements must be allowed types.
742  for (ast::Type elementType : exprType.getElementTypes())
743  if (!llvm::is_contained(allowedElementTypes, elementType))
744  return emitErrorFn();
745 
746  // Build a new tuple expression using each of the elements of the current
747  // tuple.
748  SmallVector<ast::Expr *> newExprs;
749  for (unsigned i = 0, e = exprType.size(); i < e; ++i) {
750  newExprs.push_back(ast::MemberAccessExpr::create(
751  ctx, expr->getLoc(), expr, llvm::to_string(i),
752  exprType.getElementTypes()[i]));
753  }
754  expr = ast::RangeExpr::create(ctx, expr->getLoc(), newExprs, resultTy);
755  return success();
756  };
757  if (type == valueRangeTy)
758  return convertToRange({valueTy, valueRangeTy}, valueRangeTy);
759  if (type == typeRangeTy)
760  return convertToRange({typeTy, typeRangeTy}, typeRangeTy);
761 
762  return emitErrorFn();
763 }
764 
765 //===----------------------------------------------------------------------===//
766 // Directives
767 
768 LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) {
769  StringRef directive = curToken.getSpelling();
770  if (directive == "#include")
771  return parseInclude(decls);
772 
773  return emitError("unknown directive `" + directive + "`");
774 }
775 
776 LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) {
777  SMRange loc = curToken.getLoc();
778  consumeToken(Token::directive);
779 
780  // Handle code completion of the include file path.
781  if (curToken.is(Token::code_complete_string))
782  return codeCompleteIncludeFilename(curToken.getStringValue());
783 
784  // Parse the file being included.
785  if (!curToken.isString())
786  return emitError(loc,
787  "expected string file name after `include` directive");
788  SMRange fileLoc = curToken.getLoc();
789  std::string filenameStr = curToken.getStringValue();
790  StringRef filename = filenameStr;
791  consumeToken();
792 
793  // Check the type of include. If ending with `.pdll`, this is another pdl file
794  // to be parsed along with the current module.
795  if (filename.endswith(".pdll")) {
796  if (failed(lexer.pushInclude(filename, fileLoc)))
797  return emitError(fileLoc,
798  "unable to open include file `" + filename + "`");
799 
800  // If we added the include successfully, parse it into the current module.
801  // Make sure to update to the next token after we finish parsing the nested
802  // file.
803  curToken = lexer.lexToken();
804  LogicalResult result = parseModuleBody(decls);
805  curToken = lexer.lexToken();
806  return result;
807  }
808 
809  // Otherwise, this must be a `.td` include.
810  if (filename.endswith(".td"))
811  return parseTdInclude(filename, fileLoc, decls);
812 
813  return emitError(fileLoc,
814  "expected include filename to end with `.pdll` or `.td`");
815 }
816 
817 LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc,
819  llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr();
820 
821  // Use the source manager to open the file, but don't yet add it.
822  std::string includedFile;
823  llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer =
824  parserSrcMgr.OpenIncludeFile(filename.str(), includedFile);
825  if (!includeBuffer)
826  return emitError(fileLoc, "unable to open include file `" + filename + "`");
827 
828  // Setup the source manager for parsing the tablegen file.
829  llvm::SourceMgr tdSrcMgr;
830  tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc());
831  tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs());
832 
833  // This class provides a context argument for the llvm::SourceMgr diagnostic
834  // handler.
835  struct DiagHandlerContext {
836  Parser &parser;
837  StringRef filename;
838  llvm::SMRange loc;
839  } handlerContext{*this, filename, fileLoc};
840 
841  // Set the diagnostic handler for the tablegen source manager.
842  tdSrcMgr.setDiagHandler(
843  [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) {
844  auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext);
845  (void)ctx->parser.emitError(
846  ctx->loc,
847  llvm::formatv("error while processing include file `{0}`: {1}",
848  ctx->filename, diag.getMessage()));
849  },
850  &handlerContext);
851 
852  // Parse the tablegen file.
853  llvm::RecordKeeper tdRecords;
854  if (llvm::TableGenParseFile(tdSrcMgr, tdRecords))
855  return failure();
856 
857  // Process the parsed records.
858  processTdIncludeRecords(tdRecords, decls);
859 
860  // After we are done processing, move all of the tablegen source buffers to
861  // the main parser source mgr. This allows for directly using source locations
862  // from the .td files without needing to remap them.
863  parserSrcMgr.takeSourceBuffersFrom(tdSrcMgr, fileLoc.End);
864  return success();
865 }
866 
867 void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
869  // Return the length kind of the given value.
870  auto getLengthKind = [](const auto &value) {
871  if (value.isOptional())
873  return value.isVariadic() ? ods::VariableLengthKind::Variadic
875  };
876 
877  // Insert a type constraint into the ODS context.
878  ods::Context &odsContext = ctx.getODSContext();
879  auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst)
880  -> const ods::TypeConstraint & {
881  return odsContext.insertTypeConstraint(
882  cst.constraint.getUniqueDefName(),
883  processDoc(cst.constraint.getSummary()),
884  cst.constraint.getCPPClassName());
885  };
886  auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange {
887  return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)};
888  };
889 
890  // Process the parsed tablegen records to build ODS information.
891  /// Operations.
892  for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) {
893  tblgen::Operator op(def);
894 
895  // Check to see if this operation is known to support type inferrence.
896  bool supportsResultTypeInferrence =
897  op.getTrait("::mlir::InferTypeOpInterface::Trait");
898 
899  auto [odsOp, inserted] = odsContext.insertOperation(
900  op.getOperationName(), processDoc(op.getSummary()),
901  processAndFormatDoc(op.getDescription()), op.getQualCppClassName(),
902  supportsResultTypeInferrence, op.getLoc().front());
903 
904  // Ignore operations that have already been added.
905  if (!inserted)
906  continue;
907 
908  for (const tblgen::NamedAttribute &attr : op.getAttributes()) {
909  odsOp->appendAttribute(attr.name, attr.attr.isOptional(),
910  odsContext.insertAttributeConstraint(
911  attr.attr.getUniqueDefName(),
912  processDoc(attr.attr.getSummary()),
913  attr.attr.getStorageType()));
914  }
915  for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) {
916  odsOp->appendOperand(operand.name, getLengthKind(operand),
917  addTypeConstraint(operand));
918  }
919  for (const tblgen::NamedTypeConstraint &result : op.getResults()) {
920  odsOp->appendResult(result.name, getLengthKind(result),
921  addTypeConstraint(result));
922  }
923  }
924 
925  auto shouldBeSkipped = [this](llvm::Record *def) {
926  return def->isAnonymous() || curDeclScope->lookup(def->getName()) ||
927  def->isSubClassOf("DeclareInterfaceMethods");
928  };
929 
930  /// Attr constraints.
931  for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) {
932  if (shouldBeSkipped(def))
933  continue;
934 
935  tblgen::Attribute constraint(def);
936  decls.push_back(createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
937  constraint, convertLocToRange(def->getLoc().front()), attrTy,
938  constraint.getStorageType()));
939  }
940  /// Type constraints.
941  for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) {
942  if (shouldBeSkipped(def))
943  continue;
944 
945  tblgen::TypeConstraint constraint(def);
946  decls.push_back(createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
947  constraint, convertLocToRange(def->getLoc().front()), typeTy,
948  constraint.getCPPClassName()));
949  }
950  /// OpInterfaces.
952  for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("OpInterface")) {
953  if (shouldBeSkipped(def))
954  continue;
955 
956  SMRange loc = convertLocToRange(def->getLoc().front());
957 
958  std::string cppClassName =
959  llvm::formatv("{0}::{1}", def->getValueAsString("cppNamespace"),
960  def->getValueAsString("cppInterfaceName"))
961  .str();
962  std::string codeBlock =
963  llvm::formatv("return ::mlir::success(llvm::isa<{0}>(self));",
964  cppClassName)
965  .str();
966 
967  std::string desc =
968  processAndFormatDoc(def->getValueAsString("description"));
969  decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>(
970  def->getName(), codeBlock, loc, opTy, cppClassName, desc));
971  }
972 }
973 
974 template <typename ConstraintT>
975 ast::Decl *Parser::createODSNativePDLLConstraintDecl(
976  StringRef name, StringRef codeBlock, SMRange loc, ast::Type type,
977  StringRef nativeType, StringRef docString) {
978  // Build the single input parameter.
979  ast::DeclScope *argScope = pushDeclScope();
980  auto *paramVar = ast::VariableDecl::create(
981  ctx, ast::Name::create(ctx, "self", loc), type,
982  /*initExpr=*/nullptr, ast::ConstraintRef(ConstraintT::create(ctx, loc)));
983  argScope->add(paramVar);
984  popDeclScope();
985 
986  // Build the native constraint.
987  auto *constraintDecl = ast::UserConstraintDecl::createNative(
988  ctx, ast::Name::create(ctx, name, loc), paramVar,
989  /*results=*/std::nullopt, codeBlock, ast::TupleType::get(ctx),
990  nativeType);
991  constraintDecl->setDocComment(ctx, docString);
992  curDeclScope->add(constraintDecl);
993  return constraintDecl;
994 }
995 
996 template <typename ConstraintT>
997 ast::Decl *
998 Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
999  SMRange loc, ast::Type type,
1000  StringRef nativeType) {
1001  // Format the condition template.
1002  tblgen::FmtContext fmtContext;
1003  fmtContext.withSelf("self");
1004  std::string codeBlock = tblgen::tgfmt(
1005  "return ::mlir::success(" + constraint.getConditionTemplate() + ");",
1006  &fmtContext);
1007 
1008  // If documentation was enabled, build the doc string for the generated
1009  // constraint. It would be nice to do this lazily, but TableGen information is
1010  // destroyed after we finish parsing the file.
1011  std::string docString;
1012  if (enableDocumentation) {
1013  StringRef desc = constraint.getDescription();
1014  docString = processAndFormatDoc(
1015  constraint.getSummary() +
1016  (desc.empty() ? "" : ("\n\n" + constraint.getDescription())));
1017  }
1018 
1019  return createODSNativePDLLConstraintDecl<ConstraintT>(
1020  constraint.getUniqueDefName(), codeBlock, loc, type, nativeType,
1021  docString);
1022 }
1023 
1024 //===----------------------------------------------------------------------===//
1025 // Decls
1026 
1027 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
1029  switch (curToken.getKind()) {
1030  case Token::kw_Constraint:
1031  decl = parseUserConstraintDecl();
1032  break;
1033  case Token::kw_Pattern:
1034  decl = parsePatternDecl();
1035  break;
1036  case Token::kw_Rewrite:
1037  decl = parseUserRewriteDecl();
1038  break;
1039  default:
1040  return emitError("expected top-level declaration, such as a `Pattern`");
1041  }
1042  if (failed(decl))
1043  return failure();
1044 
1045  // If the decl has a name, add it to the current scope.
1046  if (const ast::Name *name = (*decl)->getName()) {
1047  if (failed(checkDefineNamedDecl(*name)))
1048  return failure();
1049  curDeclScope->add(*decl);
1050  }
1051  return decl;
1052 }
1053 
1055 Parser::parseNamedAttributeDecl(std::optional<StringRef> parentOpName) {
1056  // Check for name code completion.
1057  if (curToken.is(Token::code_complete))
1058  return codeCompleteAttributeName(parentOpName);
1059 
1060  std::string attrNameStr;
1061  if (curToken.isString())
1062  attrNameStr = curToken.getStringValue();
1063  else if (curToken.is(Token::identifier) || curToken.isKeyword())
1064  attrNameStr = curToken.getSpelling().str();
1065  else
1066  return emitError("expected identifier or string attribute name");
1067  const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc());
1068  consumeToken();
1069 
1070  // Check for a value of the attribute.
1071  ast::Expr *attrValue = nullptr;
1072  if (consumeIf(Token::equal)) {
1073  FailureOr<ast::Expr *> attrExpr = parseExpr();
1074  if (failed(attrExpr))
1075  return failure();
1076  attrValue = *attrExpr;
1077  } else {
1078  // If there isn't a concrete value, create an expression representing a
1079  // UnitAttr.
1080  attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit");
1081  }
1082 
1083  return ast::NamedAttributeDecl::create(ctx, name, attrValue);
1084 }
1085 
1086 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody(
1087  function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
1088  bool expectTerminalSemicolon) {
1089  consumeToken(Token::equal_arrow);
1090 
1091  // Parse the single statement of the lambda body.
1092  SMLoc bodyStartLoc = curToken.getStartLoc();
1093  pushDeclScope();
1094  FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon);
1095  bool failedToParse =
1096  failed(singleStatement) || failed(processStatementFn(*singleStatement));
1097  popDeclScope();
1098  if (failedToParse)
1099  return failure();
1100 
1101  SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc());
1102  return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement);
1103 }
1104 
1105 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() {
1106  // Ensure that the argument is named.
1107  if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword())
1108  return emitError("expected identifier argument name");
1109 
1110  // Parse the argument similarly to a normal variable.
1111  StringRef name = curToken.getSpelling();
1112  SMRange nameLoc = curToken.getLoc();
1113  consumeToken();
1114 
1115  if (failed(
1116  parseToken(Token::colon, "expected `:` before argument constraint")))
1117  return failure();
1118 
1119  FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1120  if (failed(cst))
1121  return failure();
1122 
1123  return createArgOrResultVariableDecl(name, nameLoc, *cst);
1124 }
1125 
1126 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) {
1127  // Check to see if this result is named.
1128  if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
1129  // Check to see if this name actually refers to a Constraint.
1130  if (!curDeclScope->lookup<ast::ConstraintDecl>(curToken.getSpelling())) {
1131  // If it wasn't a constraint, parse the result similarly to a variable. If
1132  // there is already an existing decl, we will emit an error when defining
1133  // this variable later.
1134  StringRef name = curToken.getSpelling();
1135  SMRange nameLoc = curToken.getLoc();
1136  consumeToken();
1137 
1138  if (failed(parseToken(Token::colon,
1139  "expected `:` before result constraint")))
1140  return failure();
1141 
1142  FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1143  if (failed(cst))
1144  return failure();
1145 
1146  return createArgOrResultVariableDecl(name, nameLoc, *cst);
1147  }
1148  }
1149 
1150  // If it isn't named, we parse the constraint directly and create an unnamed
1151  // result variable.
1152  FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1153  if (failed(cst))
1154  return failure();
1155 
1156  return createArgOrResultVariableDecl("", cst->referenceLoc, *cst);
1157 }
1158 
1160 Parser::parseUserConstraintDecl(bool isInline) {
1161  // Constraints and rewrites have very similar formats, dispatch to a shared
1162  // interface for parsing.
1163  return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1164  [&](auto &&...args) {
1165  return this->parseUserPDLLConstraintDecl(args...);
1166  },
1167  ParserContext::Constraint, "constraint", isInline);
1168 }
1169 
1170 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() {
1172  parseUserConstraintDecl(/*isInline=*/true);
1173  if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1174  return failure();
1175 
1176  curDeclScope->add(*decl);
1177  return decl;
1178 }
1179 
1180 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl(
1181  const ast::Name &name, bool isInline,
1182  ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1183  ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1184  // Push the argument scope back onto the list, so that the body can
1185  // reference arguments.
1186  pushDeclScope(argumentScope);
1187 
1188  // Parse the body of the constraint. The body is either defined as a compound
1189  // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`.
1190  ast::CompoundStmt *body;
1191  if (curToken.is(Token::equal_arrow)) {
1192  FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1193  [&](ast::Stmt *&stmt) -> LogicalResult {
1194  ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt);
1195  if (!stmtExpr) {
1196  return emitError(stmt->getLoc(),
1197  "expected `Constraint` lambda body to contain a "
1198  "single expression");
1199  }
1200  stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr);
1201  return success();
1202  },
1203  /*expectTerminalSemicolon=*/!isInline);
1204  if (failed(bodyResult))
1205  return failure();
1206  body = *bodyResult;
1207  } else {
1208  FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1209  if (failed(bodyResult))
1210  return failure();
1211  body = *bodyResult;
1212 
1213  // Verify the structure of the body.
1214  auto bodyIt = body->begin(), bodyE = body->end();
1215  for (; bodyIt != bodyE; ++bodyIt)
1216  if (isa<ast::ReturnStmt>(*bodyIt))
1217  break;
1218  if (failed(validateUserConstraintOrRewriteReturn(
1219  "Constraint", body, bodyIt, bodyE, results, resultType)))
1220  return failure();
1221  }
1222  popDeclScope();
1223 
1224  return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1225  name, arguments, results, resultType, body);
1226 }
1227 
1228 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) {
1229  // Constraints and rewrites have very similar formats, dispatch to a shared
1230  // interface for parsing.
1231  return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1232  [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(args...); },
1233  ParserContext::Rewrite, "rewrite", isInline);
1234 }
1235 
1236 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() {
1238  parseUserRewriteDecl(/*isInline=*/true);
1239  if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1240  return failure();
1241 
1242  curDeclScope->add(*decl);
1243  return decl;
1244 }
1245 
1246 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl(
1247  const ast::Name &name, bool isInline,
1248  ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1249  ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1250  // Push the argument scope back onto the list, so that the body can
1251  // reference arguments.
1252  curDeclScope = argumentScope;
1253  ast::CompoundStmt *body;
1254  if (curToken.is(Token::equal_arrow)) {
1255  FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1256  [&](ast::Stmt *&statement) -> LogicalResult {
1257  if (isa<ast::OpRewriteStmt>(statement))
1258  return success();
1259 
1260  ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement);
1261  if (!statementExpr) {
1262  return emitError(
1263  statement->getLoc(),
1264  "expected `Rewrite` lambda body to contain a single expression "
1265  "or an operation rewrite statement; such as `erase`, "
1266  "`replace`, or `rewrite`");
1267  }
1268  statement =
1269  ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr);
1270  return success();
1271  },
1272  /*expectTerminalSemicolon=*/!isInline);
1273  if (failed(bodyResult))
1274  return failure();
1275  body = *bodyResult;
1276  } else {
1277  FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1278  if (failed(bodyResult))
1279  return failure();
1280  body = *bodyResult;
1281  }
1282  popDeclScope();
1283 
1284  // Verify the structure of the body.
1285  auto bodyIt = body->begin(), bodyE = body->end();
1286  for (; bodyIt != bodyE; ++bodyIt)
1287  if (isa<ast::ReturnStmt>(*bodyIt))
1288  break;
1289  if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt,
1290  bodyE, results, resultType)))
1291  return failure();
1292  return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1293  name, arguments, results, resultType, body);
1294 }
1295 
1296 template <typename T, typename ParseUserPDLLDeclFnT>
1297 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl(
1298  ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
1299  StringRef anonymousNamePrefix, bool isInline) {
1300  SMRange loc = curToken.getLoc();
1301  consumeToken();
1302  llvm::SaveAndRestore saveCtx(parserContext, declContext);
1303 
1304  // Parse the name of the decl.
1305  const ast::Name *name = nullptr;
1306  if (curToken.isNot(Token::identifier)) {
1307  // Only inline decls can be un-named. Inline decls are similar to "lambdas"
1308  // in C++, so being unnamed is fine.
1309  if (!isInline)
1310  return emitError("expected identifier name");
1311 
1312  // Create a unique anonymous name to use, as the name for this decl is not
1313  // important.
1314  std::string anonName =
1315  llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix,
1316  anonymousDeclNameCounter++)
1317  .str();
1318  name = &ast::Name::create(ctx, anonName, loc);
1319  } else {
1320  // If a name was provided, we can use it directly.
1321  name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1322  consumeToken(Token::identifier);
1323  }
1324 
1325  // Parse the functional signature of the decl.
1326  SmallVector<ast::VariableDecl *> arguments, results;
1327  ast::DeclScope *argumentScope;
1328  ast::Type resultType;
1329  if (failed(parseUserConstraintOrRewriteSignature(arguments, results,
1330  argumentScope, resultType)))
1331  return failure();
1332 
1333  // Check to see which type of constraint this is. If the constraint contains a
1334  // compound body, this is a PDLL decl.
1335  if (curToken.isAny(Token::l_brace, Token::equal_arrow))
1336  return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results,
1337  resultType);
1338 
1339  // Otherwise, this is a native decl.
1340  return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments,
1341  results, resultType);
1342 }
1343 
1344 template <typename T>
1345 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
1346  const ast::Name &name, bool isInline,
1348  ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1349  // If followed by a string, the native code body has also been specified.
1350  std::string codeStrStorage;
1351  std::optional<StringRef> optCodeStr;
1352  if (curToken.isString()) {
1353  codeStrStorage = curToken.getStringValue();
1354  optCodeStr = codeStrStorage;
1355  consumeToken();
1356  } else if (isInline) {
1357  return emitError(name.getLoc(),
1358  "external declarations must be declared in global scope");
1359  } else if (curToken.is(Token::error)) {
1360  return failure();
1361  }
1362  if (failed(parseToken(Token::semicolon,
1363  "expected `;` after native declaration")))
1364  return failure();
1365  // TODO: PDL should be able to support constraint results in certain
1366  // situations, we should revise this.
1367  if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) {
1368  return emitError(
1369  "native Constraints currently do not support returning results");
1370  }
1371  return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
1372 }
1373 
1374 LogicalResult Parser::parseUserConstraintOrRewriteSignature(
1377  ast::DeclScope *&argumentScope, ast::Type &resultType) {
1378  // Parse the argument list of the decl.
1379  if (failed(parseToken(Token::l_paren, "expected `(` to start argument list")))
1380  return failure();
1381 
1382  argumentScope = pushDeclScope();
1383  if (curToken.isNot(Token::r_paren)) {
1384  do {
1385  FailureOr<ast::VariableDecl *> argument = parseArgumentDecl();
1386  if (failed(argument))
1387  return failure();
1388  arguments.emplace_back(*argument);
1389  } while (consumeIf(Token::comma));
1390  }
1391  popDeclScope();
1392  if (failed(parseToken(Token::r_paren, "expected `)` to end argument list")))
1393  return failure();
1394 
1395  // Parse the results of the decl.
1396  pushDeclScope();
1397  if (consumeIf(Token::arrow)) {
1398  auto parseResultFn = [&]() -> LogicalResult {
1399  FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size());
1400  if (failed(result))
1401  return failure();
1402  results.emplace_back(*result);
1403  return success();
1404  };
1405 
1406  // Check for a list of results.
1407  if (consumeIf(Token::l_paren)) {
1408  do {
1409  if (failed(parseResultFn()))
1410  return failure();
1411  } while (consumeIf(Token::comma));
1412  if (failed(parseToken(Token::r_paren, "expected `)` to end result list")))
1413  return failure();
1414 
1415  // Otherwise, there is only one result.
1416  } else if (failed(parseResultFn())) {
1417  return failure();
1418  }
1419  }
1420  popDeclScope();
1421 
1422  // Compute the result type of the decl.
1423  resultType = createUserConstraintRewriteResultType(results);
1424 
1425  // Verify that results are only named if there are more than one.
1426  if (results.size() == 1 && !results.front()->getName().getName().empty()) {
1427  return emitError(
1428  results.front()->getLoc(),
1429  "cannot create a single-element tuple with an element label");
1430  }
1431  return success();
1432 }
1433 
1434 LogicalResult Parser::validateUserConstraintOrRewriteReturn(
1435  StringRef declType, ast::CompoundStmt *body,
1438  ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) {
1439  // Handle if a `return` was provided.
1440  if (bodyIt != bodyE) {
1441  // Emit an error if we have trailing statements after the return.
1442  if (std::next(bodyIt) != bodyE) {
1443  return emitError(
1444  (*std::next(bodyIt))->getLoc(),
1445  llvm::formatv("`return` terminated the `{0}` body, but found "
1446  "trailing statements afterwards",
1447  declType));
1448  }
1449 
1450  // Otherwise if a return wasn't provided, check that no results are
1451  // expected.
1452  } else if (!results.empty()) {
1453  return emitError(
1454  {body->getLoc().End, body->getLoc().End},
1455  llvm::formatv("missing return in a `{0}` expected to return `{1}`",
1456  declType, resultType));
1457  }
1458  return success();
1459 }
1460 
1461 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
1462  return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult {
1463  if (isa<ast::OpRewriteStmt>(statement))
1464  return success();
1465  return emitError(
1466  statement->getLoc(),
1467  "expected Pattern lambda body to contain a single operation "
1468  "rewrite statement, such as `erase`, `replace`, or `rewrite`");
1469  });
1470 }
1471 
1472 FailureOr<ast::Decl *> Parser::parsePatternDecl() {
1473  SMRange loc = curToken.getLoc();
1474  consumeToken(Token::kw_Pattern);
1475  llvm::SaveAndRestore saveCtx(parserContext, ParserContext::PatternMatch);
1476 
1477  // Check for an optional identifier for the pattern name.
1478  const ast::Name *name = nullptr;
1479  if (curToken.is(Token::identifier)) {
1480  name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1481  consumeToken(Token::identifier);
1482  }
1483 
1484  // Parse any pattern metadata.
1485  ParsedPatternMetadata metadata;
1486  if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata)))
1487  return failure();
1488 
1489  // Parse the pattern body.
1490  ast::CompoundStmt *body;
1491 
1492  // Handle a lambda body.
1493  if (curToken.is(Token::equal_arrow)) {
1494  FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody();
1495  if (failed(bodyResult))
1496  return failure();
1497  body = *bodyResult;
1498  } else {
1499  if (curToken.isNot(Token::l_brace))
1500  return emitError("expected `{` or `=>` to start pattern body");
1501  FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1502  if (failed(bodyResult))
1503  return failure();
1504  body = *bodyResult;
1505 
1506  // Verify the body of the pattern.
1507  auto bodyIt = body->begin(), bodyE = body->end();
1508  for (; bodyIt != bodyE; ++bodyIt) {
1509  if (isa<ast::ReturnStmt>(*bodyIt)) {
1510  return emitError((*bodyIt)->getLoc(),
1511  "`return` statements are only permitted within a "
1512  "`Constraint` or `Rewrite` body");
1513  }
1514  // Break when we've found the rewrite statement.
1515  if (isa<ast::OpRewriteStmt>(*bodyIt))
1516  break;
1517  }
1518  if (bodyIt == bodyE) {
1519  return emitError(loc,
1520  "expected Pattern body to terminate with an operation "
1521  "rewrite statement, such as `erase`");
1522  }
1523  if (std::next(bodyIt) != bodyE) {
1524  return emitError((*std::next(bodyIt))->getLoc(),
1525  "Pattern body was terminated by an operation "
1526  "rewrite statement, but found trailing statements");
1527  }
1528  }
1529 
1530  return createPatternDecl(loc, name, metadata, body);
1531 }
1532 
1534 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
1535  std::optional<SMRange> benefitLoc;
1536  std::optional<SMRange> hasBoundedRecursionLoc;
1537 
1538  do {
1539  // Handle metadata code completion.
1540  if (curToken.is(Token::code_complete))
1541  return codeCompletePatternMetadata();
1542 
1543  if (curToken.isNot(Token::identifier))
1544  return emitError("expected pattern metadata identifier");
1545  StringRef metadataStr = curToken.getSpelling();
1546  SMRange metadataLoc = curToken.getLoc();
1547  consumeToken(Token::identifier);
1548 
1549  // Parse the benefit metadata: benefit(<integer-value>)
1550  if (metadataStr == "benefit") {
1551  if (benefitLoc) {
1552  return emitErrorAndNote(metadataLoc,
1553  "pattern benefit has already been specified",
1554  *benefitLoc, "see previous definition here");
1555  }
1556  if (failed(parseToken(Token::l_paren,
1557  "expected `(` before pattern benefit")))
1558  return failure();
1559 
1560  uint16_t benefitValue = 0;
1561  if (curToken.isNot(Token::integer))
1562  return emitError("expected integral pattern benefit");
1563  if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue))
1564  return emitError(
1565  "expected pattern benefit to fit within a 16-bit integer");
1566  consumeToken(Token::integer);
1567 
1568  metadata.benefit = benefitValue;
1569  benefitLoc = metadataLoc;
1570 
1571  if (failed(
1572  parseToken(Token::r_paren, "expected `)` after pattern benefit")))
1573  return failure();
1574  continue;
1575  }
1576 
1577  // Parse the bounded recursion metadata: recursion
1578  if (metadataStr == "recursion") {
1579  if (hasBoundedRecursionLoc) {
1580  return emitErrorAndNote(
1581  metadataLoc,
1582  "pattern recursion metadata has already been specified",
1583  *hasBoundedRecursionLoc, "see previous definition here");
1584  }
1585  metadata.hasBoundedRecursion = true;
1586  hasBoundedRecursionLoc = metadataLoc;
1587  continue;
1588  }
1589 
1590  return emitError(metadataLoc, "unknown pattern metadata");
1591  } while (consumeIf(Token::comma));
1592 
1593  return success();
1594 }
1595 
1596 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() {
1597  consumeToken(Token::less);
1598 
1599  FailureOr<ast::Expr *> typeExpr = parseExpr();
1600  if (failed(typeExpr) ||
1601  failed(parseToken(Token::greater,
1602  "expected `>` after variable type constraint")))
1603  return failure();
1604  return typeExpr;
1605 }
1606 
1607 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) {
1608  assert(curDeclScope && "defining decl outside of a decl scope");
1609  if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) {
1610  return emitErrorAndNote(
1611  name.getLoc(), "`" + name.getName() + "` has already been defined",
1612  lastDecl->getName()->getLoc(), "see previous definition here");
1613  }
1614  return success();
1615 }
1616 
1618 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1619  ast::Expr *initExpr,
1620  ArrayRef<ast::ConstraintRef> constraints) {
1621  assert(curDeclScope && "defining variable outside of decl scope");
1622  const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc);
1623 
1624  // If the name of the variable indicates a special variable, we don't add it
1625  // to the scope. This variable is local to the definition point.
1626  if (name.empty() || name == "_") {
1627  return ast::VariableDecl::create(ctx, nameDecl, type, initExpr,
1628  constraints);
1629  }
1630  if (failed(checkDefineNamedDecl(nameDecl)))
1631  return failure();
1632 
1633  auto *varDecl =
1634  ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints);
1635  curDeclScope->add(varDecl);
1636  return varDecl;
1637 }
1638 
1640 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1641  ArrayRef<ast::ConstraintRef> constraints) {
1642  return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr,
1643  constraints);
1644 }
1645 
1646 LogicalResult Parser::parseVariableDeclConstraintList(
1647  SmallVectorImpl<ast::ConstraintRef> &constraints) {
1648  std::optional<SMRange> typeConstraint;
1649  auto parseSingleConstraint = [&] {
1650  FailureOr<ast::ConstraintRef> constraint = parseConstraint(
1651  typeConstraint, constraints, /*allowInlineTypeConstraints=*/true);
1652  if (failed(constraint))
1653  return failure();
1654  constraints.push_back(*constraint);
1655  return success();
1656  };
1657 
1658  // Check to see if this is a single constraint, or a list.
1659  if (!consumeIf(Token::l_square))
1660  return parseSingleConstraint();
1661 
1662  do {
1663  if (failed(parseSingleConstraint()))
1664  return failure();
1665  } while (consumeIf(Token::comma));
1666  return parseToken(Token::r_square, "expected `]` after constraint list");
1667 }
1668 
1670 Parser::parseConstraint(std::optional<SMRange> &typeConstraint,
1671  ArrayRef<ast::ConstraintRef> existingConstraints,
1672  bool allowInlineTypeConstraints) {
1673  auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
1674  if (!allowInlineTypeConstraints) {
1675  return emitError(
1676  curToken.getLoc(),
1677  "inline `Attr`, `Value`, and `ValueRange` type constraints are not "
1678  "permitted on arguments or results");
1679  }
1680  if (typeConstraint)
1681  return emitErrorAndNote(
1682  curToken.getLoc(),
1683  "the type of this variable has already been constrained",
1684  *typeConstraint, "see previous constraint location here");
1685  FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr();
1686  if (failed(constraintExpr))
1687  return failure();
1688  typeExpr = *constraintExpr;
1689  typeConstraint = typeExpr->getLoc();
1690  return success();
1691  };
1692 
1693  SMRange loc = curToken.getLoc();
1694  switch (curToken.getKind()) {
1695  case Token::kw_Attr: {
1696  consumeToken(Token::kw_Attr);
1697 
1698  // Check for a type constraint.
1699  ast::Expr *typeExpr = nullptr;
1700  if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1701  return failure();
1702  return ast::ConstraintRef(
1703  ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc);
1704  }
1705  case Token::kw_Op: {
1706  consumeToken(Token::kw_Op);
1707 
1708  // Parse an optional operation name. If the name isn't provided, this refers
1709  // to "any" operation.
1711  parseWrappedOperationName(/*allowEmptyName=*/true);
1712  if (failed(opName))
1713  return failure();
1714 
1715  return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName),
1716  loc);
1717  }
1718  case Token::kw_Type:
1719  consumeToken(Token::kw_Type);
1720  return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc);
1721  case Token::kw_TypeRange:
1722  consumeToken(Token::kw_TypeRange);
1724  loc);
1725  case Token::kw_Value: {
1726  consumeToken(Token::kw_Value);
1727 
1728  // Check for a type constraint.
1729  ast::Expr *typeExpr = nullptr;
1730  if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1731  return failure();
1732 
1733  return ast::ConstraintRef(
1734  ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc);
1735  }
1736  case Token::kw_ValueRange: {
1737  consumeToken(Token::kw_ValueRange);
1738 
1739  // Check for a type constraint.
1740  ast::Expr *typeExpr = nullptr;
1741  if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1742  return failure();
1743 
1744  return ast::ConstraintRef(
1745  ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc);
1746  }
1747 
1748  case Token::kw_Constraint: {
1749  // Handle an inline constraint.
1750  FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1751  if (failed(decl))
1752  return failure();
1753  return ast::ConstraintRef(*decl, loc);
1754  }
1755  case Token::identifier: {
1756  StringRef constraintName = curToken.getSpelling();
1757  consumeToken(Token::identifier);
1758 
1759  // Lookup the referenced constraint.
1760  ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName);
1761  if (!cstDecl) {
1762  return emitError(loc, "unknown reference to constraint `" +
1763  constraintName + "`");
1764  }
1765 
1766  // Handle a reference to a proper constraint.
1767  if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl))
1768  return ast::ConstraintRef(cst, loc);
1769 
1770  return emitErrorAndNote(
1771  loc, "invalid reference to non-constraint", cstDecl->getLoc(),
1772  "see the definition of `" + constraintName + "` here");
1773  }
1774  // Handle single entity constraint code completion.
1775  case Token::code_complete: {
1776  // Try to infer the current type for use by code completion.
1777  ast::Type inferredType;
1778  if (failed(validateVariableConstraints(existingConstraints, inferredType)))
1779  return failure();
1780 
1781  return codeCompleteConstraintName(inferredType, allowInlineTypeConstraints);
1782  }
1783  default:
1784  break;
1785  }
1786  return emitError(loc, "expected identifier constraint");
1787 }
1788 
1789 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() {
1790  std::optional<SMRange> typeConstraint;
1791  return parseConstraint(typeConstraint, /*existingConstraints=*/std::nullopt,
1792  /*allowInlineTypeConstraints=*/false);
1793 }
1794 
1795 //===----------------------------------------------------------------------===//
1796 // Exprs
1797 
1798 FailureOr<ast::Expr *> Parser::parseExpr() {
1799  if (curToken.is(Token::underscore))
1800  return parseUnderscoreExpr();
1801 
1802  // Parse the LHS expression.
1803  FailureOr<ast::Expr *> lhsExpr;
1804  switch (curToken.getKind()) {
1805  case Token::kw_attr:
1806  lhsExpr = parseAttributeExpr();
1807  break;
1808  case Token::kw_Constraint:
1809  lhsExpr = parseInlineConstraintLambdaExpr();
1810  break;
1811  case Token::kw_not:
1812  lhsExpr = parseNegatedExpr();
1813  break;
1814  case Token::identifier:
1815  lhsExpr = parseIdentifierExpr();
1816  break;
1817  case Token::kw_op:
1818  lhsExpr = parseOperationExpr();
1819  break;
1820  case Token::kw_Rewrite:
1821  lhsExpr = parseInlineRewriteLambdaExpr();
1822  break;
1823  case Token::kw_type:
1824  lhsExpr = parseTypeExpr();
1825  break;
1826  case Token::l_paren:
1827  lhsExpr = parseTupleExpr();
1828  break;
1829  default:
1830  return emitError("expected expression");
1831  }
1832  if (failed(lhsExpr))
1833  return failure();
1834 
1835  // Check for an operator expression.
1836  while (true) {
1837  switch (curToken.getKind()) {
1838  case Token::dot:
1839  lhsExpr = parseMemberAccessExpr(*lhsExpr);
1840  break;
1841  case Token::l_paren:
1842  lhsExpr = parseCallExpr(*lhsExpr);
1843  break;
1844  default:
1845  return lhsExpr;
1846  }
1847  if (failed(lhsExpr))
1848  return failure();
1849  }
1850 }
1851 
1852 FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
1853  SMRange loc = curToken.getLoc();
1854  consumeToken(Token::kw_attr);
1855 
1856  // If we aren't followed by a `<`, the `attr` keyword is treated as a normal
1857  // identifier.
1858  if (!consumeIf(Token::less)) {
1859  resetToken(loc);
1860  return parseIdentifierExpr();
1861  }
1862 
1863  if (!curToken.isString())
1864  return emitError("expected string literal containing MLIR attribute");
1865  std::string attrExpr = curToken.getStringValue();
1866  consumeToken();
1867 
1868  loc.End = curToken.getEndLoc();
1869  if (failed(
1870  parseToken(Token::greater, "expected `>` after attribute literal")))
1871  return failure();
1872  return ast::AttributeExpr::create(ctx, loc, attrExpr);
1873 }
1874 
1875 FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr,
1876  bool isNegated) {
1877  consumeToken(Token::l_paren);
1878 
1879  // Parse the arguments of the call.
1880  SmallVector<ast::Expr *> arguments;
1881  if (curToken.isNot(Token::r_paren)) {
1882  do {
1883  // Handle code completion for the call arguments.
1884  if (curToken.is(Token::code_complete)) {
1885  codeCompleteCallSignature(parentExpr, arguments.size());
1886  return failure();
1887  }
1888 
1889  FailureOr<ast::Expr *> argument = parseExpr();
1890  if (failed(argument))
1891  return failure();
1892  arguments.push_back(*argument);
1893  } while (consumeIf(Token::comma));
1894  }
1895 
1896  SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1897  if (failed(parseToken(Token::r_paren, "expected `)` after argument list")))
1898  return failure();
1899 
1900  return createCallExpr(loc, parentExpr, arguments, isNegated);
1901 }
1902 
1903 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
1904  ast::Decl *decl = curDeclScope->lookup(name);
1905  if (!decl)
1906  return emitError(loc, "undefined reference to `" + name + "`");
1907 
1908  return createDeclRefExpr(loc, decl);
1909 }
1910 
1911 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
1912  StringRef name = curToken.getSpelling();
1913  SMRange nameLoc = curToken.getLoc();
1914  consumeToken();
1915 
1916  // Check to see if this is a decl ref expression that defines a variable
1917  // inline.
1918  if (consumeIf(Token::colon)) {
1919  SmallVector<ast::ConstraintRef> constraints;
1920  if (failed(parseVariableDeclConstraintList(constraints)))
1921  return failure();
1922  ast::Type type;
1923  if (failed(validateVariableConstraints(constraints, type)))
1924  return failure();
1925  return createInlineVariableExpr(type, name, nameLoc, constraints);
1926  }
1927 
1928  return parseDeclRefExpr(name, nameLoc);
1929 }
1930 
1931 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() {
1932  FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1933  if (failed(decl))
1934  return failure();
1935 
1936  return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1938 }
1939 
1940 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() {
1941  FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl();
1942  if (failed(decl))
1943  return failure();
1944 
1945  return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1946  ast::RewriteType::get(ctx));
1947 }
1948 
1949 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
1950  SMRange dotLoc = curToken.getLoc();
1951  consumeToken(Token::dot);
1952 
1953  // Check for code completion of the member name.
1954  if (curToken.is(Token::code_complete))
1955  return codeCompleteMemberAccess(parentExpr);
1956 
1957  // Parse the member name.
1958  Token memberNameTok = curToken;
1959  if (memberNameTok.isNot(Token::identifier, Token::integer) &&
1960  !memberNameTok.isKeyword())
1961  return emitError(dotLoc, "expected identifier or numeric member name");
1962  StringRef memberName = memberNameTok.getSpelling();
1963  SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1964  consumeToken();
1965 
1966  return createMemberAccessExpr(parentExpr, memberName, loc);
1967 }
1968 
1969 FailureOr<ast::Expr *> Parser::parseNegatedExpr() {
1970  consumeToken(Token::kw_not);
1971  // Only native constraints are supported after negation
1972  if (!curToken.is(Token::identifier))
1973  return emitError("expected native constraint");
1974  FailureOr<ast::Expr *> identifierExpr = parseIdentifierExpr();
1975  if (failed(identifierExpr))
1976  return failure();
1977  return parseCallExpr(*identifierExpr, /*isNegated = */ true);
1978 }
1979 
1980 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
1981  SMRange loc = curToken.getLoc();
1982 
1983  // Check for code completion for the dialect name.
1984  if (curToken.is(Token::code_complete))
1985  return codeCompleteDialectName();
1986 
1987  // Handle the case of an no operation name.
1988  if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) {
1989  if (allowEmptyName)
1990  return ast::OpNameDecl::create(ctx, SMRange());
1991  return emitError("expected dialect namespace");
1992  }
1993  StringRef name = curToken.getSpelling();
1994  consumeToken();
1995 
1996  // Otherwise, this is a literal operation name.
1997  if (failed(parseToken(Token::dot, "expected `.` after dialect namespace")))
1998  return failure();
1999 
2000  // Check for code completion for the operation name.
2001  if (curToken.is(Token::code_complete))
2002  return codeCompleteOperationName(name);
2003 
2004  if (curToken.isNot(Token::identifier) && !curToken.isKeyword())
2005  return emitError("expected operation name after dialect namespace");
2006 
2007  name = StringRef(name.data(), name.size() + 1);
2008  do {
2009  name = StringRef(name.data(), name.size() + curToken.getSpelling().size());
2010  loc.End = curToken.getEndLoc();
2011  consumeToken();
2012  } while (curToken.isAny(Token::identifier, Token::dot) ||
2013  curToken.isKeyword());
2014  return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc));
2015 }
2016 
2018 Parser::parseWrappedOperationName(bool allowEmptyName) {
2019  if (!consumeIf(Token::less))
2020  return ast::OpNameDecl::create(ctx, SMRange());
2021 
2022  FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName);
2023  if (failed(opNameDecl))
2024  return failure();
2025 
2026  if (failed(parseToken(Token::greater, "expected `>` after operation name")))
2027  return failure();
2028  return opNameDecl;
2029 }
2030 
2032 Parser::parseOperationExpr(OpResultTypeContext inputResultTypeContext) {
2033  SMRange loc = curToken.getLoc();
2034  consumeToken(Token::kw_op);
2035 
2036  // If it isn't followed by a `<`, the `op` keyword is treated as a normal
2037  // identifier.
2038  if (curToken.isNot(Token::less)) {
2039  resetToken(loc);
2040  return parseIdentifierExpr();
2041  }
2042 
2043  // Parse the operation name. The name may be elided, in which case the
2044  // operation refers to "any" operation(i.e. a difference between `MyOp` and
2045  // `Operation*`). Operation names within a rewrite context must be named.
2046  bool allowEmptyName = parserContext != ParserContext::Rewrite;
2047  FailureOr<ast::OpNameDecl *> opNameDecl =
2048  parseWrappedOperationName(allowEmptyName);
2049  if (failed(opNameDecl))
2050  return failure();
2051  std::optional<StringRef> opName = (*opNameDecl)->getName();
2052 
2053  // Functor used to create an implicit range variable, used for implicit "all"
2054  // operand or results variables.
2055  auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) {
2057  defineVariableDecl("_", loc, type, ast::ConstraintRef(cst, loc));
2058  assert(succeeded(rangeVar) && "expected range variable to be valid");
2059  return ast::DeclRefExpr::create(ctx, loc, *rangeVar, type);
2060  };
2061 
2062  // Check for the optional list of operands.
2063  SmallVector<ast::Expr *> operands;
2064  if (!consumeIf(Token::l_paren)) {
2065  // If the operand list isn't specified and we are in a match context, define
2066  // an inplace unconstrained operand range corresponding to all of the
2067  // operands of the operation. This avoids treating zero operands the same
2068  // way as "unconstrained operands".
2069  if (parserContext != ParserContext::Rewrite) {
2070  operands.push_back(createImplicitRangeVar(
2071  ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy));
2072  }
2073  } else if (!consumeIf(Token::r_paren)) {
2074  // If the operand list was specified and non-empty, parse the operands.
2075  do {
2076  // Check for operand signature code completion.
2077  if (curToken.is(Token::code_complete)) {
2078  codeCompleteOperationOperandsSignature(opName, operands.size());
2079  return failure();
2080  }
2081 
2082  FailureOr<ast::Expr *> operand = parseExpr();
2083  if (failed(operand))
2084  return failure();
2085  operands.push_back(*operand);
2086  } while (consumeIf(Token::comma));
2087 
2088  if (failed(parseToken(Token::r_paren,
2089  "expected `)` after operation operand list")))
2090  return failure();
2091  }
2092 
2093  // Check for the optional list of attributes.
2095  if (consumeIf(Token::l_brace)) {
2096  do {
2098  parseNamedAttributeDecl(opName);
2099  if (failed(decl))
2100  return failure();
2101  attributes.emplace_back(*decl);
2102  } while (consumeIf(Token::comma));
2103 
2104  if (failed(parseToken(Token::r_brace,
2105  "expected `}` after operation attribute list")))
2106  return failure();
2107  }
2108 
2109  // Handle the result types of the operation.
2110  SmallVector<ast::Expr *> resultTypes;
2111  OpResultTypeContext resultTypeContext = inputResultTypeContext;
2112 
2113  // Check for an explicit list of result types.
2114  if (consumeIf(Token::arrow)) {
2115  if (failed(parseToken(Token::l_paren,
2116  "expected `(` before operation result type list")))
2117  return failure();
2118 
2119  // If result types are provided, initially assume that the operation does
2120  // not rely on type inferrence. We don't assert that it isn't, because we
2121  // may be inferring the value of some type/type range variables, but given
2122  // that these variables may be defined in calls we can't always discern when
2123  // this is the case.
2124  resultTypeContext = OpResultTypeContext::Explicit;
2125 
2126  // Handle the case of an empty result list.
2127  if (!consumeIf(Token::r_paren)) {
2128  do {
2129  // Check for result signature code completion.
2130  if (curToken.is(Token::code_complete)) {
2131  codeCompleteOperationResultsSignature(opName, resultTypes.size());
2132  return failure();
2133  }
2134 
2135  FailureOr<ast::Expr *> resultTypeExpr = parseExpr();
2136  if (failed(resultTypeExpr))
2137  return failure();
2138  resultTypes.push_back(*resultTypeExpr);
2139  } while (consumeIf(Token::comma));
2140 
2141  if (failed(parseToken(Token::r_paren,
2142  "expected `)` after operation result type list")))
2143  return failure();
2144  }
2145  } else if (parserContext != ParserContext::Rewrite) {
2146  // If the result list isn't specified and we are in a match context, define
2147  // an inplace unconstrained result range corresponding to all of the results
2148  // of the operation. This avoids treating zero results the same way as
2149  // "unconstrained results".
2150  resultTypes.push_back(createImplicitRangeVar(
2151  ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy));
2152  } else if (resultTypeContext == OpResultTypeContext::Explicit) {
2153  // If the result list isn't specified and we are in a rewrite, try to infer
2154  // them at runtime instead.
2155  resultTypeContext = OpResultTypeContext::Interface;
2156  }
2157 
2158  return createOperationExpr(loc, *opNameDecl, resultTypeContext, operands,
2159  attributes, resultTypes);
2160 }
2161 
2162 FailureOr<ast::Expr *> Parser::parseTupleExpr() {
2163  SMRange loc = curToken.getLoc();
2164  consumeToken(Token::l_paren);
2165 
2166  DenseMap<StringRef, SMRange> usedNames;
2167  SmallVector<StringRef> elementNames;
2168  SmallVector<ast::Expr *> elements;
2169  if (curToken.isNot(Token::r_paren)) {
2170  do {
2171  // Check for the optional element name assignment before the value.
2172  StringRef elementName;
2173  if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
2174  Token elementNameTok = curToken;
2175  consumeToken();
2176 
2177  // The element name is only present if followed by an `=`.
2178  if (consumeIf(Token::equal)) {
2179  elementName = elementNameTok.getSpelling();
2180 
2181  // Check to see if this name is already used.
2182  auto elementNameIt =
2183  usedNames.try_emplace(elementName, elementNameTok.getLoc());
2184  if (!elementNameIt.second) {
2185  return emitErrorAndNote(
2186  elementNameTok.getLoc(),
2187  llvm::formatv("duplicate tuple element label `{0}`",
2188  elementName),
2189  elementNameIt.first->getSecond(),
2190  "see previous label use here");
2191  }
2192  } else {
2193  // Otherwise, we treat this as part of an expression so reset the
2194  // lexer.
2195  resetToken(elementNameTok.getLoc());
2196  }
2197  }
2198  elementNames.push_back(elementName);
2199 
2200  // Parse the tuple element value.
2201  FailureOr<ast::Expr *> element = parseExpr();
2202  if (failed(element))
2203  return failure();
2204  elements.push_back(*element);
2205  } while (consumeIf(Token::comma));
2206  }
2207  loc.End = curToken.getEndLoc();
2208  if (failed(
2209  parseToken(Token::r_paren, "expected `)` after tuple element list")))
2210  return failure();
2211  return createTupleExpr(loc, elements, elementNames);
2212 }
2213 
2214 FailureOr<ast::Expr *> Parser::parseTypeExpr() {
2215  SMRange loc = curToken.getLoc();
2216  consumeToken(Token::kw_type);
2217 
2218  // If we aren't followed by a `<`, the `type` keyword is treated as a normal
2219  // identifier.
2220  if (!consumeIf(Token::less)) {
2221  resetToken(loc);
2222  return parseIdentifierExpr();
2223  }
2224 
2225  if (!curToken.isString())
2226  return emitError("expected string literal containing MLIR type");
2227  std::string attrExpr = curToken.getStringValue();
2228  consumeToken();
2229 
2230  loc.End = curToken.getEndLoc();
2231  if (failed(parseToken(Token::greater, "expected `>` after type literal")))
2232  return failure();
2233  return ast::TypeExpr::create(ctx, loc, attrExpr);
2234 }
2235 
2236 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
2237  StringRef name = curToken.getSpelling();
2238  SMRange nameLoc = curToken.getLoc();
2239  consumeToken(Token::underscore);
2240 
2241  // Underscore expressions require a constraint list.
2242  if (failed(parseToken(Token::colon, "expected `:` after `_` variable")))
2243  return failure();
2244 
2245  // Parse the constraints for the expression.
2246  SmallVector<ast::ConstraintRef> constraints;
2247  if (failed(parseVariableDeclConstraintList(constraints)))
2248  return failure();
2249 
2250  ast::Type type;
2251  if (failed(validateVariableConstraints(constraints, type)))
2252  return failure();
2253  return createInlineVariableExpr(type, name, nameLoc, constraints);
2254 }
2255 
2256 //===----------------------------------------------------------------------===//
2257 // Stmts
2258 
2259 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) {
2261  switch (curToken.getKind()) {
2262  case Token::kw_erase:
2263  stmt = parseEraseStmt();
2264  break;
2265  case Token::kw_let:
2266  stmt = parseLetStmt();
2267  break;
2268  case Token::kw_replace:
2269  stmt = parseReplaceStmt();
2270  break;
2271  case Token::kw_return:
2272  stmt = parseReturnStmt();
2273  break;
2274  case Token::kw_rewrite:
2275  stmt = parseRewriteStmt();
2276  break;
2277  default:
2278  stmt = parseExpr();
2279  break;
2280  }
2281  if (failed(stmt) ||
2282  (expectTerminalSemicolon &&
2283  failed(parseToken(Token::semicolon, "expected `;` after statement"))))
2284  return failure();
2285  return stmt;
2286 }
2287 
2288 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() {
2289  SMLoc startLoc = curToken.getStartLoc();
2290  consumeToken(Token::l_brace);
2291 
2292  // Push a new block scope and parse any nested statements.
2293  pushDeclScope();
2294  SmallVector<ast::Stmt *> statements;
2295  while (curToken.isNot(Token::r_brace)) {
2296  FailureOr<ast::Stmt *> statement = parseStmt();
2297  if (failed(statement))
2298  return popDeclScope(), failure();
2299  statements.push_back(*statement);
2300  }
2301  popDeclScope();
2302 
2303  // Consume the end brace.
2304  SMRange location(startLoc, curToken.getEndLoc());
2305  consumeToken(Token::r_brace);
2306 
2307  return ast::CompoundStmt::create(ctx, location, statements);
2308 }
2309 
2310 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
2311  if (parserContext == ParserContext::Constraint)
2312  return emitError("`erase` cannot be used within a Constraint");
2313  SMRange loc = curToken.getLoc();
2314  consumeToken(Token::kw_erase);
2315 
2316  // Parse the root operation expression.
2317  FailureOr<ast::Expr *> rootOp = parseExpr();
2318  if (failed(rootOp))
2319  return failure();
2320 
2321  return createEraseStmt(loc, *rootOp);
2322 }
2323 
2324 FailureOr<ast::LetStmt *> Parser::parseLetStmt() {
2325  SMRange loc = curToken.getLoc();
2326  consumeToken(Token::kw_let);
2327 
2328  // Parse the name of the new variable.
2329  SMRange varLoc = curToken.getLoc();
2330  if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) {
2331  // `_` is a reserved variable name.
2332  if (curToken.is(Token::underscore)) {
2333  return emitError(varLoc,
2334  "`_` may only be used to define \"inline\" variables");
2335  }
2336  return emitError(varLoc,
2337  "expected identifier after `let` to name a new variable");
2338  }
2339  StringRef varName = curToken.getSpelling();
2340  consumeToken();
2341 
2342  // Parse the optional set of constraints.
2343  SmallVector<ast::ConstraintRef> constraints;
2344  if (consumeIf(Token::colon) &&
2345  failed(parseVariableDeclConstraintList(constraints)))
2346  return failure();
2347 
2348  // Parse the optional initializer expression.
2349  ast::Expr *initializer = nullptr;
2350  if (consumeIf(Token::equal)) {
2351  FailureOr<ast::Expr *> initOrFailure = parseExpr();
2352  if (failed(initOrFailure))
2353  return failure();
2354  initializer = *initOrFailure;
2355 
2356  // Check that the constraints are compatible with having an initializer,
2357  // e.g. type constraints cannot be used with initializers.
2358  for (ast::ConstraintRef constraint : constraints) {
2359  LogicalResult result =
2360  TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint)
2362  ast::ValueRangeConstraintDecl>([&](const auto *cst) {
2363  if (cst->getTypeExpr()) {
2364  return this->emitError(
2365  constraint.referenceLoc,
2366  "type constraints are not permitted on variables with "
2367  "initializers");
2368  }
2369  return success();
2370  })
2371  .Default(success());
2372  if (failed(result))
2373  return failure();
2374  }
2375  }
2376 
2378  createVariableDecl(varName, varLoc, initializer, constraints);
2379  if (failed(varDecl))
2380  return failure();
2381  return ast::LetStmt::create(ctx, loc, *varDecl);
2382 }
2383 
2384 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
2385  if (parserContext == ParserContext::Constraint)
2386  return emitError("`replace` cannot be used within a Constraint");
2387  SMRange loc = curToken.getLoc();
2388  consumeToken(Token::kw_replace);
2389 
2390  // Parse the root operation expression.
2391  FailureOr<ast::Expr *> rootOp = parseExpr();
2392  if (failed(rootOp))
2393  return failure();
2394 
2395  if (failed(
2396  parseToken(Token::kw_with, "expected `with` after root operation")))
2397  return failure();
2398 
2399  // The replacement portion of this statement is within a rewrite context.
2400  llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
2401 
2402  // Parse the replacement values.
2403  SmallVector<ast::Expr *> replValues;
2404  if (consumeIf(Token::l_paren)) {
2405  if (consumeIf(Token::r_paren)) {
2406  return emitError(
2407  loc, "expected at least one replacement value, consider using "
2408  "`erase` if no replacement values are desired");
2409  }
2410 
2411  do {
2412  FailureOr<ast::Expr *> replExpr = parseExpr();
2413  if (failed(replExpr))
2414  return failure();
2415  replValues.emplace_back(*replExpr);
2416  } while (consumeIf(Token::comma));
2417 
2418  if (failed(parseToken(Token::r_paren,
2419  "expected `)` after replacement values")))
2420  return failure();
2421  } else {
2422  // Handle replacement with an operation uniquely, as the replacement
2423  // operation supports type inferrence from the root operation.
2424  FailureOr<ast::Expr *> replExpr;
2425  if (curToken.is(Token::kw_op))
2426  replExpr = parseOperationExpr(OpResultTypeContext::Replacement);
2427  else
2428  replExpr = parseExpr();
2429  if (failed(replExpr))
2430  return failure();
2431  replValues.emplace_back(*replExpr);
2432  }
2433 
2434  return createReplaceStmt(loc, *rootOp, replValues);
2435 }
2436 
2437 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() {
2438  SMRange loc = curToken.getLoc();
2439  consumeToken(Token::kw_return);
2440 
2441  // Parse the result value.
2442  FailureOr<ast::Expr *> resultExpr = parseExpr();
2443  if (failed(resultExpr))
2444  return failure();
2445 
2446  return ast::ReturnStmt::create(ctx, loc, *resultExpr);
2447 }
2448 
2449 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
2450  if (parserContext == ParserContext::Constraint)
2451  return emitError("`rewrite` cannot be used within a Constraint");
2452  SMRange loc = curToken.getLoc();
2453  consumeToken(Token::kw_rewrite);
2454 
2455  // Parse the root operation.
2456  FailureOr<ast::Expr *> rootOp = parseExpr();
2457  if (failed(rootOp))
2458  return failure();
2459 
2460  if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body")))
2461  return failure();
2462 
2463  if (curToken.isNot(Token::l_brace))
2464  return emitError("expected `{` to start rewrite body");
2465 
2466  // The rewrite body of this statement is within a rewrite context.
2467  llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
2468 
2469  FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt();
2470  if (failed(rewriteBody))
2471  return failure();
2472 
2473  // Verify the rewrite body.
2474  for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) {
2475  if (isa<ast::ReturnStmt>(stmt)) {
2476  return emitError(stmt->getLoc(),
2477  "`return` statements are only permitted within a "
2478  "`Constraint` or `Rewrite` body");
2479  }
2480  }
2481 
2482  return createRewriteStmt(loc, *rootOp, *rewriteBody);
2483 }
2484 
2485 //===----------------------------------------------------------------------===//
2486 // Creation+Analysis
2487 //===----------------------------------------------------------------------===//
2488 
2489 //===----------------------------------------------------------------------===//
2490 // Decls
2491 
2492 ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) {
2493  // Unwrap reference expressions.
2494  if (auto *init = dyn_cast<ast::DeclRefExpr>(node))
2495  node = init->getDecl();
2496  return dyn_cast<ast::CallableDecl>(node);
2497 }
2498 
2500 Parser::createPatternDecl(SMRange loc, const ast::Name *name,
2501  const ParsedPatternMetadata &metadata,
2502  ast::CompoundStmt *body) {
2503  return ast::PatternDecl::create(ctx, loc, name, metadata.benefit,
2504  metadata.hasBoundedRecursion, body);
2505 }
2506 
2507 ast::Type Parser::createUserConstraintRewriteResultType(
2509  // Single result decls use the type of the single result.
2510  if (results.size() == 1)
2511  return results[0]->getType();
2512 
2513  // Multiple results use a tuple type, with the types and names grabbed from
2514  // the result variable decls.
2515  auto resultTypes = llvm::map_range(
2516  results, [&](const auto *result) { return result->getType(); });
2517  auto resultNames = llvm::map_range(
2518  results, [&](const auto *result) { return result->getName().getName(); });
2519  return ast::TupleType::get(ctx, llvm::to_vector(resultTypes),
2520  llvm::to_vector(resultNames));
2521 }
2522 
2523 template <typename T>
2524 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl(
2525  const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
2526  ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
2527  ast::CompoundStmt *body) {
2528  if (!body->getChildren().empty()) {
2529  if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) {
2530  ast::Expr *resultExpr = retStmt->getResultExpr();
2531 
2532  // Process the result of the decl. If no explicit signature results
2533  // were provided, check for return type inference. Otherwise, check that
2534  // the return expression can be converted to the expected type.
2535  if (results.empty())
2536  resultType = resultExpr->getType();
2537  else if (failed(convertExpressionTo(resultExpr, resultType)))
2538  return failure();
2539  else
2540  retStmt->setResultExpr(resultExpr);
2541  }
2542  }
2543  return T::createPDLL(ctx, name, arguments, results, body, resultType);
2544 }
2545 
2547 Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
2548  ArrayRef<ast::ConstraintRef> constraints) {
2549  // The type of the variable, which is expected to be inferred by either a
2550  // constraint or an initializer expression.
2551  ast::Type type;
2552  if (failed(validateVariableConstraints(constraints, type)))
2553  return failure();
2554 
2555  if (initializer) {
2556  // Update the variable type based on the initializer, or try to convert the
2557  // initializer to the existing type.
2558  if (!type)
2559  type = initializer->getType();
2560  else if (ast::Type mergedType = type.refineWith(initializer->getType()))
2561  type = mergedType;
2562  else if (failed(convertExpressionTo(initializer, type)))
2563  return failure();
2564 
2565  // Otherwise, if there is no initializer check that the type has already
2566  // been resolved from the constraint list.
2567  } else if (!type) {
2568  return emitErrorAndNote(
2569  loc, "unable to infer type for variable `" + name + "`", loc,
2570  "the type of a variable must be inferable from the constraint "
2571  "list or the initializer");
2572  }
2573 
2574  // Constraint types cannot be used when defining variables.
2575  if (type.isa<ast::ConstraintType, ast::RewriteType>()) {
2576  return emitError(
2577  loc, llvm::formatv("unable to define variable of `{0}` type", type));
2578  }
2579 
2580  // Try to define a variable with the given name.
2582  defineVariableDecl(name, loc, type, initializer, constraints);
2583  if (failed(varDecl))
2584  return failure();
2585 
2586  return *varDecl;
2587 }
2588 
2590 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
2591  const ast::ConstraintRef &constraint) {
2592  ast::Type argType;
2593  if (failed(validateVariableConstraint(constraint, argType)))
2594  return failure();
2595  return defineVariableDecl(name, loc, argType, constraint);
2596 }
2597 
2599 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
2600  ast::Type &inferredType) {
2601  for (const ast::ConstraintRef &ref : constraints)
2602  if (failed(validateVariableConstraint(ref, inferredType)))
2603  return failure();
2604  return success();
2605 }
2606 
2607 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
2608  ast::Type &inferredType) {
2609  ast::Type constraintType;
2610  if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) {
2611  if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2612  if (failed(validateTypeConstraintExpr(typeExpr)))
2613  return failure();
2614  }
2615  constraintType = ast::AttributeType::get(ctx);
2616  } else if (const auto *cst =
2617  dyn_cast<ast::OpConstraintDecl>(ref.constraint)) {
2618  constraintType = ast::OperationType::get(
2619  ctx, cst->getName(), lookupODSOperation(cst->getName()));
2620  } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) {
2621  constraintType = typeTy;
2622  } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) {
2623  constraintType = typeRangeTy;
2624  } else if (const auto *cst =
2625  dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) {
2626  if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2627  if (failed(validateTypeConstraintExpr(typeExpr)))
2628  return failure();
2629  }
2630  constraintType = valueTy;
2631  } else if (const auto *cst =
2632  dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) {
2633  if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2634  if (failed(validateTypeRangeConstraintExpr(typeExpr)))
2635  return failure();
2636  }
2637  constraintType = valueRangeTy;
2638  } else if (const auto *cst =
2639  dyn_cast<ast::UserConstraintDecl>(ref.constraint)) {
2640  ArrayRef<ast::VariableDecl *> inputs = cst->getInputs();
2641  if (inputs.size() != 1) {
2642  return emitErrorAndNote(ref.referenceLoc,
2643  "`Constraint`s applied via a variable constraint "
2644  "list must take a single input, but got " +
2645  Twine(inputs.size()),
2646  cst->getLoc(),
2647  "see definition of constraint here");
2648  }
2649  constraintType = inputs.front()->getType();
2650  } else {
2651  llvm_unreachable("unknown constraint type");
2652  }
2653 
2654  // Check that the constraint type is compatible with the current inferred
2655  // type.
2656  if (!inferredType) {
2657  inferredType = constraintType;
2658  } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) {
2659  inferredType = mergedTy;
2660  } else {
2661  return emitError(ref.referenceLoc,
2662  llvm::formatv("constraint type `{0}` is incompatible "
2663  "with the previously inferred type `{1}`",
2664  constraintType, inferredType));
2665  }
2666  return success();
2667 }
2668 
2669 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) {
2670  ast::Type typeExprType = typeExpr->getType();
2671  if (typeExprType != typeTy) {
2672  return emitError(typeExpr->getLoc(),
2673  "expected expression of `Type` in type constraint");
2674  }
2675  return success();
2676 }
2677 
2679 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) {
2680  ast::Type typeExprType = typeExpr->getType();
2681  if (typeExprType != typeRangeTy) {
2682  return emitError(typeExpr->getLoc(),
2683  "expected expression of `TypeRange` in type constraint");
2684  }
2685  return success();
2686 }
2687 
2688 //===----------------------------------------------------------------------===//
2689 // Exprs
2690 
2692 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
2693  MutableArrayRef<ast::Expr *> arguments, bool isNegated) {
2694  ast::Type parentType = parentExpr->getType();
2695 
2696  ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr);
2697  if (!callableDecl) {
2698  return emitError(loc,
2699  llvm::formatv("expected a reference to a callable "
2700  "`Constraint` or `Rewrite`, but got: `{0}`",
2701  parentType));
2702  }
2703  if (parserContext == ParserContext::Rewrite) {
2704  if (isa<ast::UserConstraintDecl>(callableDecl))
2705  return emitError(
2706  loc, "unable to invoke `Constraint` within a rewrite section");
2707  if (isNegated)
2708  return emitError(loc, "unable to negate a Rewrite");
2709  } else {
2710  if (isa<ast::UserRewriteDecl>(callableDecl))
2711  return emitError(loc,
2712  "unable to invoke `Rewrite` within a match section");
2713  if (isNegated && cast<ast::UserConstraintDecl>(callableDecl)->getBody())
2714  return emitError(loc, "unable to negate non native constraints");
2715  }
2716 
2717  // Verify the arguments of the call.
2718  /// Handle size mismatch.
2719  ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs();
2720  if (callArgs.size() != arguments.size()) {
2721  return emitErrorAndNote(
2722  loc,
2723  llvm::formatv("invalid number of arguments for {0} call; expected "
2724  "{1}, but got {2}",
2725  callableDecl->getCallableType(), callArgs.size(),
2726  arguments.size()),
2727  callableDecl->getLoc(),
2728  llvm::formatv("see the definition of {0} here",
2729  callableDecl->getName()->getName()));
2730  }
2731 
2732  /// Handle argument type mismatch.
2733  auto attachDiagFn = [&](ast::Diagnostic &diag) {
2734  diag.attachNote(llvm::formatv("see the definition of `{0}` here",
2735  callableDecl->getName()->getName()),
2736  callableDecl->getLoc());
2737  };
2738  for (auto it : llvm::zip(callArgs, arguments)) {
2739  if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(),
2740  attachDiagFn)))
2741  return failure();
2742  }
2743 
2744  return ast::CallExpr::create(ctx, loc, parentExpr, arguments,
2745  callableDecl->getResultType(), isNegated);
2746 }
2747 
2748 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,
2749  ast::Decl *decl) {
2750  // Check the type of decl being referenced.
2751  ast::Type declType;
2752  if (isa<ast::ConstraintDecl>(decl))
2753  declType = ast::ConstraintType::get(ctx);
2754  else if (isa<ast::UserRewriteDecl>(decl))
2755  declType = ast::RewriteType::get(ctx);
2756  else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
2757  declType = varDecl->getType();
2758  else
2759  return emitError(loc, "invalid reference to `" +
2760  decl->getName()->getName() + "`");
2761 
2762  return ast::DeclRefExpr::create(ctx, loc, decl, declType);
2763 }
2764 
2766 Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
2767  ArrayRef<ast::ConstraintRef> constraints) {
2769  defineVariableDecl(name, loc, type, constraints);
2770  if (failed(decl))
2771  return failure();
2772  return ast::DeclRefExpr::create(ctx, loc, *decl, type);
2773 }
2774 
2776 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
2777  SMRange loc) {
2778  // Validate the member name for the given parent expression.
2779  FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc);
2780  if (failed(memberType))
2781  return failure();
2782 
2783  return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType);
2784 }
2785 
2786 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
2787  StringRef name, SMRange loc) {
2788  ast::Type parentType = parentExpr->getType();
2789  if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) {
2791  return valueRangeTy;
2792 
2793  // Verify member access based on the operation type.
2794  if (const ods::Operation *odsOp = opType.getODSOperation()) {
2795  auto results = odsOp->getResults();
2796 
2797  // Handle indexed results.
2798  unsigned index = 0;
2799  if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2800  index < results.size()) {
2801  return results[index].isVariadic() ? valueRangeTy : valueTy;
2802  }
2803 
2804  // Handle named results.
2805  const auto *it = llvm::find_if(results, [&](const auto &result) {
2806  return result.getName() == name;
2807  });
2808  if (it != results.end())
2809  return it->isVariadic() ? valueRangeTy : valueTy;
2810  } else if (llvm::isDigit(name[0])) {
2811  // Allow unchecked numeric indexing of the results of unregistered
2812  // operations. It returns a single value.
2813  return valueTy;
2814  }
2815  } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) {
2816  // Handle indexed results.
2817  unsigned index = 0;
2818  if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2819  index < tupleType.size()) {
2820  return tupleType.getElementTypes()[index];
2821  }
2822 
2823  // Handle named results.
2824  auto elementNames = tupleType.getElementNames();
2825  const auto *it = llvm::find(elementNames, name);
2826  if (it != elementNames.end())
2827  return tupleType.getElementTypes()[it - elementNames.begin()];
2828  }
2829  return emitError(
2830  loc,
2831  llvm::formatv("invalid member access `{0}` on expression of type `{1}`",
2832  name, parentType));
2833 }
2834 
2835 FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
2836  SMRange loc, const ast::OpNameDecl *name,
2837  OpResultTypeContext resultTypeContext,
2838  SmallVectorImpl<ast::Expr *> &operands,
2840  SmallVectorImpl<ast::Expr *> &results) {
2841  std::optional<StringRef> opNameRef = name->getName();
2842  const ods::Operation *odsOp = lookupODSOperation(opNameRef);
2843 
2844  // Verify the inputs operands.
2845  if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands)))
2846  return failure();
2847 
2848  // Verify the attribute list.
2849  for (ast::NamedAttributeDecl *attr : attributes) {
2850  // Check for an attribute type, or a type awaiting resolution.
2851  ast::Type attrType = attr->getValue()->getType();
2852  if (!attrType.isa<ast::AttributeType>()) {
2853  return emitError(
2854  attr->getValue()->getLoc(),
2855  llvm::formatv("expected `Attr` expression, but got `{0}`", attrType));
2856  }
2857  }
2858 
2859  assert(
2860  (resultTypeContext == OpResultTypeContext::Explicit || results.empty()) &&
2861  "unexpected inferrence when results were explicitly specified");
2862 
2863  // If we aren't relying on type inferrence, or explicit results were provided,
2864  // validate them.
2865  if (resultTypeContext == OpResultTypeContext::Explicit) {
2866  if (failed(validateOperationResults(loc, opNameRef, odsOp, results)))
2867  return failure();
2868 
2869  // Validate the use of interface based type inferrence for this operation.
2870  } else if (resultTypeContext == OpResultTypeContext::Interface) {
2871  assert(opNameRef &&
2872  "expected valid operation name when inferring operation results");
2873  checkOperationResultTypeInferrence(loc, *opNameRef, odsOp);
2874  }
2875 
2876  return ast::OperationExpr::create(ctx, loc, odsOp, name, operands, results,
2877  attributes);
2878 }
2879 
2881 Parser::validateOperationOperands(SMRange loc, std::optional<StringRef> name,
2882  const ods::Operation *odsOp,
2883  SmallVectorImpl<ast::Expr *> &operands) {
2884  return validateOperationOperandsOrResults(
2885  "operand", loc, odsOp ? odsOp->getLoc() : std::optional<SMRange>(), name,
2886  operands, odsOp ? odsOp->getOperands() : std::nullopt, valueTy,
2887  valueRangeTy);
2888 }
2889 
2891 Parser::validateOperationResults(SMRange loc, std::optional<StringRef> name,
2892  const ods::Operation *odsOp,
2893  SmallVectorImpl<ast::Expr *> &results) {
2894  return validateOperationOperandsOrResults(
2895  "result", loc, odsOp ? odsOp->getLoc() : std::optional<SMRange>(), name,
2896  results, odsOp ? odsOp->getResults() : std::nullopt, typeTy, typeRangeTy);
2897 }
2898 
2899 void Parser::checkOperationResultTypeInferrence(SMRange loc, StringRef opName,
2900  const ods::Operation *odsOp) {
2901  // If the operation might not have inferrence support, emit a warning to the
2902  // user. We don't emit an error because the interface might be added to the
2903  // operation at runtime. It's rare, but it could still happen. We emit a
2904  // warning here instead.
2905 
2906  // Handle inferrence warnings for unknown operations.
2907  if (!odsOp) {
2908  ctx.getDiagEngine().emitWarning(
2909  loc, llvm::formatv(
2910  "operation result types are marked to be inferred, but "
2911  "`{0}` is unknown. Ensure that `{0}` supports zero "
2912  "results or implements `InferTypeOpInterface`. Include "
2913  "the ODS definition of this operation to remove this warning.",
2914  opName));
2915  return;
2916  }
2917 
2918  // Handle inferrence warnings for known operations that expected at least one
2919  // result, but don't have inference support. An elided results list can mean
2920  // "zero-results", and we don't want to warn when that is the expected
2921  // behavior.
2922  bool requiresInferrence =
2923  llvm::any_of(odsOp->getResults(), [](const ods::OperandOrResult &result) {
2924  return !result.isVariableLength();
2925  });
2926  if (requiresInferrence && !odsOp->hasResultTypeInferrence()) {
2927  ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitWarning(
2928  loc,
2929  llvm::formatv("operation result types are marked to be inferred, but "
2930  "`{0}` does not provide an implementation of "
2931  "`InferTypeOpInterface`. Ensure that `{0}` attaches "
2932  "`InferTypeOpInterface` at runtime, or add support to "
2933  "the ODS definition to remove this warning.",
2934  opName));
2935  diag->attachNote(llvm::formatv("see the definition of `{0}` here", opName),
2936  odsOp->getLoc());
2937  return;
2938  }
2939 }
2940 
2941 LogicalResult Parser::validateOperationOperandsOrResults(
2942  StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc,
2943  std::optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values,
2944  ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
2945  ast::RangeType rangeTy) {
2946  // All operation types accept a single range parameter.
2947  if (values.size() == 1) {
2948  if (failed(convertExpressionTo(values[0], rangeTy)))
2949  return failure();
2950  return success();
2951  }
2952 
2953  /// If the operation has ODS information, we can more accurately verify the
2954  /// values.
2955  if (odsOpLoc) {
2956  auto emitSizeMismatchError = [&] {
2957  return emitErrorAndNote(
2958  loc,
2959  llvm::formatv("invalid number of {0} groups for `{1}`; expected "
2960  "{2}, but got {3}",
2961  groupName, *name, odsValues.size(), values.size()),
2962  *odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name));
2963  };
2964 
2965  // Handle the case where no values were provided.
2966  if (values.empty()) {
2967  // If we don't expect any on the ODS side, we are done.
2968  if (odsValues.empty())
2969  return success();
2970 
2971  // If we do, check if we actually need to provide values (i.e. if any of
2972  // the values are actually required).
2973  unsigned numVariadic = 0;
2974  for (const auto &odsValue : odsValues) {
2975  if (!odsValue.isVariableLength())
2976  return emitSizeMismatchError();
2977  ++numVariadic;
2978  }
2979 
2980  // If we are in a non-rewrite context, we don't need to do anything more.
2981  // Zero-values is a valid constraint on the operation.
2982  if (parserContext != ParserContext::Rewrite)
2983  return success();
2984 
2985  // Otherwise, when in a rewrite we may need to provide values to match the
2986  // ODS signature of the operation to create.
2987 
2988  // If we only have one variadic value, just use an empty list.
2989  if (numVariadic == 1)
2990  return success();
2991 
2992  // Otherwise, create dummy values for each of the entries so that we
2993  // adhere to the ODS signature.
2994  for (unsigned i = 0, e = odsValues.size(); i < e; ++i) {
2995  values.push_back(ast::RangeExpr::create(
2996  ctx, loc, /*elements=*/std::nullopt, rangeTy));
2997  }
2998  return success();
2999  }
3000 
3001  // Verify that the number of values provided matches the number of value
3002  // groups ODS expects.
3003  if (odsValues.size() != values.size())
3004  return emitSizeMismatchError();
3005 
3006  auto diagFn = [&](ast::Diagnostic &diag) {
3007  diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name),
3008  *odsOpLoc);
3009  };
3010  for (unsigned i = 0, e = values.size(); i < e; ++i) {
3011  ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy;
3012  if (failed(convertExpressionTo(values[i], expectedType, diagFn)))
3013  return failure();
3014  }
3015  return success();
3016  }
3017 
3018  // Otherwise, accept the value groups as they have been defined and just
3019  // ensure they are one of the expected types.
3020  for (ast::Expr *&valueExpr : values) {
3021  ast::Type valueExprType = valueExpr->getType();
3022 
3023  // Check if this is one of the expected types.
3024  if (valueExprType == rangeTy || valueExprType == singleTy)
3025  continue;
3026 
3027  // If the operand is an Operation, allow converting to a Value or
3028  // ValueRange. This situations arises quite often with nested operation
3029  // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)`
3030  if (singleTy == valueTy) {
3031  if (valueExprType.isa<ast::OperationType>()) {
3032  valueExpr = convertOpToValue(valueExpr);
3033  continue;
3034  }
3035  }
3036 
3037  // Otherwise, try to convert the expression to a range.
3038  if (succeeded(convertExpressionTo(valueExpr, rangeTy)))
3039  continue;
3040 
3041  return emitError(
3042  valueExpr->getLoc(),
3043  llvm::formatv(
3044  "expected `{0}` or `{1}` convertible expression, but got `{2}`",
3045  singleTy, rangeTy, valueExprType));
3046  }
3047  return success();
3048 }
3049 
3051 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements,
3052  ArrayRef<StringRef> elementNames) {
3053  for (const ast::Expr *element : elements) {
3054  ast::Type eleTy = element->getType();
3056  return emitError(
3057  element->getLoc(),
3058  llvm::formatv("unable to build a tuple with `{0}` element", eleTy));
3059  }
3060  }
3061  return ast::TupleExpr::create(ctx, loc, elements, elementNames);
3062 }
3063 
3064 //===----------------------------------------------------------------------===//
3065 // Stmts
3066 
3067 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc,
3068  ast::Expr *rootOp) {
3069  // Check that root is an Operation.
3070  ast::Type rootType = rootOp->getType();
3071  if (!rootType.isa<ast::OperationType>())
3072  return emitError(rootOp->getLoc(), "expected `Op` expression");
3073 
3074  return ast::EraseStmt::create(ctx, loc, rootOp);
3075 }
3076 
3078 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp,
3079  MutableArrayRef<ast::Expr *> replValues) {
3080  // Check that root is an Operation.
3081  ast::Type rootType = rootOp->getType();
3082  if (!rootType.isa<ast::OperationType>()) {
3083  return emitError(
3084  rootOp->getLoc(),
3085  llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
3086  }
3087 
3088  // If there are multiple replacement values, we implicitly convert any Op
3089  // expressions to the value form.
3090  bool shouldConvertOpToValues = replValues.size() > 1;
3091  for (ast::Expr *&replExpr : replValues) {
3092  ast::Type replType = replExpr->getType();
3093 
3094  // Check that replExpr is an Operation, Value, or ValueRange.
3095  if (replType.isa<ast::OperationType>()) {
3096  if (shouldConvertOpToValues)
3097  replExpr = convertOpToValue(replExpr);
3098  continue;
3099  }
3100 
3101  if (replType != valueTy && replType != valueRangeTy) {
3102  return emitError(replExpr->getLoc(),
3103  llvm::formatv("expected `Op`, `Value` or `ValueRange` "
3104  "expression, but got `{0}`",
3105  replType));
3106  }
3107  }
3108 
3109  return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues);
3110 }
3111 
3113 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp,
3114  ast::CompoundStmt *rewriteBody) {
3115  // Check that root is an Operation.
3116  ast::Type rootType = rootOp->getType();
3117  if (!rootType.isa<ast::OperationType>()) {
3118  return emitError(
3119  rootOp->getLoc(),
3120  llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
3121  }
3122 
3123  return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody);
3124 }
3125 
3126 //===----------------------------------------------------------------------===//
3127 // Code Completion
3128 //===----------------------------------------------------------------------===//
3129 
3130 LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) {
3131  ast::Type parentType = parentExpr->getType();
3132  if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>())
3133  codeCompleteContext->codeCompleteOperationMemberAccess(opType);
3134  else if (ast::TupleType tupleType = parentType.dyn_cast<ast::TupleType>())
3135  codeCompleteContext->codeCompleteTupleMemberAccess(tupleType);
3136  return failure();
3137 }
3138 
3140 Parser::codeCompleteAttributeName(std::optional<StringRef> opName) {
3141  if (opName)
3142  codeCompleteContext->codeCompleteOperationAttributeName(*opName);
3143  return failure();
3144 }
3145 
3147 Parser::codeCompleteConstraintName(ast::Type inferredType,
3148  bool allowInlineTypeConstraints) {
3149  codeCompleteContext->codeCompleteConstraintName(
3150  inferredType, allowInlineTypeConstraints, curDeclScope);
3151  return failure();
3152 }
3153 
3154 LogicalResult Parser::codeCompleteDialectName() {
3155  codeCompleteContext->codeCompleteDialectName();
3156  return failure();
3157 }
3158 
3159 LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) {
3160  codeCompleteContext->codeCompleteOperationName(dialectName);
3161  return failure();
3162 }
3163 
3164 LogicalResult Parser::codeCompletePatternMetadata() {
3165  codeCompleteContext->codeCompletePatternMetadata();
3166  return failure();
3167 }
3168 
3169 LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) {
3170  codeCompleteContext->codeCompleteIncludeFilename(curPath);
3171  return failure();
3172 }
3173 
3174 void Parser::codeCompleteCallSignature(ast::Node *parent,
3175  unsigned currentNumArgs) {
3176  ast::CallableDecl *callableDecl = tryExtractCallableDecl(parent);
3177  if (!callableDecl)
3178  return;
3179 
3180  codeCompleteContext->codeCompleteCallSignature(callableDecl, currentNumArgs);
3181 }
3182 
3183 void Parser::codeCompleteOperationOperandsSignature(
3184  std::optional<StringRef> opName, unsigned currentNumOperands) {
3185  codeCompleteContext->codeCompleteOperationOperandsSignature(
3186  opName, currentNumOperands);
3187 }
3188 
3189 void Parser::codeCompleteOperationResultsSignature(
3190  std::optional<StringRef> opName, unsigned currentNumResults) {
3191  codeCompleteContext->codeCompleteOperationResultsSignature(opName,
3192  currentNumResults);
3193 }
3194 
3195 //===----------------------------------------------------------------------===//
3196 // Parser
3197 //===----------------------------------------------------------------------===//
3198 
3200 mlir::pdll::parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
3201  bool enableDocumentation,
3202  CodeCompleteContext *codeCompleteContext) {
3203  Parser parser(ctx, sourceMgr, enableDocumentation, codeCompleteContext);
3204  return parser.parseModule();
3205 }
static std::string diag(const llvm::Value &value)
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class breaks up the current file into a token stream.
Definition: Lexer.h:23
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
This represents a token in the MLIR syntax.
Definition: Token.h:20
bool isKeyword() const
Return true if this is one of the keyword token kinds (e.g. kw_if).
Definition: Token.cpp:192
SMLoc getLoc() const
Definition: Token.cpp:24
bool isNot(Kind k) const
Definition: Token.h:50
StringRef getSpelling() const
Definition: Token.h:34
This class provides an abstract interface into the parser for hooking in code completion events.
Definition: CodeComplete.h:30
@ code_complete_string
Token signifying a code completion location within a string.
Definition: Lexer.h:43
@ directive
Tokens.
Definition: Lexer.h:95
@ eof
Markers.
Definition: Lexer.h:38
@ code_complete
Token signifying a code completion location.
Definition: Lexer.h:41
@ arrow
Punctuation.
Definition: Lexer.h:76
@ less
Paired punctuation.
Definition: Lexer.h:84
@ kw_Attr
General keywords.
Definition: Lexer.h:56
static StringRef getMemberName()
Return the member name used for the "all-results" access.
Definition: Nodes.h:487
static AllResultsMemberAccessExpr * create(Context &ctx, SMRange loc, const Expr *parentExpr, Type type)
Definition: Nodes.h:489
The class represents an Attribute constraint, and constrains a variable to be an Attribute.
Definition: Nodes.h:749
static AttrConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr=nullptr)
Definition: Nodes.cpp:390
static AttributeExpr * create(Context &ctx, SMRange loc, StringRef value)
Definition: Nodes.cpp:258
This class represents a PDLL type that corresponds to an mlir::Attribute.
Definition: Types.h:131
static AttributeType get(Context &context)
Return an instance of the Attribute type.
Definition: Types.cpp:57
static CallExpr * create(Context &ctx, SMRange loc, Expr *callable, ArrayRef< Expr * > arguments, Type resultType, bool isNegated=false)
Definition: Nodes.cpp:268
This decl represents a shared interface for all callable decls.
Definition: Nodes.h:1188
Type getResultType() const
Return the result type of this decl.
Definition: Nodes.h:1206
ArrayRef< VariableDecl * > getInputs() const
Return the inputs of this decl.
Definition: Nodes.h:1199
StringRef getCallableType() const
Return the callable type of this decl.
Definition: Nodes.h:1191
This statement represents a compound statement, which contains a collection of other statements.
Definition: Nodes.h:179
ArrayRef< Stmt * >::iterator begin() const
Definition: Nodes.h:191
MutableArrayRef< Stmt * > getChildren()
Return the children of this compound statement.
Definition: Nodes.h:185
ArrayRef< Stmt * >::iterator end() const
Definition: Nodes.h:192
static CompoundStmt * create(Context &ctx, SMRange location, ArrayRef< Stmt * > children)
Definition: Nodes.cpp:192
This class represents the base of all AST Constraint decls.
Definition: Nodes.h:704
This class represents a PDLL type that corresponds to a constraint.
Definition: Types.h:145
static ConstraintType get(Context &context)
Return an instance of the Constraint type.
Definition: Types.cpp:65
This class represents the main context of the PDLL AST.
Definition: Context.h:25
static DeclRefExpr * create(Context &ctx, SMRange loc, Decl *decl, Type type)
Definition: Nodes.cpp:285
This class represents a scope for named AST decls.
Definition: Nodes.h:64
DeclScope * getParentScope()
Return the parent scope of this scope, or nullptr if there is no parent.
Definition: Nodes.h:70
void add(Decl *decl)
Add a new decl to the scope.
Definition: Nodes.cpp:175
This class represents the base Decl node.
Definition: Nodes.h:669
const Name * getName() const
Return the name of the decl, or nullptr if it doesn't have one.
Definition: Nodes.h:672
This class provides a simple implementation of a PDLL diagnostic.
Definition: Diagnostic.h:30
static EraseStmt * create(Context &ctx, SMRange loc, Expr *rootOp)
Definition: Nodes.cpp:218
This class represents a base AST Expression node.
Definition: Nodes.h:345
Type getType() const
Return the type of this expression.
Definition: Nodes.h:348
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostic.h:84
static LetStmt * create(Context &ctx, SMRange loc, VariableDecl *varDecl)
Definition: Nodes.cpp:207
static MemberAccessExpr * create(Context &ctx, SMRange loc, const Expr *parentExpr, StringRef memberName, Type type)
Definition: Nodes.cpp:295
static Module * create(Context &ctx, SMLoc loc, ArrayRef< Decl * > children)
Definition: Nodes.cpp:576
This Decl represents a NamedAttribute, and contains a string name and attribute value.
Definition: Nodes.h:992
static NamedAttributeDecl * create(Context &ctx, const Name &name, Expr *value)
Definition: Nodes.cpp:499
This class represents a base AST node.
Definition: Nodes.h:108
SMRange getLoc() const
Return the location of this node.
Definition: Nodes.h:131
static OpConstraintDecl * create(Context &ctx, SMRange loc, const OpNameDecl *nameDecl=nullptr)
Definition: Nodes.cpp:400
This Decl represents an OperationName.
Definition: Nodes.h:1016
std::optional< StringRef > getName() const
Return the name of this operation, or std::nullopt if the name is unknown.
Definition: Nodes.h:1022
static OpNameDecl * create(Context &ctx, const Name &name)
Definition: Nodes.cpp:509
static OperationExpr * create(Context &ctx, SMRange loc, const ods::Operation *odsOp, const OpNameDecl *nameDecl, ArrayRef< Expr * > operands, ArrayRef< Expr * > resultTypes, ArrayRef< NamedAttributeDecl * > attributes)
Definition: Nodes.cpp:307
This class represents a PDLL type that corresponds to an mlir::Operation.
Definition: Types.h:158
const ods::Operation * getODSOperation() const
Return the ODS operation that this type refers to, or nullptr if the ODS operation is unknown.
Definition: Types.cpp:87
static OperationType get(Context &context, std::optional< StringRef > name=std::nullopt, const ods::Operation *odsOp=nullptr)
Return an instance of the Operation type with an optional operation name.
Definition: Types.cpp:73
static PatternDecl * create(Context &ctx, SMRange location, const Name *name, std::optional< uint16_t > benefit, bool hasBoundedRecursion, const CompoundStmt *body)
Definition: Nodes.cpp:520
static RangeExpr * create(Context &ctx, SMRange loc, ArrayRef< Expr * > elements, RangeType type)
Definition: Nodes.cpp:338
This class represents a PDLL type that corresponds to a range of elements with a given element type.
Definition: Types.h:183
static ReplaceStmt * create(Context &ctx, SMRange loc, Expr *rootOp, ArrayRef< Expr * > replExprs)
Definition: Nodes.cpp:225
static ReturnStmt * create(Context &ctx, SMRange loc, Expr *resultExpr)
Definition: Nodes.cpp:249
static RewriteStmt * create(Context &ctx, SMRange loc, Expr *rootOp, CompoundStmt *rewriteBody)
Definition: Nodes.cpp:239
This class represents a PDLL type that corresponds to a rewrite reference.
Definition: Types.h:230
static RewriteType get(Context &context)
Return an instance of the Rewrite type.
Definition: Types.cpp:133
This class represents a base AST Statement node.
Definition: Nodes.h:164
static TupleExpr * create(Context &ctx, SMRange loc, ArrayRef< Expr * > elements, ArrayRef< StringRef > elementNames)
Definition: Nodes.cpp:353
This class represents a PDLL tuple type, i.e.
Definition: Types.h:244
size_t size() const
Return the number of elements within this tuple.
Definition: Types.h:261
ArrayRef< Type > getElementTypes() const
Return the element types of this tuple.
Definition: Types.cpp:152
static TupleType get(Context &context, ArrayRef< Type > elementTypes, ArrayRef< StringRef > elementNames)
Return an instance of the Tuple type.
Definition: Types.cpp:141
static TypeConstraintDecl * create(Context &ctx, SMRange loc)
Definition: Nodes.cpp:417
static TypeExpr * create(Context &ctx, SMRange loc, StringRef value)
Definition: Nodes.cpp:373
static TypeRangeConstraintDecl * create(Context &ctx, SMRange loc)
Definition: Nodes.cpp:426
U dyn_cast() const
Definition: Types.h:76
Type refineWith(Type other) const
Try to refine this type with the one provided.
Definition: Types.cpp:33
bool isa() const
Provide type casting support.
Definition: Types.h:67
static UserConstraintDecl * createNative(Context &ctx, const Name &name, ArrayRef< VariableDecl * > inputs, ArrayRef< VariableDecl * > results, std::optional< StringRef > codeBlock, Type resultType, ArrayRef< StringRef > nativeInputTypes={})
Create a native constraint with the given optional code block.
Definition: Nodes.h:886
The class represents a Value constraint, and constrains a variable to be a Value.
Definition: Nodes.h:825
static ValueConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr)
Definition: Nodes.cpp:436
The class represents a ValueRange constraint, and constrains a variable to be a ValueRange.
Definition: Nodes.h:847
static ValueRangeConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr=nullptr)
Definition: Nodes.cpp:447
static VariableDecl * create(Context &ctx, const Name &name, Type type, Expr *initExpr, ArrayRef< ConstraintRef > constraints)
Definition: Nodes.cpp:558
This class contains all of the registered ODS operation classes.
Definition: Context.h:32
std::pair< Operation *, bool > insertOperation(StringRef name, StringRef summary, StringRef desc, StringRef nativeClassName, bool supportsResultTypeInferrence, SMLoc loc)
Insert a new operation with the context.
Definition: Context.cpp:64
const TypeConstraint & insertTypeConstraint(StringRef name, StringRef summary, StringRef cppClass)
Insert a new type constraint with the context.
Definition: Context.cpp:42
const AttributeConstraint & insertAttributeConstraint(StringRef name, StringRef summary, StringRef cppClass)
Insert a new attribute constraint with the context.
Definition: Context.cpp:28
This class provides an ODS representation of a specific operation operand or result.
Definition: Operation.h:74
This class provides an ODS representation of a specific operation.
Definition: Operation.h:125
SMRange getLoc() const
Return the source location of this operation.
Definition: Operation.h:128
ArrayRef< OperandOrResult > getOperands() const
Returns the operands of this operation.
Definition: Operation.h:165
bool hasResultTypeInferrence() const
Return if the operation is known to support result type inferrence.
Definition: Operation.h:171
ArrayRef< OperandOrResult > getResults() const
Returns the results of this operation.
Definition: Operation.h:168
This class represents a generic ODS Type constraint.
Definition: Constraint.h:84
raw_ostream subclass that simplifies indention a sequence of code.
raw_indented_ostream & printReindented(StringRef str, StringRef extraPrefix="")
Prints a string re-indented to the current indent.
StringRef getSummary() const
Definition: Constraint.cpp:55
std::string getUniqueDefName() const
Returns a unique name for the TablGen def of this constraint.
Definition: Constraint.cpp:72
StringRef getDescription() const
Definition: Constraint.cpp:62
std::string getConditionTemplate() const
Definition: Constraint.cpp:51
Format context containing substitutions for special placeholders.
Definition: Format.h:40
FmtContext & withSelf(Twine subst)
Definition: Format.cpp:41
Wrapper class that contains a MLIR op's information (e.g., operands, attributes) defined in TableGen ...
Definition: Operator.h:77
FailureOr< ast::Module * > parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr, bool enableDocumentation=false, CodeCompleteContext *codeCompleteContext=nullptr)
Parse an AST module from the main file of the given source manager.
Definition: Parser.cpp:3200
auto tgfmt(StringRef fmt, const FmtContext *ctx, Ts &&...vals) -> FmtObject< decltype(std::make_tuple(llvm::detail::build_format_adapter(std::forward< Ts >(vals))...))>
Formats text by substituting placeholders in format string with replacement parameters.
Definition: Format.h:261
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This class represents a reference to a constraint, and contains a constraint and the location of the ...
Definition: Nodes.h:716
const ConstraintDecl * constraint
Definition: Nodes.h:722
This class provides a convenient API for interacting with source names.
Definition: Nodes.h:37
StringRef getName() const
Return the raw string name.
Definition: Nodes.h:41
SMRange getLoc() const
Get the location of this name.
Definition: Nodes.h:44
static const Name & create(Context &ctx, StringRef name, SMRange location)
Definition: Nodes.cpp:33