MLIR  21.0.0git
Parser.cpp
Go to the documentation of this file.
1 //===- Parser.cpp ---------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "Lexer.h"
12 #include "mlir/TableGen/Argument.h"
15 #include "mlir/TableGen/Format.h"
16 #include "mlir/TableGen/Operator.h"
25 #include "llvm/ADT/StringExtras.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "llvm/Support/ManagedStatic.h"
29 #include "llvm/Support/SaveAndRestore.h"
30 #include "llvm/Support/ScopedPrinter.h"
31 #include "llvm/TableGen/Error.h"
32 #include "llvm/TableGen/Parser.h"
33 #include <optional>
34 #include <string>
35 
36 using namespace mlir;
37 using namespace mlir::pdll;
38 
39 //===----------------------------------------------------------------------===//
40 // Parser
41 //===----------------------------------------------------------------------===//
42 
43 namespace {
44 class Parser {
45 public:
46  Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
47  bool enableDocumentation, CodeCompleteContext *codeCompleteContext)
48  : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext),
49  curToken(lexer.lexToken()), enableDocumentation(enableDocumentation),
50  typeTy(ast::TypeType::get(ctx)), valueTy(ast::ValueType::get(ctx)),
51  typeRangeTy(ast::TypeRangeType::get(ctx)),
52  valueRangeTy(ast::ValueRangeType::get(ctx)),
53  attrTy(ast::AttributeType::get(ctx)),
54  codeCompleteContext(codeCompleteContext) {}
55 
56  /// Try to parse a new module. Returns nullptr in the case of failure.
57  FailureOr<ast::Module *> parseModule();
58 
59 private:
60  /// The current context of the parser. It allows for the parser to know a bit
61  /// about the construct it is nested within during parsing. This is used
62  /// specifically to provide additional verification during parsing, e.g. to
63  /// prevent using rewrites within a match context, matcher constraints within
64  /// a rewrite section, etc.
65  enum class ParserContext {
66  /// The parser is in the global context.
67  Global,
68  /// The parser is currently within a Constraint, which disallows all types
69  /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.).
70  Constraint,
71  /// The parser is currently within the matcher portion of a Pattern, which
72  /// is allows a terminal operation rewrite statement but no other rewrite
73  /// transformations.
74  PatternMatch,
75  /// The parser is currently within a Rewrite, which disallows calls to
76  /// constraints, requires operation expressions to have names, etc.
77  Rewrite,
78  };
79 
80  /// The current specification context of an operations result type. This
81  /// indicates how the result types of an operation may be inferred.
82  enum class OpResultTypeContext {
83  /// The result types of the operation are not known to be inferred.
84  Explicit,
85  /// The result types of the operation are inferred from the root input of a
86  /// `replace` statement.
87  Replacement,
88  /// The result types of the operation are inferred by using the
89  /// `InferTypeOpInterface` interface provided by the operation.
90  Interface,
91  };
92 
93  //===--------------------------------------------------------------------===//
94  // Parsing
95  //===--------------------------------------------------------------------===//
96 
97  /// Push a new decl scope onto the lexer.
98  ast::DeclScope *pushDeclScope() {
99  ast::DeclScope *newScope =
100  new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope);
101  return (curDeclScope = newScope);
102  }
103  void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; }
104 
105  /// Pop the last decl scope from the lexer.
106  void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); }
107 
108  /// Parse the body of an AST module.
109  LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls);
110 
111  /// Try to convert the given expression to `type`. Returns failure and emits
112  /// an error if a conversion is not viable. On failure, `noteAttachFn` is
113  /// invoked to attach notes to the emitted error diagnostic. On success,
114  /// `expr` is updated to the expression used to convert to `type`.
115  LogicalResult convertExpressionTo(
116  ast::Expr *&expr, ast::Type type,
117  function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {});
118  LogicalResult
119  convertOpExpressionTo(ast::Expr *&expr, ast::OperationType exprType,
120  ast::Type type,
121  function_ref<ast::InFlightDiagnostic()> emitErrorFn);
122  LogicalResult convertTupleExpressionTo(
123  ast::Expr *&expr, ast::TupleType exprType, ast::Type type,
124  function_ref<ast::InFlightDiagnostic()> emitErrorFn,
125  function_ref<void(ast::Diagnostic &diag)> noteAttachFn);
126 
127  /// Given an operation expression, convert it to a Value or ValueRange
128  /// typed expression.
129  ast::Expr *convertOpToValue(const ast::Expr *opExpr);
130 
131  /// Lookup ODS information for the given operation, returns nullptr if no
132  /// information is found.
133  const ods::Operation *lookupODSOperation(std::optional<StringRef> opName) {
134  return opName ? ctx.getODSContext().lookupOperation(*opName) : nullptr;
135  }
136 
137  /// Process the given documentation string, or return an empty string if
138  /// documentation isn't enabled.
139  StringRef processDoc(StringRef doc) {
140  return enableDocumentation ? doc : StringRef();
141  }
142 
143  /// Process the given documentation string and format it, or return an empty
144  /// string if documentation isn't enabled.
145  std::string processAndFormatDoc(const Twine &doc) {
146  if (!enableDocumentation)
147  return "";
148  std::string docStr;
149  {
150  llvm::raw_string_ostream docOS(docStr);
152  StringRef(docStr).rtrim(" \t"));
153  }
154  return docStr;
155  }
156 
157  //===--------------------------------------------------------------------===//
158  // Directives
159 
160  LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls);
161  LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls);
162  LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc,
164 
165  /// Process the records of a parsed tablegen include file.
166  void processTdIncludeRecords(const llvm::RecordKeeper &tdRecords,
168 
169  /// Create a user defined native constraint for a constraint imported from
170  /// ODS.
171  template <typename ConstraintT>
172  ast::Decl *
173  createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock,
174  SMRange loc, ast::Type type,
175  StringRef nativeType, StringRef docString);
176  template <typename ConstraintT>
177  ast::Decl *
178  createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
179  SMRange loc, ast::Type type,
180  StringRef nativeType);
181 
182  //===--------------------------------------------------------------------===//
183  // Decls
184 
185  /// This structure contains the set of pattern metadata that may be parsed.
186  struct ParsedPatternMetadata {
187  std::optional<uint16_t> benefit;
188  bool hasBoundedRecursion = false;
189  };
190 
191  FailureOr<ast::Decl *> parseTopLevelDecl();
192  FailureOr<ast::NamedAttributeDecl *>
193  parseNamedAttributeDecl(std::optional<StringRef> parentOpName);
194 
195  /// Parse an argument variable as part of the signature of a
196  /// UserConstraintDecl or UserRewriteDecl.
197  FailureOr<ast::VariableDecl *> parseArgumentDecl();
198 
199  /// Parse a result variable as part of the signature of a UserConstraintDecl
200  /// or UserRewriteDecl.
201  FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum);
202 
203  /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being
204  /// defined in a non-global context.
205  FailureOr<ast::UserConstraintDecl *>
206  parseUserConstraintDecl(bool isInline = false);
207 
208  /// Parse an inline UserConstraintDecl. An inline decl is one defined in a
209  /// non-global context, such as within a Pattern/Constraint/etc.
210  FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl();
211 
212  /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
213  /// PDLL constructs.
214  FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl(
215  const ast::Name &name, bool isInline,
216  ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
217  ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
218 
219  /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being
220  /// defined in a non-global context.
221  FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false);
222 
223  /// Parse an inline UserRewriteDecl. An inline decl is one defined in a
224  /// non-global context, such as within a Pattern/Rewrite/etc.
225  FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl();
226 
227  /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
228  /// PDLL constructs.
229  FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl(
230  const ast::Name &name, bool isInline,
231  ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
232  ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
233 
234  /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have
235  /// effectively the same syntax, and only differ on slight semantics (given
236  /// the different parsing contexts).
237  template <typename T, typename ParseUserPDLLDeclFnT>
238  FailureOr<T *> parseUserConstraintOrRewriteDecl(
239  ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
240  StringRef anonymousNamePrefix, bool isInline);
241 
242  /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl.
243  /// These decls have effectively the same syntax.
244  template <typename T>
245  FailureOr<T *> parseUserNativeConstraintOrRewriteDecl(
246  const ast::Name &name, bool isInline,
248  ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
249 
250  /// Parse the functional signature (i.e. the arguments and results) of a
251  /// UserConstraintDecl or UserRewriteDecl.
252  LogicalResult parseUserConstraintOrRewriteSignature(
255  ast::DeclScope *&argumentScope, ast::Type &resultType);
256 
257  /// Validate the return (which if present is specified by bodyIt) of a
258  /// UserConstraintDecl or UserRewriteDecl.
259  LogicalResult validateUserConstraintOrRewriteReturn(
260  StringRef declType, ast::CompoundStmt *body,
263  ArrayRef<ast::VariableDecl *> results, ast::Type &resultType);
264 
265  FailureOr<ast::CompoundStmt *>
266  parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
267  bool expectTerminalSemicolon = true);
268  FailureOr<ast::CompoundStmt *> parsePatternLambdaBody();
269  FailureOr<ast::Decl *> parsePatternDecl();
270  LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
271 
272  /// Check to see if a decl has already been defined with the given name, if
273  /// one has emit and error and return failure. Returns success otherwise.
274  LogicalResult checkDefineNamedDecl(const ast::Name &name);
275 
276  /// Try to define a variable decl with the given components, returns the
277  /// variable on success.
278  FailureOr<ast::VariableDecl *>
279  defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
280  ast::Expr *initExpr,
281  ArrayRef<ast::ConstraintRef> constraints);
282  FailureOr<ast::VariableDecl *>
283  defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
284  ArrayRef<ast::ConstraintRef> constraints);
285 
286  /// Parse the constraint reference list for a variable decl.
287  LogicalResult parseVariableDeclConstraintList(
289 
290  /// Parse the expression used within a type constraint, e.g. Attr<type-expr>.
291  FailureOr<ast::Expr *> parseTypeConstraintExpr();
292 
293  /// Try to parse a single reference to a constraint. `typeConstraint` is the
294  /// location of a previously parsed type constraint for the entity that will
295  /// be constrained by the parsed constraint. `existingConstraints` are any
296  /// existing constraints that have already been parsed for the same entity
297  /// that will be constrained by this constraint. `allowInlineTypeConstraints`
298  /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`.
299  FailureOr<ast::ConstraintRef>
300  parseConstraint(std::optional<SMRange> &typeConstraint,
301  ArrayRef<ast::ConstraintRef> existingConstraints,
302  bool allowInlineTypeConstraints);
303 
304  /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl
305  /// argument or result variable. The constraints for these variables do not
306  /// allow inline type constraints, and only permit a single constraint.
307  FailureOr<ast::ConstraintRef> parseArgOrResultConstraint();
308 
309  //===--------------------------------------------------------------------===//
310  // Exprs
311 
312  FailureOr<ast::Expr *> parseExpr();
313 
314  /// Identifier expressions.
315  FailureOr<ast::Expr *> parseAttributeExpr();
316  FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr,
317  bool isNegated = false);
318  FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
319  FailureOr<ast::Expr *> parseIdentifierExpr();
320  FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
321  FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
322  FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
323  FailureOr<ast::Expr *> parseNegatedExpr();
324  FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
325  FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
326  FailureOr<ast::Expr *>
327  parseOperationExpr(OpResultTypeContext inputResultTypeContext =
328  OpResultTypeContext::Explicit);
329  FailureOr<ast::Expr *> parseTupleExpr();
330  FailureOr<ast::Expr *> parseTypeExpr();
331  FailureOr<ast::Expr *> parseUnderscoreExpr();
332 
333  //===--------------------------------------------------------------------===//
334  // Stmts
335 
336  FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true);
337  FailureOr<ast::CompoundStmt *> parseCompoundStmt();
338  FailureOr<ast::EraseStmt *> parseEraseStmt();
339  FailureOr<ast::LetStmt *> parseLetStmt();
340  FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
341  FailureOr<ast::ReturnStmt *> parseReturnStmt();
342  FailureOr<ast::RewriteStmt *> parseRewriteStmt();
343 
344  //===--------------------------------------------------------------------===//
345  // Creation+Analysis
346  //===--------------------------------------------------------------------===//
347 
348  //===--------------------------------------------------------------------===//
349  // Decls
350 
351  /// Try to extract a callable from the given AST node. Returns nullptr on
352  /// failure.
353  ast::CallableDecl *tryExtractCallableDecl(ast::Node *node);
354 
355  /// Try to create a pattern decl with the given components, returning the
356  /// Pattern on success.
357  FailureOr<ast::PatternDecl *>
358  createPatternDecl(SMRange loc, const ast::Name *name,
359  const ParsedPatternMetadata &metadata,
360  ast::CompoundStmt *body);
361 
362  /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set
363  /// of results, defined as part of the signature.
364  ast::Type
365  createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results);
366 
367  /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl.
368  template <typename T>
369  FailureOr<T *> createUserPDLLConstraintOrRewriteDecl(
370  const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
371  ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
372  ast::CompoundStmt *body);
373 
374  /// Try to create a variable decl with the given components, returning the
375  /// Variable on success.
376  FailureOr<ast::VariableDecl *>
377  createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
378  ArrayRef<ast::ConstraintRef> constraints);
379 
380  /// Create a variable for an argument or result defined as part of the
381  /// signature of a UserConstraintDecl/UserRewriteDecl.
382  FailureOr<ast::VariableDecl *>
383  createArgOrResultVariableDecl(StringRef name, SMRange loc,
384  const ast::ConstraintRef &constraint);
385 
386  /// Validate the constraints used to constraint a variable decl.
387  /// `inferredType` is the type of the variable inferred by the constraints
388  /// within the list, and is updated to the most refined type as determined by
389  /// the constraints. Returns success if the constraint list is valid, failure
390  /// otherwise.
391  LogicalResult
392  validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
393  ast::Type &inferredType);
394  /// Validate a single reference to a constraint. `inferredType` contains the
395  /// currently inferred variabled type and is refined within the type defined
396  /// by the constraint. Returns success if the constraint is valid, failure
397  /// otherwise.
398  LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref,
399  ast::Type &inferredType);
400  LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr);
401  LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr);
402 
403  //===--------------------------------------------------------------------===//
404  // Exprs
405 
406  FailureOr<ast::CallExpr *>
407  createCallExpr(SMRange loc, ast::Expr *parentExpr,
409  bool isNegated = false);
410  FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl);
411  FailureOr<ast::DeclRefExpr *>
412  createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
413  ArrayRef<ast::ConstraintRef> constraints);
414  FailureOr<ast::MemberAccessExpr *>
415  createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc);
416 
417  /// Validate the member access `name` into the given parent expression. On
418  /// success, this also returns the type of the member accessed.
419  FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr,
420  StringRef name, SMRange loc);
421  FailureOr<ast::OperationExpr *>
422  createOperationExpr(SMRange loc, const ast::OpNameDecl *name,
423  OpResultTypeContext resultTypeContext,
427  LogicalResult
428  validateOperationOperands(SMRange loc, std::optional<StringRef> name,
429  const ods::Operation *odsOp,
430  SmallVectorImpl<ast::Expr *> &operands);
431  LogicalResult validateOperationResults(SMRange loc,
432  std::optional<StringRef> name,
433  const ods::Operation *odsOp,
435  void checkOperationResultTypeInferrence(SMRange loc, StringRef name,
436  const ods::Operation *odsOp);
437  LogicalResult validateOperationOperandsOrResults(
438  StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc,
439  std::optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values,
440  ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
441  ast::RangeType rangeTy);
442  FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
443  ArrayRef<ast::Expr *> elements,
444  ArrayRef<StringRef> elementNames);
445 
446  //===--------------------------------------------------------------------===//
447  // Stmts
448 
449  FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp);
450  FailureOr<ast::ReplaceStmt *>
451  createReplaceStmt(SMRange loc, ast::Expr *rootOp,
452  MutableArrayRef<ast::Expr *> replValues);
453  FailureOr<ast::RewriteStmt *>
454  createRewriteStmt(SMRange loc, ast::Expr *rootOp,
455  ast::CompoundStmt *rewriteBody);
456 
457  //===--------------------------------------------------------------------===//
458  // Code Completion
459  //===--------------------------------------------------------------------===//
460 
461  /// The set of various code completion methods. Every completion method
462  /// returns `failure` to stop the parsing process after providing completion
463  /// results.
464 
465  LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr);
466  LogicalResult codeCompleteAttributeName(std::optional<StringRef> opName);
467  LogicalResult codeCompleteConstraintName(ast::Type inferredType,
468  bool allowInlineTypeConstraints);
469  LogicalResult codeCompleteDialectName();
470  LogicalResult codeCompleteOperationName(StringRef dialectName);
471  LogicalResult codeCompletePatternMetadata();
472  LogicalResult codeCompleteIncludeFilename(StringRef curPath);
473 
474  void codeCompleteCallSignature(ast::Node *parent, unsigned currentNumArgs);
475  void codeCompleteOperationOperandsSignature(std::optional<StringRef> opName,
476  unsigned currentNumOperands);
477  void codeCompleteOperationResultsSignature(std::optional<StringRef> opName,
478  unsigned currentNumResults);
479 
480  //===--------------------------------------------------------------------===//
481  // Lexer Utilities
482  //===--------------------------------------------------------------------===//
483 
484  /// If the current token has the specified kind, consume it and return true.
485  /// If not, return false.
486  bool consumeIf(Token::Kind kind) {
487  if (curToken.isNot(kind))
488  return false;
489  consumeToken(kind);
490  return true;
491  }
492 
493  /// Advance the current lexer onto the next token.
494  void consumeToken() {
495  assert(curToken.isNot(Token::eof, Token::error) &&
496  "shouldn't advance past EOF or errors");
497  curToken = lexer.lexToken();
498  }
499 
500  /// Advance the current lexer onto the next token, asserting what the expected
501  /// current token is. This is preferred to the above method because it leads
502  /// to more self-documenting code with better checking.
503  void consumeToken(Token::Kind kind) {
504  assert(curToken.is(kind) && "consumed an unexpected token");
505  consumeToken();
506  }
507 
508  /// Reset the lexer to the location at the given position.
509  void resetToken(SMRange tokLoc) {
510  lexer.resetPointer(tokLoc.Start.getPointer());
511  curToken = lexer.lexToken();
512  }
513 
514  /// Consume the specified token if present and return success. On failure,
515  /// output a diagnostic and return failure.
516  LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
517  if (curToken.getKind() != kind)
518  return emitError(curToken.getLoc(), msg);
519  consumeToken();
520  return success();
521  }
522  LogicalResult emitError(SMRange loc, const Twine &msg) {
523  lexer.emitError(loc, msg);
524  return failure();
525  }
526  LogicalResult emitError(const Twine &msg) {
527  return emitError(curToken.getLoc(), msg);
528  }
529  LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc,
530  const Twine &note) {
531  lexer.emitErrorAndNote(loc, msg, noteLoc, note);
532  return failure();
533  }
534 
535  //===--------------------------------------------------------------------===//
536  // Fields
537  //===--------------------------------------------------------------------===//
538 
539  /// The owning AST context.
540  ast::Context &ctx;
541 
542  /// The lexer of this parser.
543  Lexer lexer;
544 
545  /// The current token within the lexer.
546  Token curToken;
547 
548  /// A flag indicating if the parser should add documentation to AST nodes when
549  /// viable.
550  bool enableDocumentation;
551 
552  /// The most recently defined decl scope.
553  ast::DeclScope *curDeclScope = nullptr;
554  llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;
555 
556  /// The current context of the parser.
557  ParserContext parserContext = ParserContext::Global;
558 
559  /// Cached types to simplify verification and expression creation.
560  ast::Type typeTy, valueTy;
561  ast::RangeType typeRangeTy, valueRangeTy;
562  ast::Type attrTy;
563 
564  /// A counter used when naming anonymous constraints and rewrites.
565  unsigned anonymousDeclNameCounter = 0;
566 
567  /// The optional code completion context.
568  CodeCompleteContext *codeCompleteContext;
569 };
570 } // namespace
571 
572 FailureOr<ast::Module *> Parser::parseModule() {
573  SMLoc moduleLoc = curToken.getStartLoc();
574  pushDeclScope();
575 
576  // Parse the top-level decls of the module.
578  if (failed(parseModuleBody(decls)))
579  return popDeclScope(), failure();
580 
581  popDeclScope();
582  return ast::Module::create(ctx, moduleLoc, decls);
583 }
584 
585 LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) {
586  while (curToken.isNot(Token::eof)) {
587  if (curToken.is(Token::directive)) {
588  if (failed(parseDirective(decls)))
589  return failure();
590  continue;
591  }
592 
593  FailureOr<ast::Decl *> decl = parseTopLevelDecl();
594  if (failed(decl))
595  return failure();
596  decls.push_back(*decl);
597  }
598  return success();
599 }
600 
601 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) {
602  return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr,
603  valueRangeTy);
604 }
605 
606 LogicalResult Parser::convertExpressionTo(
607  ast::Expr *&expr, ast::Type type,
608  function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
609  ast::Type exprType = expr->getType();
610  if (exprType == type)
611  return success();
612 
613  auto emitConvertError = [&]() -> ast::InFlightDiagnostic {
615  expr->getLoc(), llvm::formatv("unable to convert expression of type "
616  "`{0}` to the expected type of "
617  "`{1}`",
618  exprType, type));
619  if (noteAttachFn)
620  noteAttachFn(*diag);
621  return diag;
622  };
623 
624  if (auto exprOpType = dyn_cast<ast::OperationType>(exprType))
625  return convertOpExpressionTo(expr, exprOpType, type, emitConvertError);
626 
627  // FIXME: Decide how to allow/support converting a single result to multiple,
628  // and multiple to a single result. For now, we just allow Single->Range,
629  // but this isn't something really supported in the PDL dialect. We should
630  // figure out some way to support both.
631  if ((exprType == valueTy || exprType == valueRangeTy) &&
632  (type == valueTy || type == valueRangeTy))
633  return success();
634  if ((exprType == typeTy || exprType == typeRangeTy) &&
635  (type == typeTy || type == typeRangeTy))
636  return success();
637 
638  // Handle tuple types.
639  if (auto exprTupleType = dyn_cast<ast::TupleType>(exprType))
640  return convertTupleExpressionTo(expr, exprTupleType, type, emitConvertError,
641  noteAttachFn);
642 
643  return emitConvertError();
644 }
645 
646 LogicalResult Parser::convertOpExpressionTo(
647  ast::Expr *&expr, ast::OperationType exprType, ast::Type type,
648  function_ref<ast::InFlightDiagnostic()> emitErrorFn) {
649  // Two operation types are compatible if they have the same name, or if the
650  // expected type is more general.
651  if (auto opType = dyn_cast<ast::OperationType>(type)) {
652  if (opType.getName())
653  return emitErrorFn();
654  return success();
655  }
656 
657  // An operation can always convert to a ValueRange.
658  if (type == valueRangeTy) {
659  expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
660  valueRangeTy);
661  return success();
662  }
663 
664  // Allow conversion to a single value by constraining the result range.
665  if (type == valueTy) {
666  // If the operation is registered, we can verify if it can ever have a
667  // single result.
668  if (const ods::Operation *odsOp = exprType.getODSOperation()) {
669  if (odsOp->getResults().empty()) {
670  return emitErrorFn()->attachNote(
671  llvm::formatv("see the definition of `{0}`, which was defined "
672  "with zero results",
673  odsOp->getName()),
674  odsOp->getLoc());
675  }
676 
677  unsigned numSingleResults = llvm::count_if(
678  odsOp->getResults(), [](const ods::OperandOrResult &result) {
679  return result.getVariableLengthKind() ==
680  ods::VariableLengthKind::Single;
681  });
682  if (numSingleResults > 1) {
683  return emitErrorFn()->attachNote(
684  llvm::formatv("see the definition of `{0}`, which was defined "
685  "with at least {1} results",
686  odsOp->getName(), numSingleResults),
687  odsOp->getLoc());
688  }
689  }
690 
691  expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
692  valueTy);
693  return success();
694  }
695  return emitErrorFn();
696 }
697 
698 LogicalResult Parser::convertTupleExpressionTo(
699  ast::Expr *&expr, ast::TupleType exprType, ast::Type type,
700  function_ref<ast::InFlightDiagnostic()> emitErrorFn,
701  function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
702  // Handle conversions between tuples.
703  if (auto tupleType = dyn_cast<ast::TupleType>(type)) {
704  if (tupleType.size() != exprType.size())
705  return emitErrorFn();
706 
707  // Build a new tuple expression using each of the elements of the current
708  // tuple.
709  SmallVector<ast::Expr *> newExprs;
710  for (unsigned i = 0, e = exprType.size(); i < e; ++i) {
711  newExprs.push_back(ast::MemberAccessExpr::create(
712  ctx, expr->getLoc(), expr, llvm::to_string(i),
713  exprType.getElementTypes()[i]));
714 
715  auto diagFn = [&](ast::Diagnostic &diag) {
716  diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`",
717  i, exprType));
718  if (noteAttachFn)
719  noteAttachFn(diag);
720  };
721  if (failed(convertExpressionTo(newExprs.back(),
722  tupleType.getElementTypes()[i], diagFn)))
723  return failure();
724  }
725  expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs,
726  tupleType.getElementNames());
727  return success();
728  }
729 
730  // Handle conversion to a range.
731  auto convertToRange = [&](ArrayRef<ast::Type> allowedElementTypes,
732  ast::RangeType resultTy) -> LogicalResult {
733  // TODO: We currently only allow range conversion within a rewrite context.
734  if (parserContext != ParserContext::Rewrite) {
735  return emitErrorFn()->attachNote("Tuple to Range conversion is currently "
736  "only allowed within a rewrite context");
737  }
738 
739  // All of the tuple elements must be allowed types.
740  for (ast::Type elementType : exprType.getElementTypes())
741  if (!llvm::is_contained(allowedElementTypes, elementType))
742  return emitErrorFn();
743 
744  // Build a new tuple expression using each of the elements of the current
745  // tuple.
746  SmallVector<ast::Expr *> newExprs;
747  for (unsigned i = 0, e = exprType.size(); i < e; ++i) {
748  newExprs.push_back(ast::MemberAccessExpr::create(
749  ctx, expr->getLoc(), expr, llvm::to_string(i),
750  exprType.getElementTypes()[i]));
751  }
752  expr = ast::RangeExpr::create(ctx, expr->getLoc(), newExprs, resultTy);
753  return success();
754  };
755  if (type == valueRangeTy)
756  return convertToRange({valueTy, valueRangeTy}, valueRangeTy);
757  if (type == typeRangeTy)
758  return convertToRange({typeTy, typeRangeTy}, typeRangeTy);
759 
760  return emitErrorFn();
761 }
762 
763 //===----------------------------------------------------------------------===//
764 // Directives
765 //===----------------------------------------------------------------------===//
766 
767 LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) {
768  StringRef directive = curToken.getSpelling();
769  if (directive == "#include")
770  return parseInclude(decls);
771 
772  return emitError("unknown directive `" + directive + "`");
773 }
774 
775 LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) {
776  SMRange loc = curToken.getLoc();
777  consumeToken(Token::directive);
778 
779  // Handle code completion of the include file path.
780  if (curToken.is(Token::code_complete_string))
781  return codeCompleteIncludeFilename(curToken.getStringValue());
782 
783  // Parse the file being included.
784  if (!curToken.isString())
785  return emitError(loc,
786  "expected string file name after `include` directive");
787  SMRange fileLoc = curToken.getLoc();
788  std::string filenameStr = curToken.getStringValue();
789  StringRef filename = filenameStr;
790  consumeToken();
791 
792  // Check the type of include. If ending with `.pdll`, this is another pdl file
793  // to be parsed along with the current module.
794  if (filename.ends_with(".pdll")) {
795  if (failed(lexer.pushInclude(filename, fileLoc)))
796  return emitError(fileLoc,
797  "unable to open include file `" + filename + "`");
798 
799  // If we added the include successfully, parse it into the current module.
800  // Make sure to update to the next token after we finish parsing the nested
801  // file.
802  curToken = lexer.lexToken();
803  LogicalResult result = parseModuleBody(decls);
804  curToken = lexer.lexToken();
805  return result;
806  }
807 
808  // Otherwise, this must be a `.td` include.
809  if (filename.ends_with(".td"))
810  return parseTdInclude(filename, fileLoc, decls);
811 
812  return emitError(fileLoc,
813  "expected include filename to end with `.pdll` or `.td`");
814 }
815 
816 LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc,
818  llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr();
819 
820  // Use the source manager to open the file, but don't yet add it.
821  std::string includedFile;
822  llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer =
823  parserSrcMgr.OpenIncludeFile(filename.str(), includedFile);
824  if (!includeBuffer)
825  return emitError(fileLoc, "unable to open include file `" + filename + "`");
826 
827  // Setup the source manager for parsing the tablegen file.
828  llvm::SourceMgr tdSrcMgr;
829  tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc());
830  tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs());
831 
832  // This class provides a context argument for the llvm::SourceMgr diagnostic
833  // handler.
834  struct DiagHandlerContext {
835  Parser &parser;
836  StringRef filename;
837  llvm::SMRange loc;
838  } handlerContext{*this, filename, fileLoc};
839 
840  // Set the diagnostic handler for the tablegen source manager.
841  tdSrcMgr.setDiagHandler(
842  [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) {
843  auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext);
844  (void)ctx->parser.emitError(
845  ctx->loc,
846  llvm::formatv("error while processing include file `{0}`: {1}",
847  ctx->filename, diag.getMessage()));
848  },
849  &handlerContext);
850 
851  // Parse the tablegen file.
852  llvm::RecordKeeper tdRecords;
853  if (llvm::TableGenParseFile(tdSrcMgr, tdRecords))
854  return failure();
855 
856  // Process the parsed records.
857  processTdIncludeRecords(tdRecords, decls);
858 
859  // After we are done processing, move all of the tablegen source buffers to
860  // the main parser source mgr. This allows for directly using source locations
861  // from the .td files without needing to remap them.
862  parserSrcMgr.takeSourceBuffersFrom(tdSrcMgr, fileLoc.End);
863  return success();
864 }
865 
866 void Parser::processTdIncludeRecords(const llvm::RecordKeeper &tdRecords,
868  // Return the length kind of the given value.
869  auto getLengthKind = [](const auto &value) {
870  if (value.isOptional())
872  return value.isVariadic() ? ods::VariableLengthKind::Variadic
874  };
875 
876  // Insert a type constraint into the ODS context.
877  ods::Context &odsContext = ctx.getODSContext();
878  auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst)
879  -> const ods::TypeConstraint & {
880  return odsContext.insertTypeConstraint(
881  cst.constraint.getUniqueDefName(),
882  processDoc(cst.constraint.getSummary()), cst.constraint.getCppType());
883  };
884  auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange {
885  return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)};
886  };
887 
888  // Process the parsed tablegen records to build ODS information.
889  /// Operations.
890  for (const llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) {
891  tblgen::Operator op(def);
892 
893  // Check to see if this operation is known to support type inferrence.
894  bool supportsResultTypeInferrence =
895  op.getTrait("::mlir::InferTypeOpInterface::Trait");
896 
897  auto [odsOp, inserted] = odsContext.insertOperation(
898  op.getOperationName(), processDoc(op.getSummary()),
899  processAndFormatDoc(op.getDescription()), op.getQualCppClassName(),
900  supportsResultTypeInferrence, op.getLoc().front());
901 
902  // Ignore operations that have already been added.
903  if (!inserted)
904  continue;
905 
906  for (const tblgen::NamedAttribute &attr : op.getAttributes()) {
907  odsOp->appendAttribute(attr.name, attr.attr.isOptional(),
908  odsContext.insertAttributeConstraint(
909  attr.attr.getUniqueDefName(),
910  processDoc(attr.attr.getSummary()),
911  attr.attr.getStorageType()));
912  }
913  for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) {
914  odsOp->appendOperand(operand.name, getLengthKind(operand),
915  addTypeConstraint(operand));
916  }
917  for (const tblgen::NamedTypeConstraint &result : op.getResults()) {
918  odsOp->appendResult(result.name, getLengthKind(result),
919  addTypeConstraint(result));
920  }
921  }
922 
923  auto shouldBeSkipped = [this](const llvm::Record *def) {
924  return def->isAnonymous() || curDeclScope->lookup(def->getName()) ||
925  def->isSubClassOf("DeclareInterfaceMethods");
926  };
927 
928  /// Attr constraints.
929  for (const llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) {
930  if (shouldBeSkipped(def))
931  continue;
932 
933  tblgen::Attribute constraint(def);
934  decls.push_back(createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
935  constraint, convertLocToRange(def->getLoc().front()), attrTy,
936  constraint.getStorageType()));
937  }
938  /// Type constraints.
939  for (const llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) {
940  if (shouldBeSkipped(def))
941  continue;
942 
943  tblgen::TypeConstraint constraint(def);
944  decls.push_back(createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
945  constraint, convertLocToRange(def->getLoc().front()), typeTy,
946  constraint.getCppType()));
947  }
948  /// OpInterfaces.
950  for (const llvm::Record *def :
951  tdRecords.getAllDerivedDefinitions("OpInterface")) {
952  if (shouldBeSkipped(def))
953  continue;
954 
955  SMRange loc = convertLocToRange(def->getLoc().front());
956 
957  std::string cppClassName =
958  llvm::formatv("{0}::{1}", def->getValueAsString("cppNamespace"),
959  def->getValueAsString("cppInterfaceName"))
960  .str();
961  std::string codeBlock =
962  llvm::formatv("return ::mlir::success(llvm::isa<{0}>(self));",
963  cppClassName)
964  .str();
965 
966  std::string desc =
967  processAndFormatDoc(def->getValueAsString("description"));
968  decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>(
969  def->getName(), codeBlock, loc, opTy, cppClassName, desc));
970  }
971 }
972 
973 template <typename ConstraintT>
974 ast::Decl *Parser::createODSNativePDLLConstraintDecl(
975  StringRef name, StringRef codeBlock, SMRange loc, ast::Type type,
976  StringRef nativeType, StringRef docString) {
977  // Build the single input parameter.
978  ast::DeclScope *argScope = pushDeclScope();
979  auto *paramVar = ast::VariableDecl::create(
980  ctx, ast::Name::create(ctx, "self", loc), type,
981  /*initExpr=*/nullptr, ast::ConstraintRef(ConstraintT::create(ctx, loc)));
982  argScope->add(paramVar);
983  popDeclScope();
984 
985  // Build the native constraint.
986  auto *constraintDecl = ast::UserConstraintDecl::createNative(
987  ctx, ast::Name::create(ctx, name, loc), paramVar,
988  /*results=*/std::nullopt, codeBlock, ast::TupleType::get(ctx),
989  nativeType);
990  constraintDecl->setDocComment(ctx, docString);
991  curDeclScope->add(constraintDecl);
992  return constraintDecl;
993 }
994 
995 template <typename ConstraintT>
996 ast::Decl *
997 Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
998  SMRange loc, ast::Type type,
999  StringRef nativeType) {
1000  // Format the condition template.
1001  tblgen::FmtContext fmtContext;
1002  fmtContext.withSelf("self");
1003  std::string codeBlock = tblgen::tgfmt(
1004  "return ::mlir::success(" + constraint.getConditionTemplate() + ");",
1005  &fmtContext);
1006 
1007  // If documentation was enabled, build the doc string for the generated
1008  // constraint. It would be nice to do this lazily, but TableGen information is
1009  // destroyed after we finish parsing the file.
1010  std::string docString;
1011  if (enableDocumentation) {
1012  StringRef desc = constraint.getDescription();
1013  docString = processAndFormatDoc(
1014  constraint.getSummary() +
1015  (desc.empty() ? "" : ("\n\n" + constraint.getDescription())));
1016  }
1017 
1018  return createODSNativePDLLConstraintDecl<ConstraintT>(
1019  constraint.getUniqueDefName(), codeBlock, loc, type, nativeType,
1020  docString);
1021 }
1022 
1023 //===----------------------------------------------------------------------===//
1024 // Decls
1025 //===----------------------------------------------------------------------===//
1026 
1027 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
1028  FailureOr<ast::Decl *> decl;
1029  switch (curToken.getKind()) {
1030  case Token::kw_Constraint:
1031  decl = parseUserConstraintDecl();
1032  break;
1033  case Token::kw_Pattern:
1034  decl = parsePatternDecl();
1035  break;
1036  case Token::kw_Rewrite:
1037  decl = parseUserRewriteDecl();
1038  break;
1039  default:
1040  return emitError("expected top-level declaration, such as a `Pattern`");
1041  }
1042  if (failed(decl))
1043  return failure();
1044 
1045  // If the decl has a name, add it to the current scope.
1046  if (const ast::Name *name = (*decl)->getName()) {
1047  if (failed(checkDefineNamedDecl(*name)))
1048  return failure();
1049  curDeclScope->add(*decl);
1050  }
1051  return decl;
1052 }
1053 
1054 FailureOr<ast::NamedAttributeDecl *>
1055 Parser::parseNamedAttributeDecl(std::optional<StringRef> parentOpName) {
1056  // Check for name code completion.
1057  if (curToken.is(Token::code_complete))
1058  return codeCompleteAttributeName(parentOpName);
1059 
1060  std::string attrNameStr;
1061  if (curToken.isString())
1062  attrNameStr = curToken.getStringValue();
1063  else if (curToken.is(Token::identifier) || curToken.isKeyword())
1064  attrNameStr = curToken.getSpelling().str();
1065  else
1066  return emitError("expected identifier or string attribute name");
1067  const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc());
1068  consumeToken();
1069 
1070  // Check for a value of the attribute.
1071  ast::Expr *attrValue = nullptr;
1072  if (consumeIf(Token::equal)) {
1073  FailureOr<ast::Expr *> attrExpr = parseExpr();
1074  if (failed(attrExpr))
1075  return failure();
1076  attrValue = *attrExpr;
1077  } else {
1078  // If there isn't a concrete value, create an expression representing a
1079  // UnitAttr.
1080  attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit");
1081  }
1082 
1083  return ast::NamedAttributeDecl::create(ctx, name, attrValue);
1084 }
1085 
1086 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody(
1087  function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
1088  bool expectTerminalSemicolon) {
1089  consumeToken(Token::equal_arrow);
1090 
1091  // Parse the single statement of the lambda body.
1092  SMLoc bodyStartLoc = curToken.getStartLoc();
1093  pushDeclScope();
1094  FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon);
1095  bool failedToParse =
1096  failed(singleStatement) || failed(processStatementFn(*singleStatement));
1097  popDeclScope();
1098  if (failedToParse)
1099  return failure();
1100 
1101  SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc());
1102  return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement);
1103 }
1104 
1105 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() {
1106  // Ensure that the argument is named.
1107  if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword())
1108  return emitError("expected identifier argument name");
1109 
1110  // Parse the argument similarly to a normal variable.
1111  StringRef name = curToken.getSpelling();
1112  SMRange nameLoc = curToken.getLoc();
1113  consumeToken();
1114 
1115  if (failed(
1116  parseToken(Token::colon, "expected `:` before argument constraint")))
1117  return failure();
1118 
1119  FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1120  if (failed(cst))
1121  return failure();
1122 
1123  return createArgOrResultVariableDecl(name, nameLoc, *cst);
1124 }
1125 
1126 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) {
1127  // Check to see if this result is named.
1128  if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
1129  // Check to see if this name actually refers to a Constraint.
1130  if (!curDeclScope->lookup<ast::ConstraintDecl>(curToken.getSpelling())) {
1131  // If it wasn't a constraint, parse the result similarly to a variable. If
1132  // there is already an existing decl, we will emit an error when defining
1133  // this variable later.
1134  StringRef name = curToken.getSpelling();
1135  SMRange nameLoc = curToken.getLoc();
1136  consumeToken();
1137 
1138  if (failed(parseToken(Token::colon,
1139  "expected `:` before result constraint")))
1140  return failure();
1141 
1142  FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1143  if (failed(cst))
1144  return failure();
1145 
1146  return createArgOrResultVariableDecl(name, nameLoc, *cst);
1147  }
1148  }
1149 
1150  // If it isn't named, we parse the constraint directly and create an unnamed
1151  // result variable.
1152  FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1153  if (failed(cst))
1154  return failure();
1155 
1156  return createArgOrResultVariableDecl("", cst->referenceLoc, *cst);
1157 }
1158 
1159 FailureOr<ast::UserConstraintDecl *>
1160 Parser::parseUserConstraintDecl(bool isInline) {
1161  // Constraints and rewrites have very similar formats, dispatch to a shared
1162  // interface for parsing.
1163  return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1164  [&](auto &&...args) {
1165  return this->parseUserPDLLConstraintDecl(args...);
1166  },
1167  ParserContext::Constraint, "constraint", isInline);
1168 }
1169 
1170 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() {
1171  FailureOr<ast::UserConstraintDecl *> decl =
1172  parseUserConstraintDecl(/*isInline=*/true);
1173  if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1174  return failure();
1175 
1176  curDeclScope->add(*decl);
1177  return decl;
1178 }
1179 
1180 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl(
1181  const ast::Name &name, bool isInline,
1182  ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1183  ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1184  // Push the argument scope back onto the list, so that the body can
1185  // reference arguments.
1186  pushDeclScope(argumentScope);
1187 
1188  // Parse the body of the constraint. The body is either defined as a compound
1189  // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`.
1190  ast::CompoundStmt *body;
1191  if (curToken.is(Token::equal_arrow)) {
1192  FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1193  [&](ast::Stmt *&stmt) -> LogicalResult {
1194  ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt);
1195  if (!stmtExpr) {
1196  return emitError(stmt->getLoc(),
1197  "expected `Constraint` lambda body to contain a "
1198  "single expression");
1199  }
1200  stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr);
1201  return success();
1202  },
1203  /*expectTerminalSemicolon=*/!isInline);
1204  if (failed(bodyResult))
1205  return failure();
1206  body = *bodyResult;
1207  } else {
1208  FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1209  if (failed(bodyResult))
1210  return failure();
1211  body = *bodyResult;
1212 
1213  // Verify the structure of the body.
1214  auto bodyIt = body->begin(), bodyE = body->end();
1215  for (; bodyIt != bodyE; ++bodyIt)
1216  if (isa<ast::ReturnStmt>(*bodyIt))
1217  break;
1218  if (failed(validateUserConstraintOrRewriteReturn(
1219  "Constraint", body, bodyIt, bodyE, results, resultType)))
1220  return failure();
1221  }
1222  popDeclScope();
1223 
1224  return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1225  name, arguments, results, resultType, body);
1226 }
1227 
1228 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) {
1229  // Constraints and rewrites have very similar formats, dispatch to a shared
1230  // interface for parsing.
1231  return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1232  [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(args...); },
1233  ParserContext::Rewrite, "rewrite", isInline);
1234 }
1235 
1236 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() {
1237  FailureOr<ast::UserRewriteDecl *> decl =
1238  parseUserRewriteDecl(/*isInline=*/true);
1239  if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1240  return failure();
1241 
1242  curDeclScope->add(*decl);
1243  return decl;
1244 }
1245 
1246 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl(
1247  const ast::Name &name, bool isInline,
1248  ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1249  ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1250  // Push the argument scope back onto the list, so that the body can
1251  // reference arguments.
1252  curDeclScope = argumentScope;
1253  ast::CompoundStmt *body;
1254  if (curToken.is(Token::equal_arrow)) {
1255  FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1256  [&](ast::Stmt *&statement) -> LogicalResult {
1257  if (isa<ast::OpRewriteStmt>(statement))
1258  return success();
1259 
1260  ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement);
1261  if (!statementExpr) {
1262  return emitError(
1263  statement->getLoc(),
1264  "expected `Rewrite` lambda body to contain a single expression "
1265  "or an operation rewrite statement; such as `erase`, "
1266  "`replace`, or `rewrite`");
1267  }
1268  statement =
1269  ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr);
1270  return success();
1271  },
1272  /*expectTerminalSemicolon=*/!isInline);
1273  if (failed(bodyResult))
1274  return failure();
1275  body = *bodyResult;
1276  } else {
1277  FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1278  if (failed(bodyResult))
1279  return failure();
1280  body = *bodyResult;
1281  }
1282  popDeclScope();
1283 
1284  // Verify the structure of the body.
1285  auto bodyIt = body->begin(), bodyE = body->end();
1286  for (; bodyIt != bodyE; ++bodyIt)
1287  if (isa<ast::ReturnStmt>(*bodyIt))
1288  break;
1289  if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt,
1290  bodyE, results, resultType)))
1291  return failure();
1292  return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1293  name, arguments, results, resultType, body);
1294 }
1295 
1296 template <typename T, typename ParseUserPDLLDeclFnT>
1297 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl(
1298  ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
1299  StringRef anonymousNamePrefix, bool isInline) {
1300  SMRange loc = curToken.getLoc();
1301  consumeToken();
1302  llvm::SaveAndRestore saveCtx(parserContext, declContext);
1303 
1304  // Parse the name of the decl.
1305  const ast::Name *name = nullptr;
1306  if (curToken.isNot(Token::identifier)) {
1307  // Only inline decls can be un-named. Inline decls are similar to "lambdas"
1308  // in C++, so being unnamed is fine.
1309  if (!isInline)
1310  return emitError("expected identifier name");
1311 
1312  // Create a unique anonymous name to use, as the name for this decl is not
1313  // important.
1314  std::string anonName =
1315  llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix,
1316  anonymousDeclNameCounter++)
1317  .str();
1318  name = &ast::Name::create(ctx, anonName, loc);
1319  } else {
1320  // If a name was provided, we can use it directly.
1321  name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1322  consumeToken(Token::identifier);
1323  }
1324 
1325  // Parse the functional signature of the decl.
1326  SmallVector<ast::VariableDecl *> arguments, results;
1327  ast::DeclScope *argumentScope;
1328  ast::Type resultType;
1329  if (failed(parseUserConstraintOrRewriteSignature(arguments, results,
1330  argumentScope, resultType)))
1331  return failure();
1332 
1333  // Check to see which type of constraint this is. If the constraint contains a
1334  // compound body, this is a PDLL decl.
1335  if (curToken.isAny(Token::l_brace, Token::equal_arrow))
1336  return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results,
1337  resultType);
1338 
1339  // Otherwise, this is a native decl.
1340  return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments,
1341  results, resultType);
1342 }
1343 
1344 template <typename T>
1345 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
1346  const ast::Name &name, bool isInline,
1348  ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1349  // If followed by a string, the native code body has also been specified.
1350  std::string codeStrStorage;
1351  std::optional<StringRef> optCodeStr;
1352  if (curToken.isString()) {
1353  codeStrStorage = curToken.getStringValue();
1354  optCodeStr = codeStrStorage;
1355  consumeToken();
1356  } else if (isInline) {
1357  return emitError(name.getLoc(),
1358  "external declarations must be declared in global scope");
1359  } else if (curToken.is(Token::error)) {
1360  return failure();
1361  }
1362  if (failed(parseToken(Token::semicolon,
1363  "expected `;` after native declaration")))
1364  return failure();
1365  return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
1366 }
1367 
1368 LogicalResult Parser::parseUserConstraintOrRewriteSignature(
1371  ast::DeclScope *&argumentScope, ast::Type &resultType) {
1372  // Parse the argument list of the decl.
1373  if (failed(parseToken(Token::l_paren, "expected `(` to start argument list")))
1374  return failure();
1375 
1376  argumentScope = pushDeclScope();
1377  if (curToken.isNot(Token::r_paren)) {
1378  do {
1379  FailureOr<ast::VariableDecl *> argument = parseArgumentDecl();
1380  if (failed(argument))
1381  return failure();
1382  arguments.emplace_back(*argument);
1383  } while (consumeIf(Token::comma));
1384  }
1385  popDeclScope();
1386  if (failed(parseToken(Token::r_paren, "expected `)` to end argument list")))
1387  return failure();
1388 
1389  // Parse the results of the decl.
1390  pushDeclScope();
1391  if (consumeIf(Token::arrow)) {
1392  auto parseResultFn = [&]() -> LogicalResult {
1393  FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size());
1394  if (failed(result))
1395  return failure();
1396  results.emplace_back(*result);
1397  return success();
1398  };
1399 
1400  // Check for a list of results.
1401  if (consumeIf(Token::l_paren)) {
1402  do {
1403  if (failed(parseResultFn()))
1404  return failure();
1405  } while (consumeIf(Token::comma));
1406  if (failed(parseToken(Token::r_paren, "expected `)` to end result list")))
1407  return failure();
1408 
1409  // Otherwise, there is only one result.
1410  } else if (failed(parseResultFn())) {
1411  return failure();
1412  }
1413  }
1414  popDeclScope();
1415 
1416  // Compute the result type of the decl.
1417  resultType = createUserConstraintRewriteResultType(results);
1418 
1419  // Verify that results are only named if there are more than one.
1420  if (results.size() == 1 && !results.front()->getName().getName().empty()) {
1421  return emitError(
1422  results.front()->getLoc(),
1423  "cannot create a single-element tuple with an element label");
1424  }
1425  return success();
1426 }
1427 
1428 LogicalResult Parser::validateUserConstraintOrRewriteReturn(
1429  StringRef declType, ast::CompoundStmt *body,
1432  ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) {
1433  // Handle if a `return` was provided.
1434  if (bodyIt != bodyE) {
1435  // Emit an error if we have trailing statements after the return.
1436  if (std::next(bodyIt) != bodyE) {
1437  return emitError(
1438  (*std::next(bodyIt))->getLoc(),
1439  llvm::formatv("`return` terminated the `{0}` body, but found "
1440  "trailing statements afterwards",
1441  declType));
1442  }
1443 
1444  // Otherwise if a return wasn't provided, check that no results are
1445  // expected.
1446  } else if (!results.empty()) {
1447  return emitError(
1448  {body->getLoc().End, body->getLoc().End},
1449  llvm::formatv("missing return in a `{0}` expected to return `{1}`",
1450  declType, resultType));
1451  }
1452  return success();
1453 }
1454 
1455 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
1456  return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult {
1457  if (isa<ast::OpRewriteStmt>(statement))
1458  return success();
1459  return emitError(
1460  statement->getLoc(),
1461  "expected Pattern lambda body to contain a single operation "
1462  "rewrite statement, such as `erase`, `replace`, or `rewrite`");
1463  });
1464 }
1465 
1466 FailureOr<ast::Decl *> Parser::parsePatternDecl() {
1467  SMRange loc = curToken.getLoc();
1468  consumeToken(Token::kw_Pattern);
1469  llvm::SaveAndRestore saveCtx(parserContext, ParserContext::PatternMatch);
1470 
1471  // Check for an optional identifier for the pattern name.
1472  const ast::Name *name = nullptr;
1473  if (curToken.is(Token::identifier)) {
1474  name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1475  consumeToken(Token::identifier);
1476  }
1477 
1478  // Parse any pattern metadata.
1479  ParsedPatternMetadata metadata;
1480  if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata)))
1481  return failure();
1482 
1483  // Parse the pattern body.
1484  ast::CompoundStmt *body;
1485 
1486  // Handle a lambda body.
1487  if (curToken.is(Token::equal_arrow)) {
1488  FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody();
1489  if (failed(bodyResult))
1490  return failure();
1491  body = *bodyResult;
1492  } else {
1493  if (curToken.isNot(Token::l_brace))
1494  return emitError("expected `{` or `=>` to start pattern body");
1495  FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1496  if (failed(bodyResult))
1497  return failure();
1498  body = *bodyResult;
1499 
1500  // Verify the body of the pattern.
1501  auto bodyIt = body->begin(), bodyE = body->end();
1502  for (; bodyIt != bodyE; ++bodyIt) {
1503  if (isa<ast::ReturnStmt>(*bodyIt)) {
1504  return emitError((*bodyIt)->getLoc(),
1505  "`return` statements are only permitted within a "
1506  "`Constraint` or `Rewrite` body");
1507  }
1508  // Break when we've found the rewrite statement.
1509  if (isa<ast::OpRewriteStmt>(*bodyIt))
1510  break;
1511  }
1512  if (bodyIt == bodyE) {
1513  return emitError(loc,
1514  "expected Pattern body to terminate with an operation "
1515  "rewrite statement, such as `erase`");
1516  }
1517  if (std::next(bodyIt) != bodyE) {
1518  return emitError((*std::next(bodyIt))->getLoc(),
1519  "Pattern body was terminated by an operation "
1520  "rewrite statement, but found trailing statements");
1521  }
1522  }
1523 
1524  return createPatternDecl(loc, name, metadata, body);
1525 }
1526 
1527 LogicalResult
1528 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
1529  std::optional<SMRange> benefitLoc;
1530  std::optional<SMRange> hasBoundedRecursionLoc;
1531 
1532  do {
1533  // Handle metadata code completion.
1534  if (curToken.is(Token::code_complete))
1535  return codeCompletePatternMetadata();
1536 
1537  if (curToken.isNot(Token::identifier))
1538  return emitError("expected pattern metadata identifier");
1539  StringRef metadataStr = curToken.getSpelling();
1540  SMRange metadataLoc = curToken.getLoc();
1541  consumeToken(Token::identifier);
1542 
1543  // Parse the benefit metadata: benefit(<integer-value>)
1544  if (metadataStr == "benefit") {
1545  if (benefitLoc) {
1546  return emitErrorAndNote(metadataLoc,
1547  "pattern benefit has already been specified",
1548  *benefitLoc, "see previous definition here");
1549  }
1550  if (failed(parseToken(Token::l_paren,
1551  "expected `(` before pattern benefit")))
1552  return failure();
1553 
1554  uint16_t benefitValue = 0;
1555  if (curToken.isNot(Token::integer))
1556  return emitError("expected integral pattern benefit");
1557  if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue))
1558  return emitError(
1559  "expected pattern benefit to fit within a 16-bit integer");
1560  consumeToken(Token::integer);
1561 
1562  metadata.benefit = benefitValue;
1563  benefitLoc = metadataLoc;
1564 
1565  if (failed(
1566  parseToken(Token::r_paren, "expected `)` after pattern benefit")))
1567  return failure();
1568  continue;
1569  }
1570 
1571  // Parse the bounded recursion metadata: recursion
1572  if (metadataStr == "recursion") {
1573  if (hasBoundedRecursionLoc) {
1574  return emitErrorAndNote(
1575  metadataLoc,
1576  "pattern recursion metadata has already been specified",
1577  *hasBoundedRecursionLoc, "see previous definition here");
1578  }
1579  metadata.hasBoundedRecursion = true;
1580  hasBoundedRecursionLoc = metadataLoc;
1581  continue;
1582  }
1583 
1584  return emitError(metadataLoc, "unknown pattern metadata");
1585  } while (consumeIf(Token::comma));
1586 
1587  return success();
1588 }
1589 
1590 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() {
1591  consumeToken(Token::less);
1592 
1593  FailureOr<ast::Expr *> typeExpr = parseExpr();
1594  if (failed(typeExpr) ||
1595  failed(parseToken(Token::greater,
1596  "expected `>` after variable type constraint")))
1597  return failure();
1598  return typeExpr;
1599 }
1600 
1601 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) {
1602  assert(curDeclScope && "defining decl outside of a decl scope");
1603  if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) {
1604  return emitErrorAndNote(
1605  name.getLoc(), "`" + name.getName() + "` has already been defined",
1606  lastDecl->getName()->getLoc(), "see previous definition here");
1607  }
1608  return success();
1609 }
1610 
1611 FailureOr<ast::VariableDecl *>
1612 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1613  ast::Expr *initExpr,
1614  ArrayRef<ast::ConstraintRef> constraints) {
1615  assert(curDeclScope && "defining variable outside of decl scope");
1616  const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc);
1617 
1618  // If the name of the variable indicates a special variable, we don't add it
1619  // to the scope. This variable is local to the definition point.
1620  if (name.empty() || name == "_") {
1621  return ast::VariableDecl::create(ctx, nameDecl, type, initExpr,
1622  constraints);
1623  }
1624  if (failed(checkDefineNamedDecl(nameDecl)))
1625  return failure();
1626 
1627  auto *varDecl =
1628  ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints);
1629  curDeclScope->add(varDecl);
1630  return varDecl;
1631 }
1632 
1633 FailureOr<ast::VariableDecl *>
1634 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1635  ArrayRef<ast::ConstraintRef> constraints) {
1636  return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr,
1637  constraints);
1638 }
1639 
1640 LogicalResult Parser::parseVariableDeclConstraintList(
1641  SmallVectorImpl<ast::ConstraintRef> &constraints) {
1642  std::optional<SMRange> typeConstraint;
1643  auto parseSingleConstraint = [&] {
1644  FailureOr<ast::ConstraintRef> constraint = parseConstraint(
1645  typeConstraint, constraints, /*allowInlineTypeConstraints=*/true);
1646  if (failed(constraint))
1647  return failure();
1648  constraints.push_back(*constraint);
1649  return success();
1650  };
1651 
1652  // Check to see if this is a single constraint, or a list.
1653  if (!consumeIf(Token::l_square))
1654  return parseSingleConstraint();
1655 
1656  do {
1657  if (failed(parseSingleConstraint()))
1658  return failure();
1659  } while (consumeIf(Token::comma));
1660  return parseToken(Token::r_square, "expected `]` after constraint list");
1661 }
1662 
1663 FailureOr<ast::ConstraintRef>
1664 Parser::parseConstraint(std::optional<SMRange> &typeConstraint,
1665  ArrayRef<ast::ConstraintRef> existingConstraints,
1666  bool allowInlineTypeConstraints) {
1667  auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
1668  if (!allowInlineTypeConstraints) {
1669  return emitError(
1670  curToken.getLoc(),
1671  "inline `Attr`, `Value`, and `ValueRange` type constraints are not "
1672  "permitted on arguments or results");
1673  }
1674  if (typeConstraint)
1675  return emitErrorAndNote(
1676  curToken.getLoc(),
1677  "the type of this variable has already been constrained",
1678  *typeConstraint, "see previous constraint location here");
1679  FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr();
1680  if (failed(constraintExpr))
1681  return failure();
1682  typeExpr = *constraintExpr;
1683  typeConstraint = typeExpr->getLoc();
1684  return success();
1685  };
1686 
1687  SMRange loc = curToken.getLoc();
1688  switch (curToken.getKind()) {
1689  case Token::kw_Attr: {
1690  consumeToken(Token::kw_Attr);
1691 
1692  // Check for a type constraint.
1693  ast::Expr *typeExpr = nullptr;
1694  if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1695  return failure();
1696  return ast::ConstraintRef(
1697  ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc);
1698  }
1699  case Token::kw_Op: {
1700  consumeToken(Token::kw_Op);
1701 
1702  // Parse an optional operation name. If the name isn't provided, this refers
1703  // to "any" operation.
1704  FailureOr<ast::OpNameDecl *> opName =
1705  parseWrappedOperationName(/*allowEmptyName=*/true);
1706  if (failed(opName))
1707  return failure();
1708 
1709  return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName),
1710  loc);
1711  }
1712  case Token::kw_Type:
1713  consumeToken(Token::kw_Type);
1714  return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc);
1715  case Token::kw_TypeRange:
1716  consumeToken(Token::kw_TypeRange);
1718  loc);
1719  case Token::kw_Value: {
1720  consumeToken(Token::kw_Value);
1721 
1722  // Check for a type constraint.
1723  ast::Expr *typeExpr = nullptr;
1724  if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1725  return failure();
1726 
1727  return ast::ConstraintRef(
1728  ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc);
1729  }
1730  case Token::kw_ValueRange: {
1731  consumeToken(Token::kw_ValueRange);
1732 
1733  // Check for a type constraint.
1734  ast::Expr *typeExpr = nullptr;
1735  if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1736  return failure();
1737 
1738  return ast::ConstraintRef(
1739  ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc);
1740  }
1741 
1742  case Token::kw_Constraint: {
1743  // Handle an inline constraint.
1744  FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1745  if (failed(decl))
1746  return failure();
1747  return ast::ConstraintRef(*decl, loc);
1748  }
1749  case Token::identifier: {
1750  StringRef constraintName = curToken.getSpelling();
1751  consumeToken(Token::identifier);
1752 
1753  // Lookup the referenced constraint.
1754  ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName);
1755  if (!cstDecl) {
1756  return emitError(loc, "unknown reference to constraint `" +
1757  constraintName + "`");
1758  }
1759 
1760  // Handle a reference to a proper constraint.
1761  if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl))
1762  return ast::ConstraintRef(cst, loc);
1763 
1764  return emitErrorAndNote(
1765  loc, "invalid reference to non-constraint", cstDecl->getLoc(),
1766  "see the definition of `" + constraintName + "` here");
1767  }
1768  // Handle single entity constraint code completion.
1769  case Token::code_complete: {
1770  // Try to infer the current type for use by code completion.
1771  ast::Type inferredType;
1772  if (failed(validateVariableConstraints(existingConstraints, inferredType)))
1773  return failure();
1774 
1775  return codeCompleteConstraintName(inferredType, allowInlineTypeConstraints);
1776  }
1777  default:
1778  break;
1779  }
1780  return emitError(loc, "expected identifier constraint");
1781 }
1782 
1783 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() {
1784  std::optional<SMRange> typeConstraint;
1785  return parseConstraint(typeConstraint, /*existingConstraints=*/std::nullopt,
1786  /*allowInlineTypeConstraints=*/false);
1787 }
1788 
1789 //===----------------------------------------------------------------------===//
1790 // Exprs
1791 //===----------------------------------------------------------------------===//
1792 
1793 FailureOr<ast::Expr *> Parser::parseExpr() {
1794  if (curToken.is(Token::underscore))
1795  return parseUnderscoreExpr();
1796 
1797  // Parse the LHS expression.
1798  FailureOr<ast::Expr *> lhsExpr;
1799  switch (curToken.getKind()) {
1800  case Token::kw_attr:
1801  lhsExpr = parseAttributeExpr();
1802  break;
1803  case Token::kw_Constraint:
1804  lhsExpr = parseInlineConstraintLambdaExpr();
1805  break;
1806  case Token::kw_not:
1807  lhsExpr = parseNegatedExpr();
1808  break;
1809  case Token::identifier:
1810  lhsExpr = parseIdentifierExpr();
1811  break;
1812  case Token::kw_op:
1813  lhsExpr = parseOperationExpr();
1814  break;
1815  case Token::kw_Rewrite:
1816  lhsExpr = parseInlineRewriteLambdaExpr();
1817  break;
1818  case Token::kw_type:
1819  lhsExpr = parseTypeExpr();
1820  break;
1821  case Token::l_paren:
1822  lhsExpr = parseTupleExpr();
1823  break;
1824  default:
1825  return emitError("expected expression");
1826  }
1827  if (failed(lhsExpr))
1828  return failure();
1829 
1830  // Check for an operator expression.
1831  while (true) {
1832  switch (curToken.getKind()) {
1833  case Token::dot:
1834  lhsExpr = parseMemberAccessExpr(*lhsExpr);
1835  break;
1836  case Token::l_paren:
1837  lhsExpr = parseCallExpr(*lhsExpr);
1838  break;
1839  default:
1840  return lhsExpr;
1841  }
1842  if (failed(lhsExpr))
1843  return failure();
1844  }
1845 }
1846 
1847 FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
1848  SMRange loc = curToken.getLoc();
1849  consumeToken(Token::kw_attr);
1850 
1851  // If we aren't followed by a `<`, the `attr` keyword is treated as a normal
1852  // identifier.
1853  if (!consumeIf(Token::less)) {
1854  resetToken(loc);
1855  return parseIdentifierExpr();
1856  }
1857 
1858  if (!curToken.isString())
1859  return emitError("expected string literal containing MLIR attribute");
1860  std::string attrExpr = curToken.getStringValue();
1861  consumeToken();
1862 
1863  loc.End = curToken.getEndLoc();
1864  if (failed(
1865  parseToken(Token::greater, "expected `>` after attribute literal")))
1866  return failure();
1867  return ast::AttributeExpr::create(ctx, loc, attrExpr);
1868 }
1869 
1870 FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr,
1871  bool isNegated) {
1872  consumeToken(Token::l_paren);
1873 
1874  // Parse the arguments of the call.
1875  SmallVector<ast::Expr *> arguments;
1876  if (curToken.isNot(Token::r_paren)) {
1877  do {
1878  // Handle code completion for the call arguments.
1879  if (curToken.is(Token::code_complete)) {
1880  codeCompleteCallSignature(parentExpr, arguments.size());
1881  return failure();
1882  }
1883 
1884  FailureOr<ast::Expr *> argument = parseExpr();
1885  if (failed(argument))
1886  return failure();
1887  arguments.push_back(*argument);
1888  } while (consumeIf(Token::comma));
1889  }
1890 
1891  SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1892  if (failed(parseToken(Token::r_paren, "expected `)` after argument list")))
1893  return failure();
1894 
1895  return createCallExpr(loc, parentExpr, arguments, isNegated);
1896 }
1897 
1898 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
1899  ast::Decl *decl = curDeclScope->lookup(name);
1900  if (!decl)
1901  return emitError(loc, "undefined reference to `" + name + "`");
1902 
1903  return createDeclRefExpr(loc, decl);
1904 }
1905 
1906 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
1907  StringRef name = curToken.getSpelling();
1908  SMRange nameLoc = curToken.getLoc();
1909  consumeToken();
1910 
1911  // Check to see if this is a decl ref expression that defines a variable
1912  // inline.
1913  if (consumeIf(Token::colon)) {
1914  SmallVector<ast::ConstraintRef> constraints;
1915  if (failed(parseVariableDeclConstraintList(constraints)))
1916  return failure();
1917  ast::Type type;
1918  if (failed(validateVariableConstraints(constraints, type)))
1919  return failure();
1920  return createInlineVariableExpr(type, name, nameLoc, constraints);
1921  }
1922 
1923  return parseDeclRefExpr(name, nameLoc);
1924 }
1925 
1926 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() {
1927  FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1928  if (failed(decl))
1929  return failure();
1930 
1931  return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1933 }
1934 
1935 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() {
1936  FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl();
1937  if (failed(decl))
1938  return failure();
1939 
1940  return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1941  ast::RewriteType::get(ctx));
1942 }
1943 
1944 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
1945  SMRange dotLoc = curToken.getLoc();
1946  consumeToken(Token::dot);
1947 
1948  // Check for code completion of the member name.
1949  if (curToken.is(Token::code_complete))
1950  return codeCompleteMemberAccess(parentExpr);
1951 
1952  // Parse the member name.
1953  Token memberNameTok = curToken;
1954  if (memberNameTok.isNot(Token::identifier, Token::integer) &&
1955  !memberNameTok.isKeyword())
1956  return emitError(dotLoc, "expected identifier or numeric member name");
1957  StringRef memberName = memberNameTok.getSpelling();
1958  SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1959  consumeToken();
1960 
1961  return createMemberAccessExpr(parentExpr, memberName, loc);
1962 }
1963 
1964 FailureOr<ast::Expr *> Parser::parseNegatedExpr() {
1965  consumeToken(Token::kw_not);
1966  // Only native constraints are supported after negation
1967  if (!curToken.is(Token::identifier))
1968  return emitError("expected native constraint");
1969  FailureOr<ast::Expr *> identifierExpr = parseIdentifierExpr();
1970  if (failed(identifierExpr))
1971  return failure();
1972  if (!curToken.is(Token::l_paren))
1973  return emitError("expected `(` after function name");
1974  return parseCallExpr(*identifierExpr, /*isNegated = */ true);
1975 }
1976 
1977 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
1978  SMRange loc = curToken.getLoc();
1979 
1980  // Check for code completion for the dialect name.
1981  if (curToken.is(Token::code_complete))
1982  return codeCompleteDialectName();
1983 
1984  // Handle the case of an no operation name.
1985  if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) {
1986  if (allowEmptyName)
1987  return ast::OpNameDecl::create(ctx, SMRange());
1988  return emitError("expected dialect namespace");
1989  }
1990  StringRef name = curToken.getSpelling();
1991  consumeToken();
1992 
1993  // Otherwise, this is a literal operation name.
1994  if (failed(parseToken(Token::dot, "expected `.` after dialect namespace")))
1995  return failure();
1996 
1997  // Check for code completion for the operation name.
1998  if (curToken.is(Token::code_complete))
1999  return codeCompleteOperationName(name);
2000 
2001  if (curToken.isNot(Token::identifier) && !curToken.isKeyword())
2002  return emitError("expected operation name after dialect namespace");
2003 
2004  name = StringRef(name.data(), name.size() + 1);
2005  do {
2006  name = StringRef(name.data(), name.size() + curToken.getSpelling().size());
2007  loc.End = curToken.getEndLoc();
2008  consumeToken();
2009  } while (curToken.isAny(Token::identifier, Token::dot) ||
2010  curToken.isKeyword());
2011  return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc));
2012 }
2013 
2014 FailureOr<ast::OpNameDecl *>
2015 Parser::parseWrappedOperationName(bool allowEmptyName) {
2016  if (!consumeIf(Token::less))
2017  return ast::OpNameDecl::create(ctx, SMRange());
2018 
2019  FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName);
2020  if (failed(opNameDecl))
2021  return failure();
2022 
2023  if (failed(parseToken(Token::greater, "expected `>` after operation name")))
2024  return failure();
2025  return opNameDecl;
2026 }
2027 
2028 FailureOr<ast::Expr *>
2029 Parser::parseOperationExpr(OpResultTypeContext inputResultTypeContext) {
2030  SMRange loc = curToken.getLoc();
2031  consumeToken(Token::kw_op);
2032 
2033  // If it isn't followed by a `<`, the `op` keyword is treated as a normal
2034  // identifier.
2035  if (curToken.isNot(Token::less)) {
2036  resetToken(loc);
2037  return parseIdentifierExpr();
2038  }
2039 
2040  // Parse the operation name. The name may be elided, in which case the
2041  // operation refers to "any" operation(i.e. a difference between `MyOp` and
2042  // `Operation*`). Operation names within a rewrite context must be named.
2043  bool allowEmptyName = parserContext != ParserContext::Rewrite;
2044  FailureOr<ast::OpNameDecl *> opNameDecl =
2045  parseWrappedOperationName(allowEmptyName);
2046  if (failed(opNameDecl))
2047  return failure();
2048  std::optional<StringRef> opName = (*opNameDecl)->getName();
2049 
2050  // Functor used to create an implicit range variable, used for implicit "all"
2051  // operand or results variables.
2052  auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) {
2053  FailureOr<ast::VariableDecl *> rangeVar =
2054  defineVariableDecl("_", loc, type, ast::ConstraintRef(cst, loc));
2055  assert(succeeded(rangeVar) && "expected range variable to be valid");
2056  return ast::DeclRefExpr::create(ctx, loc, *rangeVar, type);
2057  };
2058 
2059  // Check for the optional list of operands.
2060  SmallVector<ast::Expr *> operands;
2061  if (!consumeIf(Token::l_paren)) {
2062  // If the operand list isn't specified and we are in a match context, define
2063  // an inplace unconstrained operand range corresponding to all of the
2064  // operands of the operation. This avoids treating zero operands the same
2065  // way as "unconstrained operands".
2066  if (parserContext != ParserContext::Rewrite) {
2067  operands.push_back(createImplicitRangeVar(
2068  ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy));
2069  }
2070  } else if (!consumeIf(Token::r_paren)) {
2071  // If the operand list was specified and non-empty, parse the operands.
2072  do {
2073  // Check for operand signature code completion.
2074  if (curToken.is(Token::code_complete)) {
2075  codeCompleteOperationOperandsSignature(opName, operands.size());
2076  return failure();
2077  }
2078 
2079  FailureOr<ast::Expr *> operand = parseExpr();
2080  if (failed(operand))
2081  return failure();
2082  operands.push_back(*operand);
2083  } while (consumeIf(Token::comma));
2084 
2085  if (failed(parseToken(Token::r_paren,
2086  "expected `)` after operation operand list")))
2087  return failure();
2088  }
2089 
2090  // Check for the optional list of attributes.
2092  if (consumeIf(Token::l_brace)) {
2093  do {
2094  FailureOr<ast::NamedAttributeDecl *> decl =
2095  parseNamedAttributeDecl(opName);
2096  if (failed(decl))
2097  return failure();
2098  attributes.emplace_back(*decl);
2099  } while (consumeIf(Token::comma));
2100 
2101  if (failed(parseToken(Token::r_brace,
2102  "expected `}` after operation attribute list")))
2103  return failure();
2104  }
2105 
2106  // Handle the result types of the operation.
2107  SmallVector<ast::Expr *> resultTypes;
2108  OpResultTypeContext resultTypeContext = inputResultTypeContext;
2109 
2110  // Check for an explicit list of result types.
2111  if (consumeIf(Token::arrow)) {
2112  if (failed(parseToken(Token::l_paren,
2113  "expected `(` before operation result type list")))
2114  return failure();
2115 
2116  // If result types are provided, initially assume that the operation does
2117  // not rely on type inferrence. We don't assert that it isn't, because we
2118  // may be inferring the value of some type/type range variables, but given
2119  // that these variables may be defined in calls we can't always discern when
2120  // this is the case.
2121  resultTypeContext = OpResultTypeContext::Explicit;
2122 
2123  // Handle the case of an empty result list.
2124  if (!consumeIf(Token::r_paren)) {
2125  do {
2126  // Check for result signature code completion.
2127  if (curToken.is(Token::code_complete)) {
2128  codeCompleteOperationResultsSignature(opName, resultTypes.size());
2129  return failure();
2130  }
2131 
2132  FailureOr<ast::Expr *> resultTypeExpr = parseExpr();
2133  if (failed(resultTypeExpr))
2134  return failure();
2135  resultTypes.push_back(*resultTypeExpr);
2136  } while (consumeIf(Token::comma));
2137 
2138  if (failed(parseToken(Token::r_paren,
2139  "expected `)` after operation result type list")))
2140  return failure();
2141  }
2142  } else if (parserContext != ParserContext::Rewrite) {
2143  // If the result list isn't specified and we are in a match context, define
2144  // an inplace unconstrained result range corresponding to all of the results
2145  // of the operation. This avoids treating zero results the same way as
2146  // "unconstrained results".
2147  resultTypes.push_back(createImplicitRangeVar(
2148  ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy));
2149  } else if (resultTypeContext == OpResultTypeContext::Explicit) {
2150  // If the result list isn't specified and we are in a rewrite, try to infer
2151  // them at runtime instead.
2152  resultTypeContext = OpResultTypeContext::Interface;
2153  }
2154 
2155  return createOperationExpr(loc, *opNameDecl, resultTypeContext, operands,
2156  attributes, resultTypes);
2157 }
2158 
2159 FailureOr<ast::Expr *> Parser::parseTupleExpr() {
2160  SMRange loc = curToken.getLoc();
2161  consumeToken(Token::l_paren);
2162 
2163  DenseMap<StringRef, SMRange> usedNames;
2164  SmallVector<StringRef> elementNames;
2165  SmallVector<ast::Expr *> elements;
2166  if (curToken.isNot(Token::r_paren)) {
2167  do {
2168  // Check for the optional element name assignment before the value.
2169  StringRef elementName;
2170  if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
2171  Token elementNameTok = curToken;
2172  consumeToken();
2173 
2174  // The element name is only present if followed by an `=`.
2175  if (consumeIf(Token::equal)) {
2176  elementName = elementNameTok.getSpelling();
2177 
2178  // Check to see if this name is already used.
2179  auto elementNameIt =
2180  usedNames.try_emplace(elementName, elementNameTok.getLoc());
2181  if (!elementNameIt.second) {
2182  return emitErrorAndNote(
2183  elementNameTok.getLoc(),
2184  llvm::formatv("duplicate tuple element label `{0}`",
2185  elementName),
2186  elementNameIt.first->getSecond(),
2187  "see previous label use here");
2188  }
2189  } else {
2190  // Otherwise, we treat this as part of an expression so reset the
2191  // lexer.
2192  resetToken(elementNameTok.getLoc());
2193  }
2194  }
2195  elementNames.push_back(elementName);
2196 
2197  // Parse the tuple element value.
2198  FailureOr<ast::Expr *> element = parseExpr();
2199  if (failed(element))
2200  return failure();
2201  elements.push_back(*element);
2202  } while (consumeIf(Token::comma));
2203  }
2204  loc.End = curToken.getEndLoc();
2205  if (failed(
2206  parseToken(Token::r_paren, "expected `)` after tuple element list")))
2207  return failure();
2208  return createTupleExpr(loc, elements, elementNames);
2209 }
2210 
2211 FailureOr<ast::Expr *> Parser::parseTypeExpr() {
2212  SMRange loc = curToken.getLoc();
2213  consumeToken(Token::kw_type);
2214 
2215  // If we aren't followed by a `<`, the `type` keyword is treated as a normal
2216  // identifier.
2217  if (!consumeIf(Token::less)) {
2218  resetToken(loc);
2219  return parseIdentifierExpr();
2220  }
2221 
2222  if (!curToken.isString())
2223  return emitError("expected string literal containing MLIR type");
2224  std::string attrExpr = curToken.getStringValue();
2225  consumeToken();
2226 
2227  loc.End = curToken.getEndLoc();
2228  if (failed(parseToken(Token::greater, "expected `>` after type literal")))
2229  return failure();
2230  return ast::TypeExpr::create(ctx, loc, attrExpr);
2231 }
2232 
2233 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
2234  StringRef name = curToken.getSpelling();
2235  SMRange nameLoc = curToken.getLoc();
2236  consumeToken(Token::underscore);
2237 
2238  // Underscore expressions require a constraint list.
2239  if (failed(parseToken(Token::colon, "expected `:` after `_` variable")))
2240  return failure();
2241 
2242  // Parse the constraints for the expression.
2243  SmallVector<ast::ConstraintRef> constraints;
2244  if (failed(parseVariableDeclConstraintList(constraints)))
2245  return failure();
2246 
2247  ast::Type type;
2248  if (failed(validateVariableConstraints(constraints, type)))
2249  return failure();
2250  return createInlineVariableExpr(type, name, nameLoc, constraints);
2251 }
2252 
2253 //===----------------------------------------------------------------------===//
2254 // Stmts
2255 //===----------------------------------------------------------------------===//
2256 
2257 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) {
2258  FailureOr<ast::Stmt *> stmt;
2259  switch (curToken.getKind()) {
2260  case Token::kw_erase:
2261  stmt = parseEraseStmt();
2262  break;
2263  case Token::kw_let:
2264  stmt = parseLetStmt();
2265  break;
2266  case Token::kw_replace:
2267  stmt = parseReplaceStmt();
2268  break;
2269  case Token::kw_return:
2270  stmt = parseReturnStmt();
2271  break;
2272  case Token::kw_rewrite:
2273  stmt = parseRewriteStmt();
2274  break;
2275  default:
2276  stmt = parseExpr();
2277  break;
2278  }
2279  if (failed(stmt) ||
2280  (expectTerminalSemicolon &&
2281  failed(parseToken(Token::semicolon, "expected `;` after statement"))))
2282  return failure();
2283  return stmt;
2284 }
2285 
2286 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() {
2287  SMLoc startLoc = curToken.getStartLoc();
2288  consumeToken(Token::l_brace);
2289 
2290  // Push a new block scope and parse any nested statements.
2291  pushDeclScope();
2292  SmallVector<ast::Stmt *> statements;
2293  while (curToken.isNot(Token::r_brace)) {
2294  FailureOr<ast::Stmt *> statement = parseStmt();
2295  if (failed(statement))
2296  return popDeclScope(), failure();
2297  statements.push_back(*statement);
2298  }
2299  popDeclScope();
2300 
2301  // Consume the end brace.
2302  SMRange location(startLoc, curToken.getEndLoc());
2303  consumeToken(Token::r_brace);
2304 
2305  return ast::CompoundStmt::create(ctx, location, statements);
2306 }
2307 
2308 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
2309  if (parserContext == ParserContext::Constraint)
2310  return emitError("`erase` cannot be used within a Constraint");
2311  SMRange loc = curToken.getLoc();
2312  consumeToken(Token::kw_erase);
2313 
2314  // Parse the root operation expression.
2315  FailureOr<ast::Expr *> rootOp = parseExpr();
2316  if (failed(rootOp))
2317  return failure();
2318 
2319  return createEraseStmt(loc, *rootOp);
2320 }
2321 
2322 FailureOr<ast::LetStmt *> Parser::parseLetStmt() {
2323  SMRange loc = curToken.getLoc();
2324  consumeToken(Token::kw_let);
2325 
2326  // Parse the name of the new variable.
2327  SMRange varLoc = curToken.getLoc();
2328  if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) {
2329  // `_` is a reserved variable name.
2330  if (curToken.is(Token::underscore)) {
2331  return emitError(varLoc,
2332  "`_` may only be used to define \"inline\" variables");
2333  }
2334  return emitError(varLoc,
2335  "expected identifier after `let` to name a new variable");
2336  }
2337  StringRef varName = curToken.getSpelling();
2338  consumeToken();
2339 
2340  // Parse the optional set of constraints.
2341  SmallVector<ast::ConstraintRef> constraints;
2342  if (consumeIf(Token::colon) &&
2343  failed(parseVariableDeclConstraintList(constraints)))
2344  return failure();
2345 
2346  // Parse the optional initializer expression.
2347  ast::Expr *initializer = nullptr;
2348  if (consumeIf(Token::equal)) {
2349  FailureOr<ast::Expr *> initOrFailure = parseExpr();
2350  if (failed(initOrFailure))
2351  return failure();
2352  initializer = *initOrFailure;
2353 
2354  // Check that the constraints are compatible with having an initializer,
2355  // e.g. type constraints cannot be used with initializers.
2356  for (ast::ConstraintRef constraint : constraints) {
2357  LogicalResult result =
2358  TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint)
2360  ast::ValueRangeConstraintDecl>([&](const auto *cst) {
2361  if (cst->getTypeExpr()) {
2362  return this->emitError(
2363  constraint.referenceLoc,
2364  "type constraints are not permitted on variables with "
2365  "initializers");
2366  }
2367  return success();
2368  })
2369  .Default(success());
2370  if (failed(result))
2371  return failure();
2372  }
2373  }
2374 
2375  FailureOr<ast::VariableDecl *> varDecl =
2376  createVariableDecl(varName, varLoc, initializer, constraints);
2377  if (failed(varDecl))
2378  return failure();
2379  return ast::LetStmt::create(ctx, loc, *varDecl);
2380 }
2381 
2382 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
2383  if (parserContext == ParserContext::Constraint)
2384  return emitError("`replace` cannot be used within a Constraint");
2385  SMRange loc = curToken.getLoc();
2386  consumeToken(Token::kw_replace);
2387 
2388  // Parse the root operation expression.
2389  FailureOr<ast::Expr *> rootOp = parseExpr();
2390  if (failed(rootOp))
2391  return failure();
2392 
2393  if (failed(
2394  parseToken(Token::kw_with, "expected `with` after root operation")))
2395  return failure();
2396 
2397  // The replacement portion of this statement is within a rewrite context.
2398  llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
2399 
2400  // Parse the replacement values.
2401  SmallVector<ast::Expr *> replValues;
2402  if (consumeIf(Token::l_paren)) {
2403  if (consumeIf(Token::r_paren)) {
2404  return emitError(
2405  loc, "expected at least one replacement value, consider using "
2406  "`erase` if no replacement values are desired");
2407  }
2408 
2409  do {
2410  FailureOr<ast::Expr *> replExpr = parseExpr();
2411  if (failed(replExpr))
2412  return failure();
2413  replValues.emplace_back(*replExpr);
2414  } while (consumeIf(Token::comma));
2415 
2416  if (failed(parseToken(Token::r_paren,
2417  "expected `)` after replacement values")))
2418  return failure();
2419  } else {
2420  // Handle replacement with an operation uniquely, as the replacement
2421  // operation supports type inferrence from the root operation.
2422  FailureOr<ast::Expr *> replExpr;
2423  if (curToken.is(Token::kw_op))
2424  replExpr = parseOperationExpr(OpResultTypeContext::Replacement);
2425  else
2426  replExpr = parseExpr();
2427  if (failed(replExpr))
2428  return failure();
2429  replValues.emplace_back(*replExpr);
2430  }
2431 
2432  return createReplaceStmt(loc, *rootOp, replValues);
2433 }
2434 
2435 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() {
2436  SMRange loc = curToken.getLoc();
2437  consumeToken(Token::kw_return);
2438 
2439  // Parse the result value.
2440  FailureOr<ast::Expr *> resultExpr = parseExpr();
2441  if (failed(resultExpr))
2442  return failure();
2443 
2444  return ast::ReturnStmt::create(ctx, loc, *resultExpr);
2445 }
2446 
2447 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
2448  if (parserContext == ParserContext::Constraint)
2449  return emitError("`rewrite` cannot be used within a Constraint");
2450  SMRange loc = curToken.getLoc();
2451  consumeToken(Token::kw_rewrite);
2452 
2453  // Parse the root operation.
2454  FailureOr<ast::Expr *> rootOp = parseExpr();
2455  if (failed(rootOp))
2456  return failure();
2457 
2458  if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body")))
2459  return failure();
2460 
2461  if (curToken.isNot(Token::l_brace))
2462  return emitError("expected `{` to start rewrite body");
2463 
2464  // The rewrite body of this statement is within a rewrite context.
2465  llvm::SaveAndRestore saveCtx(parserContext, ParserContext::Rewrite);
2466 
2467  FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt();
2468  if (failed(rewriteBody))
2469  return failure();
2470 
2471  // Verify the rewrite body.
2472  for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) {
2473  if (isa<ast::ReturnStmt>(stmt)) {
2474  return emitError(stmt->getLoc(),
2475  "`return` statements are only permitted within a "
2476  "`Constraint` or `Rewrite` body");
2477  }
2478  }
2479 
2480  return createRewriteStmt(loc, *rootOp, *rewriteBody);
2481 }
2482 
2483 //===----------------------------------------------------------------------===//
2484 // Creation+Analysis
2485 //===----------------------------------------------------------------------===//
2486 
2487 //===----------------------------------------------------------------------===//
2488 // Decls
2489 //===----------------------------------------------------------------------===//
2490 
2491 ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) {
2492  // Unwrap reference expressions.
2493  if (auto *init = dyn_cast<ast::DeclRefExpr>(node))
2494  node = init->getDecl();
2495  return dyn_cast<ast::CallableDecl>(node);
2496 }
2497 
2498 FailureOr<ast::PatternDecl *>
2499 Parser::createPatternDecl(SMRange loc, const ast::Name *name,
2500  const ParsedPatternMetadata &metadata,
2501  ast::CompoundStmt *body) {
2502  return ast::PatternDecl::create(ctx, loc, name, metadata.benefit,
2503  metadata.hasBoundedRecursion, body);
2504 }
2505 
2506 ast::Type Parser::createUserConstraintRewriteResultType(
2508  // Single result decls use the type of the single result.
2509  if (results.size() == 1)
2510  return results[0]->getType();
2511 
2512  // Multiple results use a tuple type, with the types and names grabbed from
2513  // the result variable decls.
2514  auto resultTypes = llvm::map_range(
2515  results, [&](const auto *result) { return result->getType(); });
2516  auto resultNames = llvm::map_range(
2517  results, [&](const auto *result) { return result->getName().getName(); });
2518  return ast::TupleType::get(ctx, llvm::to_vector(resultTypes),
2519  llvm::to_vector(resultNames));
2520 }
2521 
2522 template <typename T>
2523 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl(
2524  const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
2525  ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
2526  ast::CompoundStmt *body) {
2527  if (!body->getChildren().empty()) {
2528  if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) {
2529  ast::Expr *resultExpr = retStmt->getResultExpr();
2530 
2531  // Process the result of the decl. If no explicit signature results
2532  // were provided, check for return type inference. Otherwise, check that
2533  // the return expression can be converted to the expected type.
2534  if (results.empty())
2535  resultType = resultExpr->getType();
2536  else if (failed(convertExpressionTo(resultExpr, resultType)))
2537  return failure();
2538  else
2539  retStmt->setResultExpr(resultExpr);
2540  }
2541  }
2542  return T::createPDLL(ctx, name, arguments, results, body, resultType);
2543 }
2544 
2545 FailureOr<ast::VariableDecl *>
2546 Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
2547  ArrayRef<ast::ConstraintRef> constraints) {
2548  // The type of the variable, which is expected to be inferred by either a
2549  // constraint or an initializer expression.
2550  ast::Type type;
2551  if (failed(validateVariableConstraints(constraints, type)))
2552  return failure();
2553 
2554  if (initializer) {
2555  // Update the variable type based on the initializer, or try to convert the
2556  // initializer to the existing type.
2557  if (!type)
2558  type = initializer->getType();
2559  else if (ast::Type mergedType = type.refineWith(initializer->getType()))
2560  type = mergedType;
2561  else if (failed(convertExpressionTo(initializer, type)))
2562  return failure();
2563 
2564  // Otherwise, if there is no initializer check that the type has already
2565  // been resolved from the constraint list.
2566  } else if (!type) {
2567  return emitErrorAndNote(
2568  loc, "unable to infer type for variable `" + name + "`", loc,
2569  "the type of a variable must be inferable from the constraint "
2570  "list or the initializer");
2571  }
2572 
2573  // Constraint types cannot be used when defining variables.
2574  if (isa<ast::ConstraintType, ast::RewriteType>(type)) {
2575  return emitError(
2576  loc, llvm::formatv("unable to define variable of `{0}` type", type));
2577  }
2578 
2579  // Try to define a variable with the given name.
2580  FailureOr<ast::VariableDecl *> varDecl =
2581  defineVariableDecl(name, loc, type, initializer, constraints);
2582  if (failed(varDecl))
2583  return failure();
2584 
2585  return *varDecl;
2586 }
2587 
2588 FailureOr<ast::VariableDecl *>
2589 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
2590  const ast::ConstraintRef &constraint) {
2591  ast::Type argType;
2592  if (failed(validateVariableConstraint(constraint, argType)))
2593  return failure();
2594  return defineVariableDecl(name, loc, argType, constraint);
2595 }
2596 
2597 LogicalResult
2598 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
2599  ast::Type &inferredType) {
2600  for (const ast::ConstraintRef &ref : constraints)
2601  if (failed(validateVariableConstraint(ref, inferredType)))
2602  return failure();
2603  return success();
2604 }
2605 
2606 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
2607  ast::Type &inferredType) {
2608  ast::Type constraintType;
2609  if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) {
2610  if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2611  if (failed(validateTypeConstraintExpr(typeExpr)))
2612  return failure();
2613  }
2614  constraintType = ast::AttributeType::get(ctx);
2615  } else if (const auto *cst =
2616  dyn_cast<ast::OpConstraintDecl>(ref.constraint)) {
2617  constraintType = ast::OperationType::get(
2618  ctx, cst->getName(), lookupODSOperation(cst->getName()));
2619  } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) {
2620  constraintType = typeTy;
2621  } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) {
2622  constraintType = typeRangeTy;
2623  } else if (const auto *cst =
2624  dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) {
2625  if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2626  if (failed(validateTypeConstraintExpr(typeExpr)))
2627  return failure();
2628  }
2629  constraintType = valueTy;
2630  } else if (const auto *cst =
2631  dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) {
2632  if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2633  if (failed(validateTypeRangeConstraintExpr(typeExpr)))
2634  return failure();
2635  }
2636  constraintType = valueRangeTy;
2637  } else if (const auto *cst =
2638  dyn_cast<ast::UserConstraintDecl>(ref.constraint)) {
2639  ArrayRef<ast::VariableDecl *> inputs = cst->getInputs();
2640  if (inputs.size() != 1) {
2641  return emitErrorAndNote(ref.referenceLoc,
2642  "`Constraint`s applied via a variable constraint "
2643  "list must take a single input, but got " +
2644  Twine(inputs.size()),
2645  cst->getLoc(),
2646  "see definition of constraint here");
2647  }
2648  constraintType = inputs.front()->getType();
2649  } else {
2650  llvm_unreachable("unknown constraint type");
2651  }
2652 
2653  // Check that the constraint type is compatible with the current inferred
2654  // type.
2655  if (!inferredType) {
2656  inferredType = constraintType;
2657  } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) {
2658  inferredType = mergedTy;
2659  } else {
2660  return emitError(ref.referenceLoc,
2661  llvm::formatv("constraint type `{0}` is incompatible "
2662  "with the previously inferred type `{1}`",
2663  constraintType, inferredType));
2664  }
2665  return success();
2666 }
2667 
2668 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) {
2669  ast::Type typeExprType = typeExpr->getType();
2670  if (typeExprType != typeTy) {
2671  return emitError(typeExpr->getLoc(),
2672  "expected expression of `Type` in type constraint");
2673  }
2674  return success();
2675 }
2676 
2677 LogicalResult
2678 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) {
2679  ast::Type typeExprType = typeExpr->getType();
2680  if (typeExprType != typeRangeTy) {
2681  return emitError(typeExpr->getLoc(),
2682  "expected expression of `TypeRange` in type constraint");
2683  }
2684  return success();
2685 }
2686 
2687 //===----------------------------------------------------------------------===//
2688 // Exprs
2689 //===----------------------------------------------------------------------===//
2690 
2691 FailureOr<ast::CallExpr *>
2692 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
2693  MutableArrayRef<ast::Expr *> arguments, bool isNegated) {
2694  ast::Type parentType = parentExpr->getType();
2695 
2696  ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr);
2697  if (!callableDecl) {
2698  return emitError(loc,
2699  llvm::formatv("expected a reference to a callable "
2700  "`Constraint` or `Rewrite`, but got: `{0}`",
2701  parentType));
2702  }
2703  if (parserContext == ParserContext::Rewrite) {
2704  if (isa<ast::UserConstraintDecl>(callableDecl))
2705  return emitError(
2706  loc, "unable to invoke `Constraint` within a rewrite section");
2707  if (isNegated)
2708  return emitError(loc, "unable to negate a Rewrite");
2709  } else {
2710  if (isa<ast::UserRewriteDecl>(callableDecl))
2711  return emitError(loc,
2712  "unable to invoke `Rewrite` within a match section");
2713  if (isNegated && cast<ast::UserConstraintDecl>(callableDecl)->getBody())
2714  return emitError(loc, "unable to negate non native constraints");
2715  }
2716 
2717  // Verify the arguments of the call.
2718  /// Handle size mismatch.
2719  ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs();
2720  if (callArgs.size() != arguments.size()) {
2721  return emitErrorAndNote(
2722  loc,
2723  llvm::formatv("invalid number of arguments for {0} call; expected "
2724  "{1}, but got {2}",
2725  callableDecl->getCallableType(), callArgs.size(),
2726  arguments.size()),
2727  callableDecl->getLoc(),
2728  llvm::formatv("see the definition of {0} here",
2729  callableDecl->getName()->getName()));
2730  }
2731 
2732  /// Handle argument type mismatch.
2733  auto attachDiagFn = [&](ast::Diagnostic &diag) {
2734  diag.attachNote(llvm::formatv("see the definition of `{0}` here",
2735  callableDecl->getName()->getName()),
2736  callableDecl->getLoc());
2737  };
2738  for (auto it : llvm::zip(callArgs, arguments)) {
2739  if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(),
2740  attachDiagFn)))
2741  return failure();
2742  }
2743 
2744  return ast::CallExpr::create(ctx, loc, parentExpr, arguments,
2745  callableDecl->getResultType(), isNegated);
2746 }
2747 
2748 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,
2749  ast::Decl *decl) {
2750  // Check the type of decl being referenced.
2751  ast::Type declType;
2752  if (isa<ast::ConstraintDecl>(decl))
2753  declType = ast::ConstraintType::get(ctx);
2754  else if (isa<ast::UserRewriteDecl>(decl))
2755  declType = ast::RewriteType::get(ctx);
2756  else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
2757  declType = varDecl->getType();
2758  else
2759  return emitError(loc, "invalid reference to `" +
2760  decl->getName()->getName() + "`");
2761 
2762  return ast::DeclRefExpr::create(ctx, loc, decl, declType);
2763 }
2764 
2765 FailureOr<ast::DeclRefExpr *>
2766 Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
2767  ArrayRef<ast::ConstraintRef> constraints) {
2768  FailureOr<ast::VariableDecl *> decl =
2769  defineVariableDecl(name, loc, type, constraints);
2770  if (failed(decl))
2771  return failure();
2772  return ast::DeclRefExpr::create(ctx, loc, *decl, type);
2773 }
2774 
2775 FailureOr<ast::MemberAccessExpr *>
2776 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
2777  SMRange loc) {
2778  // Validate the member name for the given parent expression.
2779  FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc);
2780  if (failed(memberType))
2781  return failure();
2782 
2783  return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType);
2784 }
2785 
2786 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
2787  StringRef name, SMRange loc) {
2788  ast::Type parentType = parentExpr->getType();
2789  if (ast::OperationType opType = dyn_cast<ast::OperationType>(parentType)) {
2791  return valueRangeTy;
2792 
2793  // Verify member access based on the operation type.
2794  if (const ods::Operation *odsOp = opType.getODSOperation()) {
2795  auto results = odsOp->getResults();
2796 
2797  // Handle indexed results.
2798  unsigned index = 0;
2799  if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2800  index < results.size()) {
2801  return results[index].isVariadic() ? valueRangeTy : valueTy;
2802  }
2803 
2804  // Handle named results.
2805  const auto *it = llvm::find_if(results, [&](const auto &result) {
2806  return result.getName() == name;
2807  });
2808  if (it != results.end())
2809  return it->isVariadic() ? valueRangeTy : valueTy;
2810  } else if (llvm::isDigit(name[0])) {
2811  // Allow unchecked numeric indexing of the results of unregistered
2812  // operations. It returns a single value.
2813  return valueTy;
2814  }
2815  } else if (auto tupleType = dyn_cast<ast::TupleType>(parentType)) {
2816  // Handle indexed results.
2817  unsigned index = 0;
2818  if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2819  index < tupleType.size()) {
2820  return tupleType.getElementTypes()[index];
2821  }
2822 
2823  // Handle named results.
2824  auto elementNames = tupleType.getElementNames();
2825  const auto *it = llvm::find(elementNames, name);
2826  if (it != elementNames.end())
2827  return tupleType.getElementTypes()[it - elementNames.begin()];
2828  }
2829  return emitError(
2830  loc,
2831  llvm::formatv("invalid member access `{0}` on expression of type `{1}`",
2832  name, parentType));
2833 }
2834 
2835 FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
2836  SMRange loc, const ast::OpNameDecl *name,
2837  OpResultTypeContext resultTypeContext,
2838  SmallVectorImpl<ast::Expr *> &operands,
2840  SmallVectorImpl<ast::Expr *> &results) {
2841  std::optional<StringRef> opNameRef = name->getName();
2842  const ods::Operation *odsOp = lookupODSOperation(opNameRef);
2843 
2844  // Verify the inputs operands.
2845  if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands)))
2846  return failure();
2847 
2848  // Verify the attribute list.
2849  for (ast::NamedAttributeDecl *attr : attributes) {
2850  // Check for an attribute type, or a type awaiting resolution.
2851  ast::Type attrType = attr->getValue()->getType();
2852  if (!isa<ast::AttributeType>(attrType)) {
2853  return emitError(
2854  attr->getValue()->getLoc(),
2855  llvm::formatv("expected `Attr` expression, but got `{0}`", attrType));
2856  }
2857  }
2858 
2859  assert(
2860  (resultTypeContext == OpResultTypeContext::Explicit || results.empty()) &&
2861  "unexpected inferrence when results were explicitly specified");
2862 
2863  // If we aren't relying on type inferrence, or explicit results were provided,
2864  // validate them.
2865  if (resultTypeContext == OpResultTypeContext::Explicit) {
2866  if (failed(validateOperationResults(loc, opNameRef, odsOp, results)))
2867  return failure();
2868 
2869  // Validate the use of interface based type inferrence for this operation.
2870  } else if (resultTypeContext == OpResultTypeContext::Interface) {
2871  assert(opNameRef &&
2872  "expected valid operation name when inferring operation results");
2873  checkOperationResultTypeInferrence(loc, *opNameRef, odsOp);
2874  }
2875 
2876  return ast::OperationExpr::create(ctx, loc, odsOp, name, operands, results,
2877  attributes);
2878 }
2879 
2880 LogicalResult
2881 Parser::validateOperationOperands(SMRange loc, std::optional<StringRef> name,
2882  const ods::Operation *odsOp,
2883  SmallVectorImpl<ast::Expr *> &operands) {
2884  return validateOperationOperandsOrResults(
2885  "operand", loc, odsOp ? odsOp->getLoc() : std::optional<SMRange>(), name,
2886  operands, odsOp ? odsOp->getOperands() : std::nullopt, valueTy,
2887  valueRangeTy);
2888 }
2889 
2890 LogicalResult
2891 Parser::validateOperationResults(SMRange loc, std::optional<StringRef> name,
2892  const ods::Operation *odsOp,
2893  SmallVectorImpl<ast::Expr *> &results) {
2894  return validateOperationOperandsOrResults(
2895  "result", loc, odsOp ? odsOp->getLoc() : std::optional<SMRange>(), name,
2896  results, odsOp ? odsOp->getResults() : std::nullopt, typeTy, typeRangeTy);
2897 }
2898 
2899 void Parser::checkOperationResultTypeInferrence(SMRange loc, StringRef opName,
2900  const ods::Operation *odsOp) {
2901  // If the operation might not have inferrence support, emit a warning to the
2902  // user. We don't emit an error because the interface might be added to the
2903  // operation at runtime. It's rare, but it could still happen. We emit a
2904  // warning here instead.
2905 
2906  // Handle inferrence warnings for unknown operations.
2907  if (!odsOp) {
2908  ctx.getDiagEngine().emitWarning(
2909  loc, llvm::formatv(
2910  "operation result types are marked to be inferred, but "
2911  "`{0}` is unknown. Ensure that `{0}` supports zero "
2912  "results or implements `InferTypeOpInterface`. Include "
2913  "the ODS definition of this operation to remove this warning.",
2914  opName));
2915  return;
2916  }
2917 
2918  // Handle inferrence warnings for known operations that expected at least one
2919  // result, but don't have inference support. An elided results list can mean
2920  // "zero-results", and we don't want to warn when that is the expected
2921  // behavior.
2922  bool requiresInferrence =
2923  llvm::any_of(odsOp->getResults(), [](const ods::OperandOrResult &result) {
2924  return !result.isVariableLength();
2925  });
2926  if (requiresInferrence && !odsOp->hasResultTypeInferrence()) {
2928  loc,
2929  llvm::formatv("operation result types are marked to be inferred, but "
2930  "`{0}` does not provide an implementation of "
2931  "`InferTypeOpInterface`. Ensure that `{0}` attaches "
2932  "`InferTypeOpInterface` at runtime, or add support to "
2933  "the ODS definition to remove this warning.",
2934  opName));
2935  diag->attachNote(llvm::formatv("see the definition of `{0}` here", opName),
2936  odsOp->getLoc());
2937  return;
2938  }
2939 }
2940 
2941 LogicalResult Parser::validateOperationOperandsOrResults(
2942  StringRef groupName, SMRange loc, std::optional<SMRange> odsOpLoc,
2943  std::optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values,
2944  ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
2945  ast::RangeType rangeTy) {
2946  // All operation types accept a single range parameter.
2947  if (values.size() == 1) {
2948  if (failed(convertExpressionTo(values[0], rangeTy)))
2949  return failure();
2950  return success();
2951  }
2952 
2953  /// If the operation has ODS information, we can more accurately verify the
2954  /// values.
2955  if (odsOpLoc) {
2956  auto emitSizeMismatchError = [&] {
2957  return emitErrorAndNote(
2958  loc,
2959  llvm::formatv("invalid number of {0} groups for `{1}`; expected "
2960  "{2}, but got {3}",
2961  groupName, *name, odsValues.size(), values.size()),
2962  *odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name));
2963  };
2964 
2965  // Handle the case where no values were provided.
2966  if (values.empty()) {
2967  // If we don't expect any on the ODS side, we are done.
2968  if (odsValues.empty())
2969  return success();
2970 
2971  // If we do, check if we actually need to provide values (i.e. if any of
2972  // the values are actually required).
2973  unsigned numVariadic = 0;
2974  for (const auto &odsValue : odsValues) {
2975  if (!odsValue.isVariableLength())
2976  return emitSizeMismatchError();
2977  ++numVariadic;
2978  }
2979 
2980  // If we are in a non-rewrite context, we don't need to do anything more.
2981  // Zero-values is a valid constraint on the operation.
2982  if (parserContext != ParserContext::Rewrite)
2983  return success();
2984 
2985  // Otherwise, when in a rewrite we may need to provide values to match the
2986  // ODS signature of the operation to create.
2987 
2988  // If we only have one variadic value, just use an empty list.
2989  if (numVariadic == 1)
2990  return success();
2991 
2992  // Otherwise, create dummy values for each of the entries so that we
2993  // adhere to the ODS signature.
2994  for (unsigned i = 0, e = odsValues.size(); i < e; ++i) {
2995  values.push_back(ast::RangeExpr::create(
2996  ctx, loc, /*elements=*/std::nullopt, rangeTy));
2997  }
2998  return success();
2999  }
3000 
3001  // Verify that the number of values provided matches the number of value
3002  // groups ODS expects.
3003  if (odsValues.size() != values.size())
3004  return emitSizeMismatchError();
3005 
3006  auto diagFn = [&](ast::Diagnostic &diag) {
3007  diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name),
3008  *odsOpLoc);
3009  };
3010  for (unsigned i = 0, e = values.size(); i < e; ++i) {
3011  ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy;
3012  if (failed(convertExpressionTo(values[i], expectedType, diagFn)))
3013  return failure();
3014  }
3015  return success();
3016  }
3017 
3018  // Otherwise, accept the value groups as they have been defined and just
3019  // ensure they are one of the expected types.
3020  for (ast::Expr *&valueExpr : values) {
3021  ast::Type valueExprType = valueExpr->getType();
3022 
3023  // Check if this is one of the expected types.
3024  if (valueExprType == rangeTy || valueExprType == singleTy)
3025  continue;
3026 
3027  // If the operand is an Operation, allow converting to a Value or
3028  // ValueRange. This situations arises quite often with nested operation
3029  // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)`
3030  if (singleTy == valueTy) {
3031  if (isa<ast::OperationType>(valueExprType)) {
3032  valueExpr = convertOpToValue(valueExpr);
3033  continue;
3034  }
3035  }
3036 
3037  // Otherwise, try to convert the expression to a range.
3038  if (succeeded(convertExpressionTo(valueExpr, rangeTy)))
3039  continue;
3040 
3041  return emitError(
3042  valueExpr->getLoc(),
3043  llvm::formatv(
3044  "expected `{0}` or `{1}` convertible expression, but got `{2}`",
3045  singleTy, rangeTy, valueExprType));
3046  }
3047  return success();
3048 }
3049 
3050 FailureOr<ast::TupleExpr *>
3051 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements,
3052  ArrayRef<StringRef> elementNames) {
3053  for (const ast::Expr *element : elements) {
3054  ast::Type eleTy = element->getType();
3055  if (isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>(eleTy)) {
3056  return emitError(
3057  element->getLoc(),
3058  llvm::formatv("unable to build a tuple with `{0}` element", eleTy));
3059  }
3060  }
3061  return ast::TupleExpr::create(ctx, loc, elements, elementNames);
3062 }
3063 
3064 //===----------------------------------------------------------------------===//
3065 // Stmts
3066 //===----------------------------------------------------------------------===//
3067 
3068 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc,
3069  ast::Expr *rootOp) {
3070  // Check that root is an Operation.
3071  ast::Type rootType = rootOp->getType();
3072  if (!isa<ast::OperationType>(rootType))
3073  return emitError(rootOp->getLoc(), "expected `Op` expression");
3074 
3075  return ast::EraseStmt::create(ctx, loc, rootOp);
3076 }
3077 
3078 FailureOr<ast::ReplaceStmt *>
3079 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp,
3080  MutableArrayRef<ast::Expr *> replValues) {
3081  // Check that root is an Operation.
3082  ast::Type rootType = rootOp->getType();
3083  if (!isa<ast::OperationType>(rootType)) {
3084  return emitError(
3085  rootOp->getLoc(),
3086  llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
3087  }
3088 
3089  // If there are multiple replacement values, we implicitly convert any Op
3090  // expressions to the value form.
3091  bool shouldConvertOpToValues = replValues.size() > 1;
3092  for (ast::Expr *&replExpr : replValues) {
3093  ast::Type replType = replExpr->getType();
3094 
3095  // Check that replExpr is an Operation, Value, or ValueRange.
3096  if (isa<ast::OperationType>(replType)) {
3097  if (shouldConvertOpToValues)
3098  replExpr = convertOpToValue(replExpr);
3099  continue;
3100  }
3101 
3102  if (replType != valueTy && replType != valueRangeTy) {
3103  return emitError(replExpr->getLoc(),
3104  llvm::formatv("expected `Op`, `Value` or `ValueRange` "
3105  "expression, but got `{0}`",
3106  replType));
3107  }
3108  }
3109 
3110  return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues);
3111 }
3112 
3113 FailureOr<ast::RewriteStmt *>
3114 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp,
3115  ast::CompoundStmt *rewriteBody) {
3116  // Check that root is an Operation.
3117  ast::Type rootType = rootOp->getType();
3118  if (!isa<ast::OperationType>(rootType)) {
3119  return emitError(
3120  rootOp->getLoc(),
3121  llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
3122  }
3123 
3124  return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody);
3125 }
3126 
3127 //===----------------------------------------------------------------------===//
3128 // Code Completion
3129 //===----------------------------------------------------------------------===//
3130 
3131 LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) {
3132  ast::Type parentType = parentExpr->getType();
3133  if (ast::OperationType opType = dyn_cast<ast::OperationType>(parentType))
3134  codeCompleteContext->codeCompleteOperationMemberAccess(opType);
3135  else if (ast::TupleType tupleType = dyn_cast<ast::TupleType>(parentType))
3136  codeCompleteContext->codeCompleteTupleMemberAccess(tupleType);
3137  return failure();
3138 }
3139 
3140 LogicalResult
3141 Parser::codeCompleteAttributeName(std::optional<StringRef> opName) {
3142  if (opName)
3143  codeCompleteContext->codeCompleteOperationAttributeName(*opName);
3144  return failure();
3145 }
3146 
3147 LogicalResult
3148 Parser::codeCompleteConstraintName(ast::Type inferredType,
3149  bool allowInlineTypeConstraints) {
3150  codeCompleteContext->codeCompleteConstraintName(
3151  inferredType, allowInlineTypeConstraints, curDeclScope);
3152  return failure();
3153 }
3154 
3155 LogicalResult Parser::codeCompleteDialectName() {
3156  codeCompleteContext->codeCompleteDialectName();
3157  return failure();
3158 }
3159 
3160 LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) {
3161  codeCompleteContext->codeCompleteOperationName(dialectName);
3162  return failure();
3163 }
3164 
3165 LogicalResult Parser::codeCompletePatternMetadata() {
3166  codeCompleteContext->codeCompletePatternMetadata();
3167  return failure();
3168 }
3169 
3170 LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) {
3171  codeCompleteContext->codeCompleteIncludeFilename(curPath);
3172  return failure();
3173 }
3174 
3175 void Parser::codeCompleteCallSignature(ast::Node *parent,
3176  unsigned currentNumArgs) {
3177  ast::CallableDecl *callableDecl = tryExtractCallableDecl(parent);
3178  if (!callableDecl)
3179  return;
3180 
3181  codeCompleteContext->codeCompleteCallSignature(callableDecl, currentNumArgs);
3182 }
3183 
3184 void Parser::codeCompleteOperationOperandsSignature(
3185  std::optional<StringRef> opName, unsigned currentNumOperands) {
3186  codeCompleteContext->codeCompleteOperationOperandsSignature(
3187  opName, currentNumOperands);
3188 }
3189 
3190 void Parser::codeCompleteOperationResultsSignature(
3191  std::optional<StringRef> opName, unsigned currentNumResults) {
3192  codeCompleteContext->codeCompleteOperationResultsSignature(opName,
3193  currentNumResults);
3194 }
3195 
3196 //===----------------------------------------------------------------------===//
3197 // Parser
3198 //===----------------------------------------------------------------------===//
3199 
3200 FailureOr<ast::Module *>
3201 mlir::pdll::parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
3202  bool enableDocumentation,
3203  CodeCompleteContext *codeCompleteContext) {
3204  Parser parser(ctx, sourceMgr, enableDocumentation, codeCompleteContext);
3205  return parser.parseModule();
3206 }
union mlir::linalg::@1196::ArityGroupAndKind::Kind kind
static std::string diag(const llvm::Value &value)
This class breaks up the current file into a token stream.
Definition: Lexer.h:23
This represents a token in the MLIR syntax.
Definition: Token.h:20
bool isKeyword() const
Return true if this is one of the keyword token kinds (e.g. kw_if).
Definition: Token.cpp:192
SMLoc getLoc() const
Definition: Token.cpp:24
bool isNot(Kind k) const
Definition: Token.h:50
StringRef getSpelling() const
Definition: Token.h:34
This class provides an abstract interface into the parser for hooking in code completion events.
Definition: CodeComplete.h:30
virtual void codeCompleteConstraintName(ast::Type currentType, bool allowInlineTypeConstraints, const ast::DeclScope *scope)
Signal code completion for a constraint name with an optional decl scope.
virtual void codeCompleteOperationAttributeName(StringRef opName)
Signal code completion for a member access into the given operation type.
Definition: CodeComplete.h:48
virtual void codeCompleteOperationOperandsSignature(std::optional< StringRef > opName, unsigned currentNumOperands)
Signal code completion for the signature of an operation's operands.
Definition: CodeComplete.h:80
virtual void codeCompleteOperationName(StringRef dialectName)
Signal code completion for an operation name in the given dialect.
Definition: CodeComplete.h:62
virtual void codeCompleteOperationResultsSignature(std::optional< StringRef > opName, unsigned currentNumResults)
Signal code completion for the signature of an operation's results.
Definition: CodeComplete.h:85
virtual void codeCompleteDialectName()
Signal code completion for a dialect name.
Definition: CodeComplete.h:59
virtual void codeCompleteOperationMemberAccess(ast::OperationType opType)
Signal code completion for a member access into the given operation type.
virtual void codeCompleteTupleMemberAccess(ast::TupleType tupleType)
Signal code completion for a member access into the given tuple type.
virtual void codeCompletePatternMetadata()
Signal code completion for Pattern metadata.
Definition: CodeComplete.h:65
virtual void codeCompleteCallSignature(const ast::CallableDecl *callable, unsigned currentNumArgs)
Signal code completion for the signature of a callable.
Definition: CodeComplete.h:75
virtual void codeCompleteIncludeFilename(StringRef curPath)
Signal code completion for an include filename.
Definition: CodeComplete.h:68
@ code_complete_string
Token signifying a code completion location within a string.
Definition: Lexer.h:41
@ directive
Tokens.
Definition: Lexer.h:93
@ eof
Markers.
Definition: Lexer.h:36
@ code_complete
Token signifying a code completion location.
Definition: Lexer.h:39
@ arrow
Punctuation.
Definition: Lexer.h:74
@ less
Paired punctuation.
Definition: Lexer.h:82
@ kw_Attr
General keywords.
Definition: Lexer.h:54
static StringRef getMemberName()
Return the member name used for the "all-results" access.
Definition: Nodes.h:491
static AllResultsMemberAccessExpr * create(Context &ctx, SMRange loc, const Expr *parentExpr, Type type)
Definition: Nodes.h:493
The class represents an Attribute constraint, and constrains a variable to be an Attribute.
Definition: Nodes.h:754
static AttrConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr=nullptr)
Definition: Nodes.cpp:393
static AttributeExpr * create(Context &ctx, SMRange loc, StringRef value)
Definition: Nodes.cpp:261
static AttributeType get(Context &context)
Return an instance of the Attribute type.
Definition: Types.cpp:57
static CallExpr * create(Context &ctx, SMRange loc, Expr *callable, ArrayRef< Expr * > arguments, Type resultType, bool isNegated=false)
Definition: Nodes.cpp:271
This decl represents a shared interface for all callable decls.
Definition: Nodes.h:1198
Type getResultType() const
Return the result type of this decl.
Definition: Nodes.h:1216
ArrayRef< VariableDecl * > getInputs() const
Return the inputs of this decl.
Definition: Nodes.h:1209
StringRef getCallableType() const
Return the callable type of this decl.
Definition: Nodes.h:1201
This statement represents a compound statement, which contains a collection of other statements.
Definition: Nodes.h:179
ArrayRef< Stmt * >::iterator begin() const
Definition: Nodes.h:191
MutableArrayRef< Stmt * > getChildren()
Return the children of this compound statement.
Definition: Nodes.h:185
ArrayRef< Stmt * >::iterator end() const
Definition: Nodes.h:192
static CompoundStmt * create(Context &ctx, SMRange location, ArrayRef< Stmt * > children)
Definition: Nodes.cpp:192
This class represents the base of all AST Constraint decls.
Definition: Nodes.h:708
static ConstraintType get(Context &context)
Return an instance of the Constraint type.
Definition: Types.cpp:65
This class represents the main context of the PDLL AST.
Definition: Context.h:25
DiagnosticEngine & getDiagEngine()
Return the diagnostic engine of this context.
Definition: Context.h:42
ods::Context & getODSContext()
Return the ODS context used by the AST.
Definition: Context.h:38
static DeclRefExpr * create(Context &ctx, SMRange loc, Decl *decl, Type type)
Definition: Nodes.cpp:288
This class represents a scope for named AST decls.
Definition: Nodes.h:64
DeclScope * getParentScope()
Return the parent scope of this scope, or nullptr if there is no parent.
Definition: Nodes.h:70
void add(Decl *decl)
Add a new decl to the scope.
Definition: Nodes.cpp:175
This class represents the base Decl node.
Definition: Nodes.h:673
const Name * getName() const
Return the name of the decl, or nullptr if it doesn't have one.
Definition: Nodes.h:676
InFlightDiagnostic emitWarning(SMRange loc, const Twine &msg)
Definition: Diagnostic.h:149
InFlightDiagnostic emitError(SMRange loc, const Twine &msg)
Emit an error to the diagnostic engine.
Definition: Diagnostic.h:145
This class provides a simple implementation of a PDLL diagnostic.
Definition: Diagnostic.h:29
static EraseStmt * create(Context &ctx, SMRange loc, Expr *rootOp)
Definition: Nodes.cpp:219
This class represents a base AST Expression node.
Definition: Nodes.h:348
Type getType() const
Return the type of this expression.
Definition: Nodes.h:351
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostic.h:83
static LetStmt * create(Context &ctx, SMRange loc, VariableDecl *varDecl)
Definition: Nodes.cpp:207
static MemberAccessExpr * create(Context &ctx, SMRange loc, const Expr *parentExpr, StringRef memberName, Type type)
Definition: Nodes.cpp:298
static Module * create(Context &ctx, SMLoc loc, ArrayRef< Decl * > children)
Definition: Nodes.cpp:579
This Decl represents a NamedAttribute, and contains a string name and attribute value.
Definition: Nodes.h:1002
static NamedAttributeDecl * create(Context &ctx, const Name &name, Expr *value)
Definition: Nodes.cpp:502
This class represents a base AST node.
Definition: Nodes.h:108
SMRange getLoc() const
Return the location of this node.
Definition: Nodes.h:131
static OpConstraintDecl * create(Context &ctx, SMRange loc, const OpNameDecl *nameDecl=nullptr)
Definition: Nodes.cpp:403
This Decl represents an OperationName.
Definition: Nodes.h:1026
std::optional< StringRef > getName() const
Return the name of this operation, or std::nullopt if the name is unknown.
Definition: Nodes.h:1032
static OpNameDecl * create(Context &ctx, const Name &name)
Definition: Nodes.cpp:512
static OperationExpr * create(Context &ctx, SMRange loc, const ods::Operation *odsOp, const OpNameDecl *nameDecl, ArrayRef< Expr * > operands, ArrayRef< Expr * > resultTypes, ArrayRef< NamedAttributeDecl * > attributes)
Definition: Nodes.cpp:310
This class represents a PDLL type that corresponds to an mlir::Operation.
Definition: Types.h:134
const ods::Operation * getODSOperation() const
Return the ODS operation that this type refers to, or nullptr if the ODS operation is unknown.
Definition: Types.cpp:87
static OperationType get(Context &context, std::optional< StringRef > name=std::nullopt, const ods::Operation *odsOp=nullptr)
Return an instance of the Operation type with an optional operation name.
Definition: Types.cpp:73
static PatternDecl * create(Context &ctx, SMRange location, const Name *name, std::optional< uint16_t > benefit, bool hasBoundedRecursion, const CompoundStmt *body)
Definition: Nodes.cpp:523
static RangeExpr * create(Context &ctx, SMRange loc, ArrayRef< Expr * > elements, RangeType type)
Definition: Nodes.cpp:341
This class represents a PDLL type that corresponds to a range of elements with a given element type.
Definition: Types.h:159
static ReplaceStmt * create(Context &ctx, SMRange loc, Expr *rootOp, ArrayRef< Expr * > replExprs)
Definition: Nodes.cpp:227
static ReturnStmt * create(Context &ctx, SMRange loc, Expr *resultExpr)
Definition: Nodes.cpp:252
static RewriteStmt * create(Context &ctx, SMRange loc, Expr *rootOp, CompoundStmt *rewriteBody)
Definition: Nodes.cpp:242
static RewriteType get(Context &context)
Return an instance of the Rewrite type.
Definition: Types.cpp:136
This class represents a base AST Statement node.
Definition: Nodes.h:164
static TupleExpr * create(Context &ctx, SMRange loc, ArrayRef< Expr * > elements, ArrayRef< StringRef > elementNames)
Definition: Nodes.cpp:356
This class represents a PDLL tuple type, i.e.
Definition: Types.h:222
size_t size() const
Return the number of elements within this tuple.
Definition: Types.h:239
ArrayRef< Type > getElementTypes() const
Return the element types of this tuple.
Definition: Types.cpp:155
static TupleType get(Context &context, ArrayRef< Type > elementTypes, ArrayRef< StringRef > elementNames)
Return an instance of the Tuple type.
Definition: Types.cpp:144
static TypeConstraintDecl * create(Context &ctx, SMRange loc)
Definition: Nodes.cpp:420
static TypeExpr * create(Context &ctx, SMRange loc, StringRef value)
Definition: Nodes.cpp:376
static TypeRangeConstraintDecl * create(Context &ctx, SMRange loc)
Definition: Nodes.cpp:429
static TypeRangeType get(Context &context)
Return an instance of the TypeRange type.
Definition: Types.cpp:113
static TypeType get(Context &context)
Return an instance of the Type type.
Definition: Types.cpp:167
Type refineWith(Type other) const
Try to refine this type with the one provided.
Definition: Types.cpp:33
static UserConstraintDecl * createNative(Context &ctx, const Name &name, ArrayRef< VariableDecl * > inputs, ArrayRef< VariableDecl * > results, std::optional< StringRef > codeBlock, Type resultType, ArrayRef< StringRef > nativeInputTypes={})
Create a native constraint with the given optional code block.
Definition: Nodes.h:896
The class represents a Value constraint, and constrains a variable to be a Value.
Definition: Nodes.h:834
static ValueConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr)
Definition: Nodes.cpp:439
The class represents a ValueRange constraint, and constrains a variable to be a ValueRange.
Definition: Nodes.h:857
static ValueRangeConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr=nullptr)
Definition: Nodes.cpp:450
static ValueRangeType get(Context &context)
Return an instance of the ValueRange type.
Definition: Types.cpp:127
static ValueType get(Context &context)
Return an instance of the Value type.
Definition: Types.cpp:175
static VariableDecl * create(Context &ctx, const Name &name, Type type, Expr *initExpr, ArrayRef< ConstraintRef > constraints)
Definition: Nodes.cpp:561
This class contains all of the registered ODS operation classes.
Definition: Context.h:32
std::pair< Operation *, bool > insertOperation(StringRef name, StringRef summary, StringRef desc, StringRef nativeClassName, bool supportsResultTypeInferrence, SMLoc loc)
Insert a new operation with the context.
Definition: Context.cpp:64
const TypeConstraint & insertTypeConstraint(StringRef name, StringRef summary, StringRef cppClass)
Insert a new type constraint with the context.
Definition: Context.cpp:42
const AttributeConstraint & insertAttributeConstraint(StringRef name, StringRef summary, StringRef cppClass)
Insert a new attribute constraint with the context.
Definition: Context.cpp:28
const Operation * lookupOperation(StringRef name) const
Lookup an operation registered with the given name, or null if no operation with that name is registe...
Definition: Context.cpp:73
This class provides an ODS representation of a specific operation operand or result.
Definition: Operation.h:74
This class provides an ODS representation of a specific operation.
Definition: Operation.h:125
SMRange getLoc() const
Return the source location of this operation.
Definition: Operation.h:128
ArrayRef< OperandOrResult > getOperands() const
Returns the operands of this operation.
Definition: Operation.h:165
bool hasResultTypeInferrence() const
Return if the operation is known to support result type inferrence.
Definition: Operation.h:171
ArrayRef< OperandOrResult > getResults() const
Returns the results of this operation.
Definition: Operation.h:168
This class represents a generic ODS Type constraint.
Definition: Constraint.h:84
raw_ostream subclass that simplifies indention a sequence of code.
raw_indented_ostream & printReindented(StringRef str, StringRef extraPrefix="")
Prints a string re-indented to the current indent.
StringRef getSummary() const
Definition: Constraint.cpp:55
std::string getUniqueDefName() const
Returns a unique name for the TablGen def of this constraint.
Definition: Constraint.cpp:72
StringRef getDescription() const
Definition: Constraint.cpp:62
std::string getConditionTemplate() const
Definition: Constraint.cpp:51
Format context containing substitutions for special placeholders.
Definition: Format.h:40
FmtContext & withSelf(Twine subst)
Definition: Format.cpp:41
Wrapper class that contains a MLIR op's information (e.g., operands, attributes) defined in TableGen ...
Definition: Operator.h:77
FailureOr< ast::Module * > parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr, bool enableDocumentation=false, CodeCompleteContext *codeCompleteContext=nullptr)
Parse an AST module from the main file of the given source manager.
Definition: Parser.cpp:3201
auto tgfmt(StringRef fmt, const FmtContext *ctx, Ts &&...vals) -> FmtObject< decltype(std::make_tuple(llvm::support::detail::build_format_adapter(std::forward< Ts >(vals))...))>
Formats text by substituting placeholders in format string with replacement parameters.
Definition: Format.h:262
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
This class represents a reference to a constraint, and contains a constraint and the location of the ...
Definition: Nodes.h:720
const ConstraintDecl * constraint
Definition: Nodes.h:726
This class provides a convenient API for interacting with source names.
Definition: Nodes.h:37
StringRef getName() const
Return the raw string name.
Definition: Nodes.h:41
SMRange getLoc() const
Get the location of this name.
Definition: Nodes.h:44
static const Name & create(Context &ctx, StringRef name, SMRange location)
Definition: Nodes.cpp:33