MLIR  16.0.0git
Parser.cpp
Go to the documentation of this file.
1 //===- Parser.cpp ---------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "Lexer.h"
13 #include "mlir/TableGen/Argument.h"
16 #include "mlir/TableGen/Format.h"
17 #include "mlir/TableGen/Operator.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/FormatVariadic.h"
29 #include "llvm/Support/ManagedStatic.h"
30 #include "llvm/Support/SaveAndRestore.h"
31 #include "llvm/Support/ScopedPrinter.h"
32 #include "llvm/TableGen/Error.h"
33 #include "llvm/TableGen/Parser.h"
34 #include <string>
35 
36 using namespace mlir;
37 using namespace mlir::pdll;
38 
39 //===----------------------------------------------------------------------===//
40 // Parser
41 //===----------------------------------------------------------------------===//
42 
43 namespace {
44 class Parser {
45 public:
46  Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
47  bool enableDocumentation, CodeCompleteContext *codeCompleteContext)
48  : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext),
49  curToken(lexer.lexToken()), enableDocumentation(enableDocumentation),
50  valueTy(ast::ValueType::get(ctx)),
51  valueRangeTy(ast::ValueRangeType::get(ctx)),
52  typeTy(ast::TypeType::get(ctx)),
53  typeRangeTy(ast::TypeRangeType::get(ctx)),
54  attrTy(ast::AttributeType::get(ctx)),
55  codeCompleteContext(codeCompleteContext) {}
56 
57  /// Try to parse a new module. Returns nullptr in the case of failure.
58  FailureOr<ast::Module *> parseModule();
59 
60 private:
61  /// The current context of the parser. It allows for the parser to know a bit
62  /// about the construct it is nested within during parsing. This is used
63  /// specifically to provide additional verification during parsing, e.g. to
64  /// prevent using rewrites within a match context, matcher constraints within
65  /// a rewrite section, etc.
66  enum class ParserContext {
67  /// The parser is in the global context.
68  Global,
69  /// The parser is currently within a Constraint, which disallows all types
70  /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.).
71  Constraint,
72  /// The parser is currently within the matcher portion of a Pattern, which
73  /// is allows a terminal operation rewrite statement but no other rewrite
74  /// transformations.
75  PatternMatch,
76  /// The parser is currently within a Rewrite, which disallows calls to
77  /// constraints, requires operation expressions to have names, etc.
78  Rewrite,
79  };
80 
81  /// The current specification context of an operations result type. This
82  /// indicates how the result types of an operation may be inferred.
83  enum class OpResultTypeContext {
84  /// The result types of the operation are not known to be inferred.
85  Explicit,
86  /// The result types of the operation are inferred from the root input of a
87  /// `replace` statement.
88  Replacement,
89  /// The result types of the operation are inferred by using the
90  /// `InferTypeOpInterface` interface provided by the operation.
91  Interface,
92  };
93 
94  //===--------------------------------------------------------------------===//
95  // Parsing
96  //===--------------------------------------------------------------------===//
97 
98  /// Push a new decl scope onto the lexer.
99  ast::DeclScope *pushDeclScope() {
100  ast::DeclScope *newScope =
101  new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope);
102  return (curDeclScope = newScope);
103  }
104  void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; }
105 
106  /// Pop the last decl scope from the lexer.
107  void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); }
108 
109  /// Parse the body of an AST module.
110  LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls);
111 
112  /// Try to convert the given expression to `type`. Returns failure and emits
113  /// an error if a conversion is not viable. On failure, `noteAttachFn` is
114  /// invoked to attach notes to the emitted error diagnostic. On success,
115  /// `expr` is updated to the expression used to convert to `type`.
116  LogicalResult convertExpressionTo(
117  ast::Expr *&expr, ast::Type type,
118  function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {});
119 
120  /// Given an operation expression, convert it to a Value or ValueRange
121  /// typed expression.
122  ast::Expr *convertOpToValue(const ast::Expr *opExpr);
123 
124  /// Lookup ODS information for the given operation, returns nullptr if no
125  /// information is found.
126  const ods::Operation *lookupODSOperation(Optional<StringRef> opName) {
127  return opName ? ctx.getODSContext().lookupOperation(*opName) : nullptr;
128  }
129 
130  /// Process the given documentation string, or return an empty string if
131  /// documentation isn't enabled.
132  StringRef processDoc(StringRef doc) {
133  return enableDocumentation ? doc : StringRef();
134  }
135 
136  /// Process the given documentation string and format it, or return an empty
137  /// string if documentation isn't enabled.
138  std::string processAndFormatDoc(const Twine &doc) {
139  if (!enableDocumentation)
140  return "";
141  std::string docStr;
142  {
143  llvm::raw_string_ostream docOS(docStr);
144  std::string tmpDocStr = doc.str();
146  StringRef(tmpDocStr).rtrim(" \t"));
147  }
148  return docStr;
149  }
150 
151  //===--------------------------------------------------------------------===//
152  // Directives
153 
154  LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls);
155  LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls);
156  LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc,
158 
159  /// Process the records of a parsed tablegen include file.
160  void processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
162 
163  /// Create a user defined native constraint for a constraint imported from
164  /// ODS.
165  template <typename ConstraintT>
166  ast::Decl *
167  createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock,
168  SMRange loc, ast::Type type,
169  StringRef nativeType, StringRef docString);
170  template <typename ConstraintT>
171  ast::Decl *
172  createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
173  SMRange loc, ast::Type type,
174  StringRef nativeType);
175 
176  //===--------------------------------------------------------------------===//
177  // Decls
178 
179  /// This structure contains the set of pattern metadata that may be parsed.
180  struct ParsedPatternMetadata {
181  Optional<uint16_t> benefit;
182  bool hasBoundedRecursion = false;
183  };
184 
185  FailureOr<ast::Decl *> parseTopLevelDecl();
187  parseNamedAttributeDecl(Optional<StringRef> parentOpName);
188 
189  /// Parse an argument variable as part of the signature of a
190  /// UserConstraintDecl or UserRewriteDecl.
191  FailureOr<ast::VariableDecl *> parseArgumentDecl();
192 
193  /// Parse a result variable as part of the signature of a UserConstraintDecl
194  /// or UserRewriteDecl.
195  FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum);
196 
197  /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being
198  /// defined in a non-global context.
200  parseUserConstraintDecl(bool isInline = false);
201 
202  /// Parse an inline UserConstraintDecl. An inline decl is one defined in a
203  /// non-global context, such as within a Pattern/Constraint/etc.
204  FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl();
205 
206  /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
207  /// PDLL constructs.
208  FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl(
209  const ast::Name &name, bool isInline,
210  ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
211  ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
212 
213  /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being
214  /// defined in a non-global context.
215  FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false);
216 
217  /// Parse an inline UserRewriteDecl. An inline decl is one defined in a
218  /// non-global context, such as within a Pattern/Rewrite/etc.
219  FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl();
220 
221  /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
222  /// PDLL constructs.
223  FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl(
224  const ast::Name &name, bool isInline,
225  ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
226  ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
227 
228  /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have
229  /// effectively the same syntax, and only differ on slight semantics (given
230  /// the different parsing contexts).
231  template <typename T, typename ParseUserPDLLDeclFnT>
232  FailureOr<T *> parseUserConstraintOrRewriteDecl(
233  ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
234  StringRef anonymousNamePrefix, bool isInline);
235 
236  /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl.
237  /// These decls have effectively the same syntax.
238  template <typename T>
239  FailureOr<T *> parseUserNativeConstraintOrRewriteDecl(
240  const ast::Name &name, bool isInline,
242  ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
243 
244  /// Parse the functional signature (i.e. the arguments and results) of a
245  /// UserConstraintDecl or UserRewriteDecl.
246  LogicalResult parseUserConstraintOrRewriteSignature(
249  ast::DeclScope *&argumentScope, ast::Type &resultType);
250 
251  /// Validate the return (which if present is specified by bodyIt) of a
252  /// UserConstraintDecl or UserRewriteDecl.
253  LogicalResult validateUserConstraintOrRewriteReturn(
254  StringRef declType, ast::CompoundStmt *body,
257  ArrayRef<ast::VariableDecl *> results, ast::Type &resultType);
258 
260  parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
261  bool expectTerminalSemicolon = true);
262  FailureOr<ast::CompoundStmt *> parsePatternLambdaBody();
263  FailureOr<ast::Decl *> parsePatternDecl();
264  LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
265 
266  /// Check to see if a decl has already been defined with the given name, if
267  /// one has emit and error and return failure. Returns success otherwise.
268  LogicalResult checkDefineNamedDecl(const ast::Name &name);
269 
270  /// Try to define a variable decl with the given components, returns the
271  /// variable on success.
273  defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
274  ast::Expr *initExpr,
275  ArrayRef<ast::ConstraintRef> constraints);
277  defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
278  ArrayRef<ast::ConstraintRef> constraints);
279 
280  /// Parse the constraint reference list for a variable decl.
281  LogicalResult parseVariableDeclConstraintList(
283 
284  /// Parse the expression used within a type constraint, e.g. Attr<type-expr>.
285  FailureOr<ast::Expr *> parseTypeConstraintExpr();
286 
287  /// Try to parse a single reference to a constraint. `typeConstraint` is the
288  /// location of a previously parsed type constraint for the entity that will
289  /// be constrained by the parsed constraint. `existingConstraints` are any
290  /// existing constraints that have already been parsed for the same entity
291  /// that will be constrained by this constraint. `allowInlineTypeConstraints`
292  /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`.
293  /// If `allowNonCoreConstraints` is true, then complex (e.g. user defined
294  /// constraints) may be used with the variable.
296  parseConstraint(Optional<SMRange> &typeConstraint,
297  ArrayRef<ast::ConstraintRef> existingConstraints,
298  bool allowInlineTypeConstraints,
299  bool allowNonCoreConstraints);
300 
301  /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl
302  /// argument or result variable. The constraints for these variables do not
303  /// allow inline type constraints, and only permit a single constraint.
304  FailureOr<ast::ConstraintRef> parseArgOrResultConstraint();
305 
306  //===--------------------------------------------------------------------===//
307  // Exprs
308 
309  FailureOr<ast::Expr *> parseExpr();
310 
311  /// Identifier expressions.
312  FailureOr<ast::Expr *> parseAttributeExpr();
313  FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr);
314  FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
315  FailureOr<ast::Expr *> parseIdentifierExpr();
316  FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
317  FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
318  FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
319  FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
320  FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
322  parseOperationExpr(OpResultTypeContext inputResultTypeContext =
323  OpResultTypeContext::Explicit);
324  FailureOr<ast::Expr *> parseTupleExpr();
325  FailureOr<ast::Expr *> parseTypeExpr();
326  FailureOr<ast::Expr *> parseUnderscoreExpr();
327 
328  //===--------------------------------------------------------------------===//
329  // Stmts
330 
331  FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true);
332  FailureOr<ast::CompoundStmt *> parseCompoundStmt();
333  FailureOr<ast::EraseStmt *> parseEraseStmt();
334  FailureOr<ast::LetStmt *> parseLetStmt();
335  FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
336  FailureOr<ast::ReturnStmt *> parseReturnStmt();
337  FailureOr<ast::RewriteStmt *> parseRewriteStmt();
338 
339  //===--------------------------------------------------------------------===//
340  // Creation+Analysis
341  //===--------------------------------------------------------------------===//
342 
343  //===--------------------------------------------------------------------===//
344  // Decls
345 
346  /// Try to extract a callable from the given AST node. Returns nullptr on
347  /// failure.
348  ast::CallableDecl *tryExtractCallableDecl(ast::Node *node);
349 
350  /// Try to create a pattern decl with the given components, returning the
351  /// Pattern on success.
353  createPatternDecl(SMRange loc, const ast::Name *name,
354  const ParsedPatternMetadata &metadata,
355  ast::CompoundStmt *body);
356 
357  /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set
358  /// of results, defined as part of the signature.
359  ast::Type
360  createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results);
361 
362  /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl.
363  template <typename T>
364  FailureOr<T *> createUserPDLLConstraintOrRewriteDecl(
365  const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
366  ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
367  ast::CompoundStmt *body);
368 
369  /// Try to create a variable decl with the given components, returning the
370  /// Variable on success.
372  createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
373  ArrayRef<ast::ConstraintRef> constraints);
374 
375  /// Create a variable for an argument or result defined as part of the
376  /// signature of a UserConstraintDecl/UserRewriteDecl.
378  createArgOrResultVariableDecl(StringRef name, SMRange loc,
379  const ast::ConstraintRef &constraint);
380 
381  /// Validate the constraints used to constraint a variable decl.
382  /// `inferredType` is the type of the variable inferred by the constraints
383  /// within the list, and is updated to the most refined type as determined by
384  /// the constraints. Returns success if the constraint list is valid, failure
385  /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user
386  /// defined constraints) may be used with the variable.
388  validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
389  ast::Type &inferredType,
390  bool allowNonCoreConstraints = true);
391  /// Validate a single reference to a constraint. `inferredType` contains the
392  /// currently inferred variabled type and is refined within the type defined
393  /// by the constraint. Returns success if the constraint is valid, failure
394  /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user
395  /// defined constraints) may be used with the variable.
396  LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref,
397  ast::Type &inferredType,
398  bool allowNonCoreConstraints = true);
399  LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr);
400  LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr);
401 
402  //===--------------------------------------------------------------------===//
403  // Exprs
404 
406  createCallExpr(SMRange loc, ast::Expr *parentExpr,
407  MutableArrayRef<ast::Expr *> arguments);
408  FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl);
410  createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
411  ArrayRef<ast::ConstraintRef> constraints);
413  createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc);
414 
415  /// Validate the member access `name` into the given parent expression. On
416  /// success, this also returns the type of the member accessed.
417  FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr,
418  StringRef name, SMRange loc);
420  createOperationExpr(SMRange loc, const ast::OpNameDecl *name,
421  OpResultTypeContext resultTypeContext,
426  validateOperationOperands(SMRange loc, Optional<StringRef> name,
427  const ods::Operation *odsOp,
429  LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name,
430  const ods::Operation *odsOp,
432  void checkOperationResultTypeInferrence(SMRange loc, StringRef name,
433  const ods::Operation *odsOp);
434  LogicalResult validateOperationOperandsOrResults(
435  StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc,
437  ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
438  ast::Type rangeTy);
439  FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
440  ArrayRef<ast::Expr *> elements,
441  ArrayRef<StringRef> elementNames);
442 
443  //===--------------------------------------------------------------------===//
444  // Stmts
445 
446  FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp);
448  createReplaceStmt(SMRange loc, ast::Expr *rootOp,
449  MutableArrayRef<ast::Expr *> replValues);
451  createRewriteStmt(SMRange loc, ast::Expr *rootOp,
452  ast::CompoundStmt *rewriteBody);
453 
454  //===--------------------------------------------------------------------===//
455  // Code Completion
456  //===--------------------------------------------------------------------===//
457 
458  /// The set of various code completion methods. Every completion method
459  /// returns `failure` to stop the parsing process after providing completion
460  /// results.
461 
462  LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr);
463  LogicalResult codeCompleteAttributeName(Optional<StringRef> opName);
464  LogicalResult codeCompleteConstraintName(ast::Type inferredType,
465  bool allowNonCoreConstraints,
466  bool allowInlineTypeConstraints);
467  LogicalResult codeCompleteDialectName();
468  LogicalResult codeCompleteOperationName(StringRef dialectName);
469  LogicalResult codeCompletePatternMetadata();
470  LogicalResult codeCompleteIncludeFilename(StringRef curPath);
471 
472  void codeCompleteCallSignature(ast::Node *parent, unsigned currentNumArgs);
473  void codeCompleteOperationOperandsSignature(Optional<StringRef> opName,
474  unsigned currentNumOperands);
475  void codeCompleteOperationResultsSignature(Optional<StringRef> opName,
476  unsigned currentNumResults);
477 
478  //===--------------------------------------------------------------------===//
479  // Lexer Utilities
480  //===--------------------------------------------------------------------===//
481 
482  /// If the current token has the specified kind, consume it and return true.
483  /// If not, return false.
484  bool consumeIf(Token::Kind kind) {
485  if (curToken.isNot(kind))
486  return false;
487  consumeToken(kind);
488  return true;
489  }
490 
491  /// Advance the current lexer onto the next token.
492  void consumeToken() {
493  assert(curToken.isNot(Token::eof, Token::error) &&
494  "shouldn't advance past EOF or errors");
495  curToken = lexer.lexToken();
496  }
497 
498  /// Advance the current lexer onto the next token, asserting what the expected
499  /// current token is. This is preferred to the above method because it leads
500  /// to more self-documenting code with better checking.
501  void consumeToken(Token::Kind kind) {
502  assert(curToken.is(kind) && "consumed an unexpected token");
503  consumeToken();
504  }
505 
506  /// Reset the lexer to the location at the given position.
507  void resetToken(SMRange tokLoc) {
508  lexer.resetPointer(tokLoc.Start.getPointer());
509  curToken = lexer.lexToken();
510  }
511 
512  /// Consume the specified token if present and return success. On failure,
513  /// output a diagnostic and return failure.
514  LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
515  if (curToken.getKind() != kind)
516  return emitError(curToken.getLoc(), msg);
517  consumeToken();
518  return success();
519  }
520  LogicalResult emitError(SMRange loc, const Twine &msg) {
521  lexer.emitError(loc, msg);
522  return failure();
523  }
524  LogicalResult emitError(const Twine &msg) {
525  return emitError(curToken.getLoc(), msg);
526  }
527  LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc,
528  const Twine &note) {
529  lexer.emitErrorAndNote(loc, msg, noteLoc, note);
530  return failure();
531  }
532 
533  //===--------------------------------------------------------------------===//
534  // Fields
535  //===--------------------------------------------------------------------===//
536 
537  /// The owning AST context.
538  ast::Context &ctx;
539 
540  /// The lexer of this parser.
541  Lexer lexer;
542 
543  /// The current token within the lexer.
544  Token curToken;
545 
546  /// A flag indicating if the parser should add documentation to AST nodes when
547  /// viable.
548  bool enableDocumentation;
549 
550  /// The most recently defined decl scope.
551  ast::DeclScope *curDeclScope = nullptr;
552  llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;
553 
554  /// The current context of the parser.
555  ParserContext parserContext = ParserContext::Global;
556 
557  /// Cached types to simplify verification and expression creation.
558  ast::Type valueTy, valueRangeTy;
559  ast::Type typeTy, typeRangeTy;
560  ast::Type attrTy;
561 
562  /// A counter used when naming anonymous constraints and rewrites.
563  unsigned anonymousDeclNameCounter = 0;
564 
565  /// The optional code completion context.
566  CodeCompleteContext *codeCompleteContext;
567 };
568 } // namespace
569 
570 FailureOr<ast::Module *> Parser::parseModule() {
571  SMLoc moduleLoc = curToken.getStartLoc();
572  pushDeclScope();
573 
574  // Parse the top-level decls of the module.
576  if (failed(parseModuleBody(decls)))
577  return popDeclScope(), failure();
578 
579  popDeclScope();
580  return ast::Module::create(ctx, moduleLoc, decls);
581 }
582 
583 LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) {
584  while (curToken.isNot(Token::eof)) {
585  if (curToken.is(Token::directive)) {
586  if (failed(parseDirective(decls)))
587  return failure();
588  continue;
589  }
590 
591  FailureOr<ast::Decl *> decl = parseTopLevelDecl();
592  if (failed(decl))
593  return failure();
594  decls.push_back(*decl);
595  }
596  return success();
597 }
598 
599 ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) {
600  return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr,
601  valueRangeTy);
602 }
603 
604 LogicalResult Parser::convertExpressionTo(
605  ast::Expr *&expr, ast::Type type,
606  function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
607  ast::Type exprType = expr->getType();
608  if (exprType == type)
609  return success();
610 
611  auto emitConvertError = [&]() -> ast::InFlightDiagnostic {
613  expr->getLoc(), llvm::formatv("unable to convert expression of type "
614  "`{0}` to the expected type of "
615  "`{1}`",
616  exprType, type));
617  if (noteAttachFn)
618  noteAttachFn(*diag);
619  return diag;
620  };
621 
622  if (auto exprOpType = exprType.dyn_cast<ast::OperationType>()) {
623  // Two operation types are compatible if they have the same name, or if the
624  // expected type is more general.
625  if (auto opType = type.dyn_cast<ast::OperationType>()) {
626  if (opType.getName())
627  return emitConvertError();
628  return success();
629  }
630 
631  // An operation can always convert to a ValueRange.
632  if (type == valueRangeTy) {
633  expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
634  valueRangeTy);
635  return success();
636  }
637 
638  // Allow conversion to a single value by constraining the result range.
639  if (type == valueTy) {
640  // If the operation is registered, we can verify if it can ever have a
641  // single result.
642  if (const ods::Operation *odsOp = exprOpType.getODSOperation()) {
643  if (odsOp->getResults().empty()) {
644  return emitConvertError()->attachNote(
645  llvm::formatv("see the definition of `{0}`, which was defined "
646  "with zero results",
647  odsOp->getName()),
648  odsOp->getLoc());
649  }
650 
651  unsigned numSingleResults = llvm::count_if(
652  odsOp->getResults(), [](const ods::OperandOrResult &result) {
653  return result.getVariableLengthKind() ==
655  });
656  if (numSingleResults > 1) {
657  return emitConvertError()->attachNote(
658  llvm::formatv("see the definition of `{0}`, which was defined "
659  "with at least {1} results",
660  odsOp->getName(), numSingleResults),
661  odsOp->getLoc());
662  }
663  }
664 
665  expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
666  valueTy);
667  return success();
668  }
669  return emitConvertError();
670  }
671 
672  // FIXME: Decide how to allow/support converting a single result to multiple,
673  // and multiple to a single result. For now, we just allow Single->Range,
674  // but this isn't something really supported in the PDL dialect. We should
675  // figure out some way to support both.
676  if ((exprType == valueTy || exprType == valueRangeTy) &&
677  (type == valueTy || type == valueRangeTy))
678  return success();
679  if ((exprType == typeTy || exprType == typeRangeTy) &&
680  (type == typeTy || type == typeRangeTy))
681  return success();
682 
683  // Handle tuple types.
684  if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>()) {
685  auto tupleType = type.dyn_cast<ast::TupleType>();
686  if (!tupleType || tupleType.size() != exprTupleType.size())
687  return emitConvertError();
688 
689  // Build a new tuple expression using each of the elements of the current
690  // tuple.
691  SmallVector<ast::Expr *> newExprs;
692  for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) {
693  newExprs.push_back(ast::MemberAccessExpr::create(
694  ctx, expr->getLoc(), expr, llvm::to_string(i),
695  exprTupleType.getElementTypes()[i]));
696 
697  auto diagFn = [&](ast::Diagnostic &diag) {
698  diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`",
699  i, exprTupleType));
700  if (noteAttachFn)
701  noteAttachFn(diag);
702  };
703  if (failed(convertExpressionTo(newExprs.back(),
704  tupleType.getElementTypes()[i], diagFn)))
705  return failure();
706  }
707  expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs,
708  tupleType.getElementNames());
709  return success();
710  }
711 
712  return emitConvertError();
713 }
714 
715 //===----------------------------------------------------------------------===//
716 // Directives
717 
718 LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) {
719  StringRef directive = curToken.getSpelling();
720  if (directive == "#include")
721  return parseInclude(decls);
722 
723  return emitError("unknown directive `" + directive + "`");
724 }
725 
726 LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) {
727  SMRange loc = curToken.getLoc();
728  consumeToken(Token::directive);
729 
730  // Handle code completion of the include file path.
731  if (curToken.is(Token::code_complete_string))
732  return codeCompleteIncludeFilename(curToken.getStringValue());
733 
734  // Parse the file being included.
735  if (!curToken.isString())
736  return emitError(loc,
737  "expected string file name after `include` directive");
738  SMRange fileLoc = curToken.getLoc();
739  std::string filenameStr = curToken.getStringValue();
740  StringRef filename = filenameStr;
741  consumeToken();
742 
743  // Check the type of include. If ending with `.pdll`, this is another pdl file
744  // to be parsed along with the current module.
745  if (filename.endswith(".pdll")) {
746  if (failed(lexer.pushInclude(filename, fileLoc)))
747  return emitError(fileLoc,
748  "unable to open include file `" + filename + "`");
749 
750  // If we added the include successfully, parse it into the current module.
751  // Make sure to update to the next token after we finish parsing the nested
752  // file.
753  curToken = lexer.lexToken();
754  LogicalResult result = parseModuleBody(decls);
755  curToken = lexer.lexToken();
756  return result;
757  }
758 
759  // Otherwise, this must be a `.td` include.
760  if (filename.endswith(".td"))
761  return parseTdInclude(filename, fileLoc, decls);
762 
763  return emitError(fileLoc,
764  "expected include filename to end with `.pdll` or `.td`");
765 }
766 
767 LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc,
769  llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr();
770 
771  // Use the source manager to open the file, but don't yet add it.
772  std::string includedFile;
773  llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer =
774  parserSrcMgr.OpenIncludeFile(filename.str(), includedFile);
775  if (!includeBuffer)
776  return emitError(fileLoc, "unable to open include file `" + filename + "`");
777 
778  // Setup the source manager for parsing the tablegen file.
779  llvm::SourceMgr tdSrcMgr;
780  tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc());
781  tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs());
782 
783  // This class provides a context argument for the llvm::SourceMgr diagnostic
784  // handler.
785  struct DiagHandlerContext {
786  Parser &parser;
787  StringRef filename;
788  llvm::SMRange loc;
789  } handlerContext{*this, filename, fileLoc};
790 
791  // Set the diagnostic handler for the tablegen source manager.
792  tdSrcMgr.setDiagHandler(
793  [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) {
794  auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext);
795  (void)ctx->parser.emitError(
796  ctx->loc,
797  llvm::formatv("error while processing include file `{0}`: {1}",
798  ctx->filename, diag.getMessage()));
799  },
800  &handlerContext);
801 
802  // Parse the tablegen file.
803  llvm::RecordKeeper tdRecords;
804  if (llvm::TableGenParseFile(tdSrcMgr, tdRecords))
805  return failure();
806 
807  // Process the parsed records.
808  processTdIncludeRecords(tdRecords, decls);
809 
810  // After we are done processing, move all of the tablegen source buffers to
811  // the main parser source mgr. This allows for directly using source locations
812  // from the .td files without needing to remap them.
813  parserSrcMgr.takeSourceBuffersFrom(tdSrcMgr, fileLoc.End);
814  return success();
815 }
816 
817 void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
819  // Return the length kind of the given value.
820  auto getLengthKind = [](const auto &value) {
821  if (value.isOptional())
823  return value.isVariadic() ? ods::VariableLengthKind::Variadic
825  };
826 
827  // Insert a type constraint into the ODS context.
828  ods::Context &odsContext = ctx.getODSContext();
829  auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst)
830  -> const ods::TypeConstraint & {
831  return odsContext.insertTypeConstraint(
832  cst.constraint.getUniqueDefName(),
833  processDoc(cst.constraint.getSummary()),
834  cst.constraint.getCPPClassName());
835  };
836  auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange {
837  return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)};
838  };
839 
840  // Process the parsed tablegen records to build ODS information.
841  /// Operations.
842  for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) {
843  tblgen::Operator op(def);
844 
845  // Check to see if this operation is known to support type inferrence.
846  bool supportsResultTypeInferrence =
847  op.getTrait("::mlir::InferTypeOpInterface::Trait");
848 
849  bool inserted = false;
850  ods::Operation *odsOp = nullptr;
851  std::tie(odsOp, inserted) = odsContext.insertOperation(
852  op.getOperationName(), processDoc(op.getSummary()),
853  processAndFormatDoc(op.getDescription()), op.getQualCppClassName(),
854  supportsResultTypeInferrence, op.getLoc().front());
855 
856  // Ignore operations that have already been added.
857  if (!inserted)
858  continue;
859 
860  for (const tblgen::NamedAttribute &attr : op.getAttributes()) {
861  odsOp->appendAttribute(attr.name, attr.attr.isOptional(),
862  odsContext.insertAttributeConstraint(
863  attr.attr.getUniqueDefName(),
864  processDoc(attr.attr.getSummary()),
865  attr.attr.getStorageType()));
866  }
867  for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) {
868  odsOp->appendOperand(operand.name, getLengthKind(operand),
869  addTypeConstraint(operand));
870  }
871  for (const tblgen::NamedTypeConstraint &result : op.getResults()) {
872  odsOp->appendResult(result.name, getLengthKind(result),
873  addTypeConstraint(result));
874  }
875  }
876 
877  auto shouldBeSkipped = [this](llvm::Record *def) {
878  return def->isAnonymous() || curDeclScope->lookup(def->getName()) ||
879  def->isSubClassOf("DeclareInterfaceMethods");
880  };
881 
882  /// Attr constraints.
883  for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) {
884  if (shouldBeSkipped(def))
885  continue;
886 
887  tblgen::Attribute constraint(def);
888  decls.push_back(createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
889  constraint, convertLocToRange(def->getLoc().front()), attrTy,
890  constraint.getStorageType()));
891  }
892  /// Type constraints.
893  for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) {
894  if (shouldBeSkipped(def))
895  continue;
896 
897  tblgen::TypeConstraint constraint(def);
898  decls.push_back(createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
899  constraint, convertLocToRange(def->getLoc().front()), typeTy,
900  constraint.getCPPClassName()));
901  }
902  /// OpInterfaces.
904  for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("OpInterface")) {
905  if (shouldBeSkipped(def))
906  continue;
907 
908  SMRange loc = convertLocToRange(def->getLoc().front());
909 
910  std::string cppClassName =
911  llvm::formatv("{0}::{1}", def->getValueAsString("cppNamespace"),
912  def->getValueAsString("cppInterfaceName"))
913  .str();
914  std::string codeBlock =
915  llvm::formatv("return ::mlir::success(llvm::isa<{0}>(self));",
916  cppClassName)
917  .str();
918 
919  std::string desc =
920  processAndFormatDoc(def->getValueAsString("description"));
921  decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>(
922  def->getName(), codeBlock, loc, opTy, cppClassName, desc));
923  }
924 }
925 
926 template <typename ConstraintT>
927 ast::Decl *Parser::createODSNativePDLLConstraintDecl(
928  StringRef name, StringRef codeBlock, SMRange loc, ast::Type type,
929  StringRef nativeType, StringRef docString) {
930  // Build the single input parameter.
931  ast::DeclScope *argScope = pushDeclScope();
932  auto *paramVar = ast::VariableDecl::create(
933  ctx, ast::Name::create(ctx, "self", loc), type,
934  /*initExpr=*/nullptr, ast::ConstraintRef(ConstraintT::create(ctx, loc)));
935  argScope->add(paramVar);
936  popDeclScope();
937 
938  // Build the native constraint.
939  auto *constraintDecl = ast::UserConstraintDecl::createNative(
940  ctx, ast::Name::create(ctx, name, loc), paramVar,
941  /*results=*/llvm::None, codeBlock, ast::TupleType::get(ctx), nativeType);
942  constraintDecl->setDocComment(ctx, docString);
943  curDeclScope->add(constraintDecl);
944  return constraintDecl;
945 }
946 
947 template <typename ConstraintT>
948 ast::Decl *
949 Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
950  SMRange loc, ast::Type type,
951  StringRef nativeType) {
952  // Format the condition template.
953  tblgen::FmtContext fmtContext;
954  fmtContext.withSelf("self");
955  std::string codeBlock = tblgen::tgfmt(
956  "return ::mlir::success(" + constraint.getConditionTemplate() + ");",
957  &fmtContext);
958 
959  // If documentation was enabled, build the doc string for the generated
960  // constraint. It would be nice to do this lazily, but TableGen information is
961  // destroyed after we finish parsing the file.
962  std::string docString;
963  if (enableDocumentation) {
964  StringRef desc = constraint.getDescription();
965  docString = processAndFormatDoc(
966  constraint.getSummary() +
967  (desc.empty() ? "" : ("\n\n" + constraint.getDescription())));
968  }
969 
970  return createODSNativePDLLConstraintDecl<ConstraintT>(
971  constraint.getUniqueDefName(), codeBlock, loc, type, nativeType,
972  docString);
973 }
974 
975 //===----------------------------------------------------------------------===//
976 // Decls
977 
978 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
980  switch (curToken.getKind()) {
982  decl = parseUserConstraintDecl();
983  break;
984  case Token::kw_Pattern:
985  decl = parsePatternDecl();
986  break;
987  case Token::kw_Rewrite:
988  decl = parseUserRewriteDecl();
989  break;
990  default:
991  return emitError("expected top-level declaration, such as a `Pattern`");
992  }
993  if (failed(decl))
994  return failure();
995 
996  // If the decl has a name, add it to the current scope.
997  if (const ast::Name *name = (*decl)->getName()) {
998  if (failed(checkDefineNamedDecl(*name)))
999  return failure();
1000  curDeclScope->add(*decl);
1001  }
1002  return decl;
1003 }
1004 
1006 Parser::parseNamedAttributeDecl(Optional<StringRef> parentOpName) {
1007  // Check for name code completion.
1008  if (curToken.is(Token::code_complete))
1009  return codeCompleteAttributeName(parentOpName);
1010 
1011  std::string attrNameStr;
1012  if (curToken.isString())
1013  attrNameStr = curToken.getStringValue();
1014  else if (curToken.is(Token::identifier) || curToken.isKeyword())
1015  attrNameStr = curToken.getSpelling().str();
1016  else
1017  return emitError("expected identifier or string attribute name");
1018  const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc());
1019  consumeToken();
1020 
1021  // Check for a value of the attribute.
1022  ast::Expr *attrValue = nullptr;
1023  if (consumeIf(Token::equal)) {
1024  FailureOr<ast::Expr *> attrExpr = parseExpr();
1025  if (failed(attrExpr))
1026  return failure();
1027  attrValue = *attrExpr;
1028  } else {
1029  // If there isn't a concrete value, create an expression representing a
1030  // UnitAttr.
1031  attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit");
1032  }
1033 
1034  return ast::NamedAttributeDecl::create(ctx, name, attrValue);
1035 }
1036 
1037 FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody(
1038  function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
1039  bool expectTerminalSemicolon) {
1040  consumeToken(Token::equal_arrow);
1041 
1042  // Parse the single statement of the lambda body.
1043  SMLoc bodyStartLoc = curToken.getStartLoc();
1044  pushDeclScope();
1045  FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon);
1046  bool failedToParse =
1047  failed(singleStatement) || failed(processStatementFn(*singleStatement));
1048  popDeclScope();
1049  if (failedToParse)
1050  return failure();
1051 
1052  SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc());
1053  return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement);
1054 }
1055 
1056 FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() {
1057  // Ensure that the argument is named.
1058  if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword())
1059  return emitError("expected identifier argument name");
1060 
1061  // Parse the argument similarly to a normal variable.
1062  StringRef name = curToken.getSpelling();
1063  SMRange nameLoc = curToken.getLoc();
1064  consumeToken();
1065 
1066  if (failed(
1067  parseToken(Token::colon, "expected `:` before argument constraint")))
1068  return failure();
1069 
1070  FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1071  if (failed(cst))
1072  return failure();
1073 
1074  return createArgOrResultVariableDecl(name, nameLoc, *cst);
1075 }
1076 
1077 FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) {
1078  // Check to see if this result is named.
1079  if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
1080  // Check to see if this name actually refers to a Constraint.
1081  ast::Decl *existingDecl = curDeclScope->lookup(curToken.getSpelling());
1082  if (isa_and_nonnull<ast::ConstraintDecl>(existingDecl)) {
1083  // If yes, and this is a Rewrite, give a nice error message as non-Core
1084  // constraints are not supported on Rewrite results.
1085  if (parserContext == ParserContext::Rewrite) {
1086  return emitError(
1087  "`Rewrite` results are only permitted to use core constraints, "
1088  "such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`");
1089  }
1090 
1091  // Otherwise, parse this as an unnamed result variable.
1092  } else {
1093  // If it wasn't a constraint, parse the result similarly to a variable. If
1094  // there is already an existing decl, we will emit an error when defining
1095  // this variable later.
1096  StringRef name = curToken.getSpelling();
1097  SMRange nameLoc = curToken.getLoc();
1098  consumeToken();
1099 
1100  if (failed(parseToken(Token::colon,
1101  "expected `:` before result constraint")))
1102  return failure();
1103 
1104  FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1105  if (failed(cst))
1106  return failure();
1107 
1108  return createArgOrResultVariableDecl(name, nameLoc, *cst);
1109  }
1110  }
1111 
1112  // If it isn't named, we parse the constraint directly and create an unnamed
1113  // result variable.
1114  FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
1115  if (failed(cst))
1116  return failure();
1117 
1118  return createArgOrResultVariableDecl("", cst->referenceLoc, *cst);
1119 }
1120 
1122 Parser::parseUserConstraintDecl(bool isInline) {
1123  // Constraints and rewrites have very similar formats, dispatch to a shared
1124  // interface for parsing.
1125  return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1126  [&](auto &&...args) {
1127  return this->parseUserPDLLConstraintDecl(args...);
1128  },
1129  ParserContext::Constraint, "constraint", isInline);
1130 }
1131 
1132 FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() {
1134  parseUserConstraintDecl(/*isInline=*/true);
1135  if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1136  return failure();
1137 
1138  curDeclScope->add(*decl);
1139  return decl;
1140 }
1141 
1142 FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl(
1143  const ast::Name &name, bool isInline,
1144  ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1145  ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1146  // Push the argument scope back onto the list, so that the body can
1147  // reference arguments.
1148  pushDeclScope(argumentScope);
1149 
1150  // Parse the body of the constraint. The body is either defined as a compound
1151  // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`.
1152  ast::CompoundStmt *body;
1153  if (curToken.is(Token::equal_arrow)) {
1154  FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1155  [&](ast::Stmt *&stmt) -> LogicalResult {
1156  ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt);
1157  if (!stmtExpr) {
1158  return emitError(stmt->getLoc(),
1159  "expected `Constraint` lambda body to contain a "
1160  "single expression");
1161  }
1162  stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr);
1163  return success();
1164  },
1165  /*expectTerminalSemicolon=*/!isInline);
1166  if (failed(bodyResult))
1167  return failure();
1168  body = *bodyResult;
1169  } else {
1170  FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1171  if (failed(bodyResult))
1172  return failure();
1173  body = *bodyResult;
1174 
1175  // Verify the structure of the body.
1176  auto bodyIt = body->begin(), bodyE = body->end();
1177  for (; bodyIt != bodyE; ++bodyIt)
1178  if (isa<ast::ReturnStmt>(*bodyIt))
1179  break;
1180  if (failed(validateUserConstraintOrRewriteReturn(
1181  "Constraint", body, bodyIt, bodyE, results, resultType)))
1182  return failure();
1183  }
1184  popDeclScope();
1185 
1186  return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>(
1187  name, arguments, results, resultType, body);
1188 }
1189 
1190 FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) {
1191  // Constraints and rewrites have very similar formats, dispatch to a shared
1192  // interface for parsing.
1193  return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1194  [&](auto &&...args) { return this->parseUserPDLLRewriteDecl(args...); },
1195  ParserContext::Rewrite, "rewrite", isInline);
1196 }
1197 
1198 FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() {
1200  parseUserRewriteDecl(/*isInline=*/true);
1201  if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
1202  return failure();
1203 
1204  curDeclScope->add(*decl);
1205  return decl;
1206 }
1207 
1208 FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl(
1209  const ast::Name &name, bool isInline,
1210  ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
1211  ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1212  // Push the argument scope back onto the list, so that the body can
1213  // reference arguments.
1214  curDeclScope = argumentScope;
1215  ast::CompoundStmt *body;
1216  if (curToken.is(Token::equal_arrow)) {
1217  FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
1218  [&](ast::Stmt *&statement) -> LogicalResult {
1219  if (isa<ast::OpRewriteStmt>(statement))
1220  return success();
1221 
1222  ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement);
1223  if (!statementExpr) {
1224  return emitError(
1225  statement->getLoc(),
1226  "expected `Rewrite` lambda body to contain a single expression "
1227  "or an operation rewrite statement; such as `erase`, "
1228  "`replace`, or `rewrite`");
1229  }
1230  statement =
1231  ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr);
1232  return success();
1233  },
1234  /*expectTerminalSemicolon=*/!isInline);
1235  if (failed(bodyResult))
1236  return failure();
1237  body = *bodyResult;
1238  } else {
1239  FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1240  if (failed(bodyResult))
1241  return failure();
1242  body = *bodyResult;
1243  }
1244  popDeclScope();
1245 
1246  // Verify the structure of the body.
1247  auto bodyIt = body->begin(), bodyE = body->end();
1248  for (; bodyIt != bodyE; ++bodyIt)
1249  if (isa<ast::ReturnStmt>(*bodyIt))
1250  break;
1251  if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt,
1252  bodyE, results, resultType)))
1253  return failure();
1254  return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>(
1255  name, arguments, results, resultType, body);
1256 }
1257 
1258 template <typename T, typename ParseUserPDLLDeclFnT>
1259 FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl(
1260  ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
1261  StringRef anonymousNamePrefix, bool isInline) {
1262  SMRange loc = curToken.getLoc();
1263  consumeToken();
1264  llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, declContext);
1265 
1266  // Parse the name of the decl.
1267  const ast::Name *name = nullptr;
1268  if (curToken.isNot(Token::identifier)) {
1269  // Only inline decls can be un-named. Inline decls are similar to "lambdas"
1270  // in C++, so being unnamed is fine.
1271  if (!isInline)
1272  return emitError("expected identifier name");
1273 
1274  // Create a unique anonymous name to use, as the name for this decl is not
1275  // important.
1276  std::string anonName =
1277  llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix,
1278  anonymousDeclNameCounter++)
1279  .str();
1280  name = &ast::Name::create(ctx, anonName, loc);
1281  } else {
1282  // If a name was provided, we can use it directly.
1283  name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1284  consumeToken(Token::identifier);
1285  }
1286 
1287  // Parse the functional signature of the decl.
1288  SmallVector<ast::VariableDecl *> arguments, results;
1289  ast::DeclScope *argumentScope;
1290  ast::Type resultType;
1291  if (failed(parseUserConstraintOrRewriteSignature(arguments, results,
1292  argumentScope, resultType)))
1293  return failure();
1294 
1295  // Check to see which type of constraint this is. If the constraint contains a
1296  // compound body, this is a PDLL decl.
1297  if (curToken.isAny(Token::l_brace, Token::equal_arrow))
1298  return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results,
1299  resultType);
1300 
1301  // Otherwise, this is a native decl.
1302  return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments,
1303  results, resultType);
1304 }
1305 
1306 template <typename T>
1307 FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
1308  const ast::Name &name, bool isInline,
1310  ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
1311  // If followed by a string, the native code body has also been specified.
1312  std::string codeStrStorage;
1313  Optional<StringRef> optCodeStr;
1314  if (curToken.isString()) {
1315  codeStrStorage = curToken.getStringValue();
1316  optCodeStr = codeStrStorage;
1317  consumeToken();
1318  } else if (isInline) {
1319  return emitError(name.getLoc(),
1320  "external declarations must be declared in global scope");
1321  } else if (curToken.is(Token::error)) {
1322  return failure();
1323  }
1324  if (failed(parseToken(Token::semicolon,
1325  "expected `;` after native declaration")))
1326  return failure();
1327  // TODO: PDL should be able to support constraint results in certain
1328  // situations, we should revise this.
1329  if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) {
1330  return emitError(
1331  "native Constraints currently do not support returning results");
1332  }
1333  return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
1334 }
1335 
1336 LogicalResult Parser::parseUserConstraintOrRewriteSignature(
1339  ast::DeclScope *&argumentScope, ast::Type &resultType) {
1340  // Parse the argument list of the decl.
1341  if (failed(parseToken(Token::l_paren, "expected `(` to start argument list")))
1342  return failure();
1343 
1344  argumentScope = pushDeclScope();
1345  if (curToken.isNot(Token::r_paren)) {
1346  do {
1347  FailureOr<ast::VariableDecl *> argument = parseArgumentDecl();
1348  if (failed(argument))
1349  return failure();
1350  arguments.emplace_back(*argument);
1351  } while (consumeIf(Token::comma));
1352  }
1353  popDeclScope();
1354  if (failed(parseToken(Token::r_paren, "expected `)` to end argument list")))
1355  return failure();
1356 
1357  // Parse the results of the decl.
1358  pushDeclScope();
1359  if (consumeIf(Token::arrow)) {
1360  auto parseResultFn = [&]() -> LogicalResult {
1361  FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size());
1362  if (failed(result))
1363  return failure();
1364  results.emplace_back(*result);
1365  return success();
1366  };
1367 
1368  // Check for a list of results.
1369  if (consumeIf(Token::l_paren)) {
1370  do {
1371  if (failed(parseResultFn()))
1372  return failure();
1373  } while (consumeIf(Token::comma));
1374  if (failed(parseToken(Token::r_paren, "expected `)` to end result list")))
1375  return failure();
1376 
1377  // Otherwise, there is only one result.
1378  } else if (failed(parseResultFn())) {
1379  return failure();
1380  }
1381  }
1382  popDeclScope();
1383 
1384  // Compute the result type of the decl.
1385  resultType = createUserConstraintRewriteResultType(results);
1386 
1387  // Verify that results are only named if there are more than one.
1388  if (results.size() == 1 && !results.front()->getName().getName().empty()) {
1389  return emitError(
1390  results.front()->getLoc(),
1391  "cannot create a single-element tuple with an element label");
1392  }
1393  return success();
1394 }
1395 
1396 LogicalResult Parser::validateUserConstraintOrRewriteReturn(
1397  StringRef declType, ast::CompoundStmt *body,
1400  ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) {
1401  // Handle if a `return` was provided.
1402  if (bodyIt != bodyE) {
1403  // Emit an error if we have trailing statements after the return.
1404  if (std::next(bodyIt) != bodyE) {
1405  return emitError(
1406  (*std::next(bodyIt))->getLoc(),
1407  llvm::formatv("`return` terminated the `{0}` body, but found "
1408  "trailing statements afterwards",
1409  declType));
1410  }
1411 
1412  // Otherwise if a return wasn't provided, check that no results are
1413  // expected.
1414  } else if (!results.empty()) {
1415  return emitError(
1416  {body->getLoc().End, body->getLoc().End},
1417  llvm::formatv("missing return in a `{0}` expected to return `{1}`",
1418  declType, resultType));
1419  }
1420  return success();
1421 }
1422 
1423 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
1424  return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult {
1425  if (isa<ast::OpRewriteStmt>(statement))
1426  return success();
1427  return emitError(
1428  statement->getLoc(),
1429  "expected Pattern lambda body to contain a single operation "
1430  "rewrite statement, such as `erase`, `replace`, or `rewrite`");
1431  });
1432 }
1433 
1434 FailureOr<ast::Decl *> Parser::parsePatternDecl() {
1435  SMRange loc = curToken.getLoc();
1436  consumeToken(Token::kw_Pattern);
1437  llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
1438  ParserContext::PatternMatch);
1439 
1440  // Check for an optional identifier for the pattern name.
1441  const ast::Name *name = nullptr;
1442  if (curToken.is(Token::identifier)) {
1443  name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
1444  consumeToken(Token::identifier);
1445  }
1446 
1447  // Parse any pattern metadata.
1448  ParsedPatternMetadata metadata;
1449  if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata)))
1450  return failure();
1451 
1452  // Parse the pattern body.
1453  ast::CompoundStmt *body;
1454 
1455  // Handle a lambda body.
1456  if (curToken.is(Token::equal_arrow)) {
1457  FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody();
1458  if (failed(bodyResult))
1459  return failure();
1460  body = *bodyResult;
1461  } else {
1462  if (curToken.isNot(Token::l_brace))
1463  return emitError("expected `{` or `=>` to start pattern body");
1464  FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
1465  if (failed(bodyResult))
1466  return failure();
1467  body = *bodyResult;
1468 
1469  // Verify the body of the pattern.
1470  auto bodyIt = body->begin(), bodyE = body->end();
1471  for (; bodyIt != bodyE; ++bodyIt) {
1472  if (isa<ast::ReturnStmt>(*bodyIt)) {
1473  return emitError((*bodyIt)->getLoc(),
1474  "`return` statements are only permitted within a "
1475  "`Constraint` or `Rewrite` body");
1476  }
1477  // Break when we've found the rewrite statement.
1478  if (isa<ast::OpRewriteStmt>(*bodyIt))
1479  break;
1480  }
1481  if (bodyIt == bodyE) {
1482  return emitError(loc,
1483  "expected Pattern body to terminate with an operation "
1484  "rewrite statement, such as `erase`");
1485  }
1486  if (std::next(bodyIt) != bodyE) {
1487  return emitError((*std::next(bodyIt))->getLoc(),
1488  "Pattern body was terminated by an operation "
1489  "rewrite statement, but found trailing statements");
1490  }
1491  }
1492 
1493  return createPatternDecl(loc, name, metadata, body);
1494 }
1495 
1497 Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) {
1498  Optional<SMRange> benefitLoc;
1499  Optional<SMRange> hasBoundedRecursionLoc;
1500 
1501  do {
1502  // Handle metadata code completion.
1503  if (curToken.is(Token::code_complete))
1504  return codeCompletePatternMetadata();
1505 
1506  if (curToken.isNot(Token::identifier))
1507  return emitError("expected pattern metadata identifier");
1508  StringRef metadataStr = curToken.getSpelling();
1509  SMRange metadataLoc = curToken.getLoc();
1510  consumeToken(Token::identifier);
1511 
1512  // Parse the benefit metadata: benefit(<integer-value>)
1513  if (metadataStr == "benefit") {
1514  if (benefitLoc) {
1515  return emitErrorAndNote(metadataLoc,
1516  "pattern benefit has already been specified",
1517  *benefitLoc, "see previous definition here");
1518  }
1519  if (failed(parseToken(Token::l_paren,
1520  "expected `(` before pattern benefit")))
1521  return failure();
1522 
1523  uint16_t benefitValue = 0;
1524  if (curToken.isNot(Token::integer))
1525  return emitError("expected integral pattern benefit");
1526  if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue))
1527  return emitError(
1528  "expected pattern benefit to fit within a 16-bit integer");
1529  consumeToken(Token::integer);
1530 
1531  metadata.benefit = benefitValue;
1532  benefitLoc = metadataLoc;
1533 
1534  if (failed(
1535  parseToken(Token::r_paren, "expected `)` after pattern benefit")))
1536  return failure();
1537  continue;
1538  }
1539 
1540  // Parse the bounded recursion metadata: recursion
1541  if (metadataStr == "recursion") {
1542  if (hasBoundedRecursionLoc) {
1543  return emitErrorAndNote(
1544  metadataLoc,
1545  "pattern recursion metadata has already been specified",
1546  *hasBoundedRecursionLoc, "see previous definition here");
1547  }
1548  metadata.hasBoundedRecursion = true;
1549  hasBoundedRecursionLoc = metadataLoc;
1550  continue;
1551  }
1552 
1553  return emitError(metadataLoc, "unknown pattern metadata");
1554  } while (consumeIf(Token::comma));
1555 
1556  return success();
1557 }
1558 
1559 FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() {
1560  consumeToken(Token::less);
1561 
1562  FailureOr<ast::Expr *> typeExpr = parseExpr();
1563  if (failed(typeExpr) ||
1564  failed(parseToken(Token::greater,
1565  "expected `>` after variable type constraint")))
1566  return failure();
1567  return typeExpr;
1568 }
1569 
1570 LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) {
1571  assert(curDeclScope && "defining decl outside of a decl scope");
1572  if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) {
1573  return emitErrorAndNote(
1574  name.getLoc(), "`" + name.getName() + "` has already been defined",
1575  lastDecl->getName()->getLoc(), "see previous definition here");
1576  }
1577  return success();
1578 }
1579 
1581 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1582  ast::Expr *initExpr,
1583  ArrayRef<ast::ConstraintRef> constraints) {
1584  assert(curDeclScope && "defining variable outside of decl scope");
1585  const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc);
1586 
1587  // If the name of the variable indicates a special variable, we don't add it
1588  // to the scope. This variable is local to the definition point.
1589  if (name.empty() || name == "_") {
1590  return ast::VariableDecl::create(ctx, nameDecl, type, initExpr,
1591  constraints);
1592  }
1593  if (failed(checkDefineNamedDecl(nameDecl)))
1594  return failure();
1595 
1596  auto *varDecl =
1597  ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints);
1598  curDeclScope->add(varDecl);
1599  return varDecl;
1600 }
1601 
1603 Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
1604  ArrayRef<ast::ConstraintRef> constraints) {
1605  return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr,
1606  constraints);
1607 }
1608 
1609 LogicalResult Parser::parseVariableDeclConstraintList(
1610  SmallVectorImpl<ast::ConstraintRef> &constraints) {
1611  Optional<SMRange> typeConstraint;
1612  auto parseSingleConstraint = [&] {
1613  FailureOr<ast::ConstraintRef> constraint = parseConstraint(
1614  typeConstraint, constraints, /*allowInlineTypeConstraints=*/true,
1615  /*allowNonCoreConstraints=*/true);
1616  if (failed(constraint))
1617  return failure();
1618  constraints.push_back(*constraint);
1619  return success();
1620  };
1621 
1622  // Check to see if this is a single constraint, or a list.
1623  if (!consumeIf(Token::l_square))
1624  return parseSingleConstraint();
1625 
1626  do {
1627  if (failed(parseSingleConstraint()))
1628  return failure();
1629  } while (consumeIf(Token::comma));
1630  return parseToken(Token::r_square, "expected `]` after constraint list");
1631 }
1632 
1634 Parser::parseConstraint(Optional<SMRange> &typeConstraint,
1635  ArrayRef<ast::ConstraintRef> existingConstraints,
1636  bool allowInlineTypeConstraints,
1637  bool allowNonCoreConstraints) {
1638  auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
1639  if (!allowInlineTypeConstraints) {
1640  return emitError(
1641  curToken.getLoc(),
1642  "inline `Attr`, `Value`, and `ValueRange` type constraints are not "
1643  "permitted on arguments or results");
1644  }
1645  if (typeConstraint)
1646  return emitErrorAndNote(
1647  curToken.getLoc(),
1648  "the type of this variable has already been constrained",
1649  *typeConstraint, "see previous constraint location here");
1650  FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr();
1651  if (failed(constraintExpr))
1652  return failure();
1653  typeExpr = *constraintExpr;
1654  typeConstraint = typeExpr->getLoc();
1655  return success();
1656  };
1657 
1658  SMRange loc = curToken.getLoc();
1659  switch (curToken.getKind()) {
1660  case Token::kw_Attr: {
1661  consumeToken(Token::kw_Attr);
1662 
1663  // Check for a type constraint.
1664  ast::Expr *typeExpr = nullptr;
1665  if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1666  return failure();
1667  return ast::ConstraintRef(
1668  ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc);
1669  }
1670  case Token::kw_Op: {
1671  consumeToken(Token::kw_Op);
1672 
1673  // Parse an optional operation name. If the name isn't provided, this refers
1674  // to "any" operation.
1676  parseWrappedOperationName(/*allowEmptyName=*/true);
1677  if (failed(opName))
1678  return failure();
1679 
1680  return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName),
1681  loc);
1682  }
1683  case Token::kw_Type:
1684  consumeToken(Token::kw_Type);
1685  return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc);
1686  case Token::kw_TypeRange:
1687  consumeToken(Token::kw_TypeRange);
1689  loc);
1690  case Token::kw_Value: {
1691  consumeToken(Token::kw_Value);
1692 
1693  // Check for a type constraint.
1694  ast::Expr *typeExpr = nullptr;
1695  if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1696  return failure();
1697 
1698  return ast::ConstraintRef(
1699  ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc);
1700  }
1701  case Token::kw_ValueRange: {
1702  consumeToken(Token::kw_ValueRange);
1703 
1704  // Check for a type constraint.
1705  ast::Expr *typeExpr = nullptr;
1706  if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
1707  return failure();
1708 
1709  return ast::ConstraintRef(
1710  ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc);
1711  }
1712 
1713  case Token::kw_Constraint: {
1714  // Handle an inline constraint.
1715  FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1716  if (failed(decl))
1717  return failure();
1718  return ast::ConstraintRef(*decl, loc);
1719  }
1720  case Token::identifier: {
1721  StringRef constraintName = curToken.getSpelling();
1722  consumeToken(Token::identifier);
1723 
1724  // Lookup the referenced constraint.
1725  ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName);
1726  if (!cstDecl) {
1727  return emitError(loc, "unknown reference to constraint `" +
1728  constraintName + "`");
1729  }
1730 
1731  // Handle a reference to a proper constraint.
1732  if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl))
1733  return ast::ConstraintRef(cst, loc);
1734 
1735  return emitErrorAndNote(
1736  loc, "invalid reference to non-constraint", cstDecl->getLoc(),
1737  "see the definition of `" + constraintName + "` here");
1738  }
1739  // Handle single entity constraint code completion.
1740  case Token::code_complete: {
1741  // Try to infer the current type for use by code completion.
1742  ast::Type inferredType;
1743  if (failed(validateVariableConstraints(existingConstraints, inferredType,
1744  allowNonCoreConstraints)))
1745  return failure();
1746 
1747  return codeCompleteConstraintName(inferredType, allowNonCoreConstraints,
1748  allowInlineTypeConstraints);
1749  }
1750  default:
1751  break;
1752  }
1753  return emitError(loc, "expected identifier constraint");
1754 }
1755 
1756 FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() {
1757  // Constraint arguments may apply more complex constraints via the arguments.
1758  bool allowNonCoreConstraints = parserContext == ParserContext::Constraint;
1759 
1760  Optional<SMRange> typeConstraint;
1761  return parseConstraint(typeConstraint, /*existingConstraints=*/llvm::None,
1762  /*allowInlineTypeConstraints=*/false,
1763  allowNonCoreConstraints);
1764 }
1765 
1766 //===----------------------------------------------------------------------===//
1767 // Exprs
1768 
1769 FailureOr<ast::Expr *> Parser::parseExpr() {
1770  if (curToken.is(Token::underscore))
1771  return parseUnderscoreExpr();
1772 
1773  // Parse the LHS expression.
1774  FailureOr<ast::Expr *> lhsExpr;
1775  switch (curToken.getKind()) {
1776  case Token::kw_attr:
1777  lhsExpr = parseAttributeExpr();
1778  break;
1779  case Token::kw_Constraint:
1780  lhsExpr = parseInlineConstraintLambdaExpr();
1781  break;
1782  case Token::identifier:
1783  lhsExpr = parseIdentifierExpr();
1784  break;
1785  case Token::kw_op:
1786  lhsExpr = parseOperationExpr();
1787  break;
1788  case Token::kw_Rewrite:
1789  lhsExpr = parseInlineRewriteLambdaExpr();
1790  break;
1791  case Token::kw_type:
1792  lhsExpr = parseTypeExpr();
1793  break;
1794  case Token::l_paren:
1795  lhsExpr = parseTupleExpr();
1796  break;
1797  default:
1798  return emitError("expected expression");
1799  }
1800  if (failed(lhsExpr))
1801  return failure();
1802 
1803  // Check for an operator expression.
1804  while (true) {
1805  switch (curToken.getKind()) {
1806  case Token::dot:
1807  lhsExpr = parseMemberAccessExpr(*lhsExpr);
1808  break;
1809  case Token::l_paren:
1810  lhsExpr = parseCallExpr(*lhsExpr);
1811  break;
1812  default:
1813  return lhsExpr;
1814  }
1815  if (failed(lhsExpr))
1816  return failure();
1817  }
1818 }
1819 
1820 FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
1821  SMRange loc = curToken.getLoc();
1822  consumeToken(Token::kw_attr);
1823 
1824  // If we aren't followed by a `<`, the `attr` keyword is treated as a normal
1825  // identifier.
1826  if (!consumeIf(Token::less)) {
1827  resetToken(loc);
1828  return parseIdentifierExpr();
1829  }
1830 
1831  if (!curToken.isString())
1832  return emitError("expected string literal containing MLIR attribute");
1833  std::string attrExpr = curToken.getStringValue();
1834  consumeToken();
1835 
1836  loc.End = curToken.getEndLoc();
1837  if (failed(
1838  parseToken(Token::greater, "expected `>` after attribute literal")))
1839  return failure();
1840  return ast::AttributeExpr::create(ctx, loc, attrExpr);
1841 }
1842 
1843 FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) {
1844  consumeToken(Token::l_paren);
1845 
1846  // Parse the arguments of the call.
1847  SmallVector<ast::Expr *> arguments;
1848  if (curToken.isNot(Token::r_paren)) {
1849  do {
1850  // Handle code completion for the call arguments.
1851  if (curToken.is(Token::code_complete)) {
1852  codeCompleteCallSignature(parentExpr, arguments.size());
1853  return failure();
1854  }
1855 
1856  FailureOr<ast::Expr *> argument = parseExpr();
1857  if (failed(argument))
1858  return failure();
1859  arguments.push_back(*argument);
1860  } while (consumeIf(Token::comma));
1861  }
1862 
1863  SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1864  if (failed(parseToken(Token::r_paren, "expected `)` after argument list")))
1865  return failure();
1866 
1867  return createCallExpr(loc, parentExpr, arguments);
1868 }
1869 
1870 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
1871  ast::Decl *decl = curDeclScope->lookup(name);
1872  if (!decl)
1873  return emitError(loc, "undefined reference to `" + name + "`");
1874 
1875  return createDeclRefExpr(loc, decl);
1876 }
1877 
1878 FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
1879  StringRef name = curToken.getSpelling();
1880  SMRange nameLoc = curToken.getLoc();
1881  consumeToken();
1882 
1883  // Check to see if this is a decl ref expression that defines a variable
1884  // inline.
1885  if (consumeIf(Token::colon)) {
1886  SmallVector<ast::ConstraintRef> constraints;
1887  if (failed(parseVariableDeclConstraintList(constraints)))
1888  return failure();
1889  ast::Type type;
1890  if (failed(validateVariableConstraints(constraints, type)))
1891  return failure();
1892  return createInlineVariableExpr(type, name, nameLoc, constraints);
1893  }
1894 
1895  return parseDeclRefExpr(name, nameLoc);
1896 }
1897 
1898 FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() {
1899  FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
1900  if (failed(decl))
1901  return failure();
1902 
1903  return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1905 }
1906 
1907 FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() {
1908  FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl();
1909  if (failed(decl))
1910  return failure();
1911 
1912  return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
1913  ast::RewriteType::get(ctx));
1914 }
1915 
1916 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
1917  SMRange dotLoc = curToken.getLoc();
1918  consumeToken(Token::dot);
1919 
1920  // Check for code completion of the member name.
1921  if (curToken.is(Token::code_complete))
1922  return codeCompleteMemberAccess(parentExpr);
1923 
1924  // Parse the member name.
1925  Token memberNameTok = curToken;
1926  if (memberNameTok.isNot(Token::identifier, Token::integer) &&
1927  !memberNameTok.isKeyword())
1928  return emitError(dotLoc, "expected identifier or numeric member name");
1929  StringRef memberName = memberNameTok.getSpelling();
1930  SMRange loc(parentExpr->getLoc().Start, curToken.getEndLoc());
1931  consumeToken();
1932 
1933  return createMemberAccessExpr(parentExpr, memberName, loc);
1934 }
1935 
1936 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
1937  SMRange loc = curToken.getLoc();
1938 
1939  // Check for code completion for the dialect name.
1940  if (curToken.is(Token::code_complete))
1941  return codeCompleteDialectName();
1942 
1943  // Handle the case of an no operation name.
1944  if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) {
1945  if (allowEmptyName)
1946  return ast::OpNameDecl::create(ctx, SMRange());
1947  return emitError("expected dialect namespace");
1948  }
1949  StringRef name = curToken.getSpelling();
1950  consumeToken();
1951 
1952  // Otherwise, this is a literal operation name.
1953  if (failed(parseToken(Token::dot, "expected `.` after dialect namespace")))
1954  return failure();
1955 
1956  // Check for code completion for the operation name.
1957  if (curToken.is(Token::code_complete))
1958  return codeCompleteOperationName(name);
1959 
1960  if (curToken.isNot(Token::identifier) && !curToken.isKeyword())
1961  return emitError("expected operation name after dialect namespace");
1962 
1963  name = StringRef(name.data(), name.size() + 1);
1964  do {
1965  name = StringRef(name.data(), name.size() + curToken.getSpelling().size());
1966  loc.End = curToken.getEndLoc();
1967  consumeToken();
1968  } while (curToken.isAny(Token::identifier, Token::dot) ||
1969  curToken.isKeyword());
1970  return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc));
1971 }
1972 
1974 Parser::parseWrappedOperationName(bool allowEmptyName) {
1975  if (!consumeIf(Token::less))
1976  return ast::OpNameDecl::create(ctx, SMRange());
1977 
1978  FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName(allowEmptyName);
1979  if (failed(opNameDecl))
1980  return failure();
1981 
1982  if (failed(parseToken(Token::greater, "expected `>` after operation name")))
1983  return failure();
1984  return opNameDecl;
1985 }
1986 
1988 Parser::parseOperationExpr(OpResultTypeContext inputResultTypeContext) {
1989  SMRange loc = curToken.getLoc();
1990  consumeToken(Token::kw_op);
1991 
1992  // If it isn't followed by a `<`, the `op` keyword is treated as a normal
1993  // identifier.
1994  if (curToken.isNot(Token::less)) {
1995  resetToken(loc);
1996  return parseIdentifierExpr();
1997  }
1998 
1999  // Parse the operation name. The name may be elided, in which case the
2000  // operation refers to "any" operation(i.e. a difference between `MyOp` and
2001  // `Operation*`). Operation names within a rewrite context must be named.
2002  bool allowEmptyName = parserContext != ParserContext::Rewrite;
2003  FailureOr<ast::OpNameDecl *> opNameDecl =
2004  parseWrappedOperationName(allowEmptyName);
2005  if (failed(opNameDecl))
2006  return failure();
2007  Optional<StringRef> opName = (*opNameDecl)->getName();
2008 
2009  // Functor used to create an implicit range variable, used for implicit "all"
2010  // operand or results variables.
2011  auto createImplicitRangeVar = [&](ast::ConstraintDecl *cst, ast::Type type) {
2013  defineVariableDecl("_", loc, type, ast::ConstraintRef(cst, loc));
2014  assert(succeeded(rangeVar) && "expected range variable to be valid");
2015  return ast::DeclRefExpr::create(ctx, loc, *rangeVar, type);
2016  };
2017 
2018  // Check for the optional list of operands.
2019  SmallVector<ast::Expr *> operands;
2020  if (!consumeIf(Token::l_paren)) {
2021  // If the operand list isn't specified and we are in a match context, define
2022  // an inplace unconstrained operand range corresponding to all of the
2023  // operands of the operation. This avoids treating zero operands the same
2024  // way as "unconstrained operands".
2025  if (parserContext != ParserContext::Rewrite) {
2026  operands.push_back(createImplicitRangeVar(
2027  ast::ValueRangeConstraintDecl::create(ctx, loc), valueRangeTy));
2028  }
2029  } else if (!consumeIf(Token::r_paren)) {
2030  // If the operand list was specified and non-empty, parse the operands.
2031  do {
2032  // Check for operand signature code completion.
2033  if (curToken.is(Token::code_complete)) {
2034  codeCompleteOperationOperandsSignature(opName, operands.size());
2035  return failure();
2036  }
2037 
2038  FailureOr<ast::Expr *> operand = parseExpr();
2039  if (failed(operand))
2040  return failure();
2041  operands.push_back(*operand);
2042  } while (consumeIf(Token::comma));
2043 
2044  if (failed(parseToken(Token::r_paren,
2045  "expected `)` after operation operand list")))
2046  return failure();
2047  }
2048 
2049  // Check for the optional list of attributes.
2051  if (consumeIf(Token::l_brace)) {
2052  do {
2054  parseNamedAttributeDecl(opName);
2055  if (failed(decl))
2056  return failure();
2057  attributes.emplace_back(*decl);
2058  } while (consumeIf(Token::comma));
2059 
2060  if (failed(parseToken(Token::r_brace,
2061  "expected `}` after operation attribute list")))
2062  return failure();
2063  }
2064 
2065  // Handle the result types of the operation.
2066  SmallVector<ast::Expr *> resultTypes;
2067  OpResultTypeContext resultTypeContext = inputResultTypeContext;
2068 
2069  // Check for an explicit list of result types.
2070  if (consumeIf(Token::arrow)) {
2071  if (failed(parseToken(Token::l_paren,
2072  "expected `(` before operation result type list")))
2073  return failure();
2074 
2075  // If result types are provided, initially assume that the operation does
2076  // not rely on type inferrence. We don't assert that it isn't, because we
2077  // may be inferring the value of some type/type range variables, but given
2078  // that these variables may be defined in calls we can't always discern when
2079  // this is the case.
2080  resultTypeContext = OpResultTypeContext::Explicit;
2081 
2082  // Handle the case of an empty result list.
2083  if (!consumeIf(Token::r_paren)) {
2084  do {
2085  // Check for result signature code completion.
2086  if (curToken.is(Token::code_complete)) {
2087  codeCompleteOperationResultsSignature(opName, resultTypes.size());
2088  return failure();
2089  }
2090 
2091  FailureOr<ast::Expr *> resultTypeExpr = parseExpr();
2092  if (failed(resultTypeExpr))
2093  return failure();
2094  resultTypes.push_back(*resultTypeExpr);
2095  } while (consumeIf(Token::comma));
2096 
2097  if (failed(parseToken(Token::r_paren,
2098  "expected `)` after operation result type list")))
2099  return failure();
2100  }
2101  } else if (parserContext != ParserContext::Rewrite) {
2102  // If the result list isn't specified and we are in a match context, define
2103  // an inplace unconstrained result range corresponding to all of the results
2104  // of the operation. This avoids treating zero results the same way as
2105  // "unconstrained results".
2106  resultTypes.push_back(createImplicitRangeVar(
2107  ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy));
2108  } else if (resultTypeContext == OpResultTypeContext::Explicit) {
2109  // If the result list isn't specified and we are in a rewrite, try to infer
2110  // them at runtime instead.
2111  resultTypeContext = OpResultTypeContext::Interface;
2112  }
2113 
2114  return createOperationExpr(loc, *opNameDecl, resultTypeContext, operands,
2115  attributes, resultTypes);
2116 }
2117 
2118 FailureOr<ast::Expr *> Parser::parseTupleExpr() {
2119  SMRange loc = curToken.getLoc();
2120  consumeToken(Token::l_paren);
2121 
2122  DenseMap<StringRef, SMRange> usedNames;
2123  SmallVector<StringRef> elementNames;
2124  SmallVector<ast::Expr *> elements;
2125  if (curToken.isNot(Token::r_paren)) {
2126  do {
2127  // Check for the optional element name assignment before the value.
2128  StringRef elementName;
2129  if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
2130  Token elementNameTok = curToken;
2131  consumeToken();
2132 
2133  // The element name is only present if followed by an `=`.
2134  if (consumeIf(Token::equal)) {
2135  elementName = elementNameTok.getSpelling();
2136 
2137  // Check to see if this name is already used.
2138  auto elementNameIt =
2139  usedNames.try_emplace(elementName, elementNameTok.getLoc());
2140  if (!elementNameIt.second) {
2141  return emitErrorAndNote(
2142  elementNameTok.getLoc(),
2143  llvm::formatv("duplicate tuple element label `{0}`",
2144  elementName),
2145  elementNameIt.first->getSecond(),
2146  "see previous label use here");
2147  }
2148  } else {
2149  // Otherwise, we treat this as part of an expression so reset the
2150  // lexer.
2151  resetToken(elementNameTok.getLoc());
2152  }
2153  }
2154  elementNames.push_back(elementName);
2155 
2156  // Parse the tuple element value.
2157  FailureOr<ast::Expr *> element = parseExpr();
2158  if (failed(element))
2159  return failure();
2160  elements.push_back(*element);
2161  } while (consumeIf(Token::comma));
2162  }
2163  loc.End = curToken.getEndLoc();
2164  if (failed(
2165  parseToken(Token::r_paren, "expected `)` after tuple element list")))
2166  return failure();
2167  return createTupleExpr(loc, elements, elementNames);
2168 }
2169 
2170 FailureOr<ast::Expr *> Parser::parseTypeExpr() {
2171  SMRange loc = curToken.getLoc();
2172  consumeToken(Token::kw_type);
2173 
2174  // If we aren't followed by a `<`, the `type` keyword is treated as a normal
2175  // identifier.
2176  if (!consumeIf(Token::less)) {
2177  resetToken(loc);
2178  return parseIdentifierExpr();
2179  }
2180 
2181  if (!curToken.isString())
2182  return emitError("expected string literal containing MLIR type");
2183  std::string attrExpr = curToken.getStringValue();
2184  consumeToken();
2185 
2186  loc.End = curToken.getEndLoc();
2187  if (failed(parseToken(Token::greater, "expected `>` after type literal")))
2188  return failure();
2189  return ast::TypeExpr::create(ctx, loc, attrExpr);
2190 }
2191 
2192 FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
2193  StringRef name = curToken.getSpelling();
2194  SMRange nameLoc = curToken.getLoc();
2195  consumeToken(Token::underscore);
2196 
2197  // Underscore expressions require a constraint list.
2198  if (failed(parseToken(Token::colon, "expected `:` after `_` variable")))
2199  return failure();
2200 
2201  // Parse the constraints for the expression.
2202  SmallVector<ast::ConstraintRef> constraints;
2203  if (failed(parseVariableDeclConstraintList(constraints)))
2204  return failure();
2205 
2206  ast::Type type;
2207  if (failed(validateVariableConstraints(constraints, type)))
2208  return failure();
2209  return createInlineVariableExpr(type, name, nameLoc, constraints);
2210 }
2211 
2212 //===----------------------------------------------------------------------===//
2213 // Stmts
2214 
2215 FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) {
2217  switch (curToken.getKind()) {
2218  case Token::kw_erase:
2219  stmt = parseEraseStmt();
2220  break;
2221  case Token::kw_let:
2222  stmt = parseLetStmt();
2223  break;
2224  case Token::kw_replace:
2225  stmt = parseReplaceStmt();
2226  break;
2227  case Token::kw_return:
2228  stmt = parseReturnStmt();
2229  break;
2230  case Token::kw_rewrite:
2231  stmt = parseRewriteStmt();
2232  break;
2233  default:
2234  stmt = parseExpr();
2235  break;
2236  }
2237  if (failed(stmt) ||
2238  (expectTerminalSemicolon &&
2239  failed(parseToken(Token::semicolon, "expected `;` after statement"))))
2240  return failure();
2241  return stmt;
2242 }
2243 
2244 FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() {
2245  SMLoc startLoc = curToken.getStartLoc();
2246  consumeToken(Token::l_brace);
2247 
2248  // Push a new block scope and parse any nested statements.
2249  pushDeclScope();
2250  SmallVector<ast::Stmt *> statements;
2251  while (curToken.isNot(Token::r_brace)) {
2252  FailureOr<ast::Stmt *> statement = parseStmt();
2253  if (failed(statement))
2254  return popDeclScope(), failure();
2255  statements.push_back(*statement);
2256  }
2257  popDeclScope();
2258 
2259  // Consume the end brace.
2260  SMRange location(startLoc, curToken.getEndLoc());
2261  consumeToken(Token::r_brace);
2262 
2263  return ast::CompoundStmt::create(ctx, location, statements);
2264 }
2265 
2266 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
2267  if (parserContext == ParserContext::Constraint)
2268  return emitError("`erase` cannot be used within a Constraint");
2269  SMRange loc = curToken.getLoc();
2270  consumeToken(Token::kw_erase);
2271 
2272  // Parse the root operation expression.
2273  FailureOr<ast::Expr *> rootOp = parseExpr();
2274  if (failed(rootOp))
2275  return failure();
2276 
2277  return createEraseStmt(loc, *rootOp);
2278 }
2279 
2280 FailureOr<ast::LetStmt *> Parser::parseLetStmt() {
2281  SMRange loc = curToken.getLoc();
2282  consumeToken(Token::kw_let);
2283 
2284  // Parse the name of the new variable.
2285  SMRange varLoc = curToken.getLoc();
2286  if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) {
2287  // `_` is a reserved variable name.
2288  if (curToken.is(Token::underscore)) {
2289  return emitError(varLoc,
2290  "`_` may only be used to define \"inline\" variables");
2291  }
2292  return emitError(varLoc,
2293  "expected identifier after `let` to name a new variable");
2294  }
2295  StringRef varName = curToken.getSpelling();
2296  consumeToken();
2297 
2298  // Parse the optional set of constraints.
2299  SmallVector<ast::ConstraintRef> constraints;
2300  if (consumeIf(Token::colon) &&
2301  failed(parseVariableDeclConstraintList(constraints)))
2302  return failure();
2303 
2304  // Parse the optional initializer expression.
2305  ast::Expr *initializer = nullptr;
2306  if (consumeIf(Token::equal)) {
2307  FailureOr<ast::Expr *> initOrFailure = parseExpr();
2308  if (failed(initOrFailure))
2309  return failure();
2310  initializer = *initOrFailure;
2311 
2312  // Check that the constraints are compatible with having an initializer,
2313  // e.g. type constraints cannot be used with initializers.
2314  for (ast::ConstraintRef constraint : constraints) {
2315  LogicalResult result =
2316  TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint)
2318  ast::ValueRangeConstraintDecl>([&](const auto *cst) {
2319  if (auto *typeConstraintExpr = cst->getTypeExpr()) {
2320  return this->emitError(
2321  constraint.referenceLoc,
2322  "type constraints are not permitted on variables with "
2323  "initializers");
2324  }
2325  return success();
2326  })
2327  .Default(success());
2328  if (failed(result))
2329  return failure();
2330  }
2331  }
2332 
2334  createVariableDecl(varName, varLoc, initializer, constraints);
2335  if (failed(varDecl))
2336  return failure();
2337  return ast::LetStmt::create(ctx, loc, *varDecl);
2338 }
2339 
2340 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
2341  if (parserContext == ParserContext::Constraint)
2342  return emitError("`replace` cannot be used within a Constraint");
2343  SMRange loc = curToken.getLoc();
2344  consumeToken(Token::kw_replace);
2345 
2346  // Parse the root operation expression.
2347  FailureOr<ast::Expr *> rootOp = parseExpr();
2348  if (failed(rootOp))
2349  return failure();
2350 
2351  if (failed(
2352  parseToken(Token::kw_with, "expected `with` after root operation")))
2353  return failure();
2354 
2355  // The replacement portion of this statement is within a rewrite context.
2356  llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
2357  ParserContext::Rewrite);
2358 
2359  // Parse the replacement values.
2360  SmallVector<ast::Expr *> replValues;
2361  if (consumeIf(Token::l_paren)) {
2362  if (consumeIf(Token::r_paren)) {
2363  return emitError(
2364  loc, "expected at least one replacement value, consider using "
2365  "`erase` if no replacement values are desired");
2366  }
2367 
2368  do {
2369  FailureOr<ast::Expr *> replExpr = parseExpr();
2370  if (failed(replExpr))
2371  return failure();
2372  replValues.emplace_back(*replExpr);
2373  } while (consumeIf(Token::comma));
2374 
2375  if (failed(parseToken(Token::r_paren,
2376  "expected `)` after replacement values")))
2377  return failure();
2378  } else {
2379  // Handle replacement with an operation uniquely, as the replacement
2380  // operation supports type inferrence from the root operation.
2381  FailureOr<ast::Expr *> replExpr;
2382  if (curToken.is(Token::kw_op))
2383  replExpr = parseOperationExpr(OpResultTypeContext::Replacement);
2384  else
2385  replExpr = parseExpr();
2386  if (failed(replExpr))
2387  return failure();
2388  replValues.emplace_back(*replExpr);
2389  }
2390 
2391  return createReplaceStmt(loc, *rootOp, replValues);
2392 }
2393 
2394 FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() {
2395  SMRange loc = curToken.getLoc();
2396  consumeToken(Token::kw_return);
2397 
2398  // Parse the result value.
2399  FailureOr<ast::Expr *> resultExpr = parseExpr();
2400  if (failed(resultExpr))
2401  return failure();
2402 
2403  return ast::ReturnStmt::create(ctx, loc, *resultExpr);
2404 }
2405 
2406 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
2407  if (parserContext == ParserContext::Constraint)
2408  return emitError("`rewrite` cannot be used within a Constraint");
2409  SMRange loc = curToken.getLoc();
2410  consumeToken(Token::kw_rewrite);
2411 
2412  // Parse the root operation.
2413  FailureOr<ast::Expr *> rootOp = parseExpr();
2414  if (failed(rootOp))
2415  return failure();
2416 
2417  if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body")))
2418  return failure();
2419 
2420  if (curToken.isNot(Token::l_brace))
2421  return emitError("expected `{` to start rewrite body");
2422 
2423  // The rewrite body of this statement is within a rewrite context.
2424  llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
2425  ParserContext::Rewrite);
2426 
2427  FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt();
2428  if (failed(rewriteBody))
2429  return failure();
2430 
2431  // Verify the rewrite body.
2432  for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) {
2433  if (isa<ast::ReturnStmt>(stmt)) {
2434  return emitError(stmt->getLoc(),
2435  "`return` statements are only permitted within a "
2436  "`Constraint` or `Rewrite` body");
2437  }
2438  }
2439 
2440  return createRewriteStmt(loc, *rootOp, *rewriteBody);
2441 }
2442 
2443 //===----------------------------------------------------------------------===//
2444 // Creation+Analysis
2445 //===----------------------------------------------------------------------===//
2446 
2447 //===----------------------------------------------------------------------===//
2448 // Decls
2449 
2450 ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) {
2451  // Unwrap reference expressions.
2452  if (auto *init = dyn_cast<ast::DeclRefExpr>(node))
2453  node = init->getDecl();
2454  return dyn_cast<ast::CallableDecl>(node);
2455 }
2456 
2458 Parser::createPatternDecl(SMRange loc, const ast::Name *name,
2459  const ParsedPatternMetadata &metadata,
2460  ast::CompoundStmt *body) {
2461  return ast::PatternDecl::create(ctx, loc, name, metadata.benefit,
2462  metadata.hasBoundedRecursion, body);
2463 }
2464 
2465 ast::Type Parser::createUserConstraintRewriteResultType(
2467  // Single result decls use the type of the single result.
2468  if (results.size() == 1)
2469  return results[0]->getType();
2470 
2471  // Multiple results use a tuple type, with the types and names grabbed from
2472  // the result variable decls.
2473  auto resultTypes = llvm::map_range(
2474  results, [&](const auto *result) { return result->getType(); });
2475  auto resultNames = llvm::map_range(
2476  results, [&](const auto *result) { return result->getName().getName(); });
2477  return ast::TupleType::get(ctx, llvm::to_vector(resultTypes),
2478  llvm::to_vector(resultNames));
2479 }
2480 
2481 template <typename T>
2482 FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl(
2483  const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
2484  ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
2485  ast::CompoundStmt *body) {
2486  if (!body->getChildren().empty()) {
2487  if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) {
2488  ast::Expr *resultExpr = retStmt->getResultExpr();
2489 
2490  // Process the result of the decl. If no explicit signature results
2491  // were provided, check for return type inference. Otherwise, check that
2492  // the return expression can be converted to the expected type.
2493  if (results.empty())
2494  resultType = resultExpr->getType();
2495  else if (failed(convertExpressionTo(resultExpr, resultType)))
2496  return failure();
2497  else
2498  retStmt->setResultExpr(resultExpr);
2499  }
2500  }
2501  return T::createPDLL(ctx, name, arguments, results, body, resultType);
2502 }
2503 
2505 Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
2506  ArrayRef<ast::ConstraintRef> constraints) {
2507  // The type of the variable, which is expected to be inferred by either a
2508  // constraint or an initializer expression.
2509  ast::Type type;
2510  if (failed(validateVariableConstraints(constraints, type)))
2511  return failure();
2512 
2513  if (initializer) {
2514  // Update the variable type based on the initializer, or try to convert the
2515  // initializer to the existing type.
2516  if (!type)
2517  type = initializer->getType();
2518  else if (ast::Type mergedType = type.refineWith(initializer->getType()))
2519  type = mergedType;
2520  else if (failed(convertExpressionTo(initializer, type)))
2521  return failure();
2522 
2523  // Otherwise, if there is no initializer check that the type has already
2524  // been resolved from the constraint list.
2525  } else if (!type) {
2526  return emitErrorAndNote(
2527  loc, "unable to infer type for variable `" + name + "`", loc,
2528  "the type of a variable must be inferable from the constraint "
2529  "list or the initializer");
2530  }
2531 
2532  // Constraint types cannot be used when defining variables.
2533  if (type.isa<ast::ConstraintType, ast::RewriteType>()) {
2534  return emitError(
2535  loc, llvm::formatv("unable to define variable of `{0}` type", type));
2536  }
2537 
2538  // Try to define a variable with the given name.
2540  defineVariableDecl(name, loc, type, initializer, constraints);
2541  if (failed(varDecl))
2542  return failure();
2543 
2544  return *varDecl;
2545 }
2546 
2548 Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
2549  const ast::ConstraintRef &constraint) {
2550  // Constraint arguments may apply more complex constraints via the arguments.
2551  bool allowNonCoreConstraints = parserContext == ParserContext::Constraint;
2552  ast::Type argType;
2553  if (failed(validateVariableConstraint(constraint, argType,
2554  allowNonCoreConstraints)))
2555  return failure();
2556  return defineVariableDecl(name, loc, argType, constraint);
2557 }
2558 
2560 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
2561  ast::Type &inferredType,
2562  bool allowNonCoreConstraints) {
2563  for (const ast::ConstraintRef &ref : constraints)
2564  if (failed(validateVariableConstraint(ref, inferredType,
2565  allowNonCoreConstraints)))
2566  return failure();
2567  return success();
2568 }
2569 
2570 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
2571  ast::Type &inferredType,
2572  bool allowNonCoreConstraints) {
2573  ast::Type constraintType;
2574  if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) {
2575  if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2576  if (failed(validateTypeConstraintExpr(typeExpr)))
2577  return failure();
2578  }
2579  constraintType = ast::AttributeType::get(ctx);
2580  } else if (const auto *cst =
2581  dyn_cast<ast::OpConstraintDecl>(ref.constraint)) {
2582  constraintType = ast::OperationType::get(
2583  ctx, cst->getName(), lookupODSOperation(cst->getName()));
2584  } else if (isa<ast::TypeConstraintDecl>(ref.constraint)) {
2585  constraintType = typeTy;
2586  } else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) {
2587  constraintType = typeRangeTy;
2588  } else if (const auto *cst =
2589  dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) {
2590  if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2591  if (failed(validateTypeConstraintExpr(typeExpr)))
2592  return failure();
2593  }
2594  constraintType = valueTy;
2595  } else if (const auto *cst =
2596  dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) {
2597  if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
2598  if (failed(validateTypeRangeConstraintExpr(typeExpr)))
2599  return failure();
2600  }
2601  constraintType = valueRangeTy;
2602  } else if (const auto *cst =
2603  dyn_cast<ast::UserConstraintDecl>(ref.constraint)) {
2604  if (!allowNonCoreConstraints) {
2605  return emitError(ref.referenceLoc,
2606  "`Rewrite` arguments and results are only permitted to "
2607  "use core constraints, such as `Attr`, `Op`, `Type`, "
2608  "`TypeRange`, `Value`, `ValueRange`");
2609  }
2610 
2611  ArrayRef<ast::VariableDecl *> inputs = cst->getInputs();
2612  if (inputs.size() != 1) {
2613  return emitErrorAndNote(ref.referenceLoc,
2614  "`Constraint`s applied via a variable constraint "
2615  "list must take a single input, but got " +
2616  Twine(inputs.size()),
2617  cst->getLoc(),
2618  "see definition of constraint here");
2619  }
2620  constraintType = inputs.front()->getType();
2621  } else {
2622  llvm_unreachable("unknown constraint type");
2623  }
2624 
2625  // Check that the constraint type is compatible with the current inferred
2626  // type.
2627  if (!inferredType) {
2628  inferredType = constraintType;
2629  } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) {
2630  inferredType = mergedTy;
2631  } else {
2632  return emitError(ref.referenceLoc,
2633  llvm::formatv("constraint type `{0}` is incompatible "
2634  "with the previously inferred type `{1}`",
2635  constraintType, inferredType));
2636  }
2637  return success();
2638 }
2639 
2640 LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) {
2641  ast::Type typeExprType = typeExpr->getType();
2642  if (typeExprType != typeTy) {
2643  return emitError(typeExpr->getLoc(),
2644  "expected expression of `Type` in type constraint");
2645  }
2646  return success();
2647 }
2648 
2650 Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) {
2651  ast::Type typeExprType = typeExpr->getType();
2652  if (typeExprType != typeRangeTy) {
2653  return emitError(typeExpr->getLoc(),
2654  "expected expression of `TypeRange` in type constraint");
2655  }
2656  return success();
2657 }
2658 
2659 //===----------------------------------------------------------------------===//
2660 // Exprs
2661 
2663 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
2664  MutableArrayRef<ast::Expr *> arguments) {
2665  ast::Type parentType = parentExpr->getType();
2666 
2667  ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr);
2668  if (!callableDecl) {
2669  return emitError(loc,
2670  llvm::formatv("expected a reference to a callable "
2671  "`Constraint` or `Rewrite`, but got: `{0}`",
2672  parentType));
2673  }
2674  if (parserContext == ParserContext::Rewrite) {
2675  if (isa<ast::UserConstraintDecl>(callableDecl))
2676  return emitError(
2677  loc, "unable to invoke `Constraint` within a rewrite section");
2678  } else if (isa<ast::UserRewriteDecl>(callableDecl)) {
2679  return emitError(loc, "unable to invoke `Rewrite` within a match section");
2680  }
2681 
2682  // Verify the arguments of the call.
2683  /// Handle size mismatch.
2684  ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs();
2685  if (callArgs.size() != arguments.size()) {
2686  return emitErrorAndNote(
2687  loc,
2688  llvm::formatv("invalid number of arguments for {0} call; expected "
2689  "{1}, but got {2}",
2690  callableDecl->getCallableType(), callArgs.size(),
2691  arguments.size()),
2692  callableDecl->getLoc(),
2693  llvm::formatv("see the definition of {0} here",
2694  callableDecl->getName()->getName()));
2695  }
2696 
2697  /// Handle argument type mismatch.
2698  auto attachDiagFn = [&](ast::Diagnostic &diag) {
2699  diag.attachNote(llvm::formatv("see the definition of `{0}` here",
2700  callableDecl->getName()->getName()),
2701  callableDecl->getLoc());
2702  };
2703  for (auto it : llvm::zip(callArgs, arguments)) {
2704  if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(),
2705  attachDiagFn)))
2706  return failure();
2707  }
2708 
2709  return ast::CallExpr::create(ctx, loc, parentExpr, arguments,
2710  callableDecl->getResultType());
2711 }
2712 
2713 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,
2714  ast::Decl *decl) {
2715  // Check the type of decl being referenced.
2716  ast::Type declType;
2717  if (isa<ast::ConstraintDecl>(decl))
2718  declType = ast::ConstraintType::get(ctx);
2719  else if (isa<ast::UserRewriteDecl>(decl))
2720  declType = ast::RewriteType::get(ctx);
2721  else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
2722  declType = varDecl->getType();
2723  else
2724  return emitError(loc, "invalid reference to `" +
2725  decl->getName()->getName() + "`");
2726 
2727  return ast::DeclRefExpr::create(ctx, loc, decl, declType);
2728 }
2729 
2731 Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
2732  ArrayRef<ast::ConstraintRef> constraints) {
2734  defineVariableDecl(name, loc, type, constraints);
2735  if (failed(decl))
2736  return failure();
2737  return ast::DeclRefExpr::create(ctx, loc, *decl, type);
2738 }
2739 
2741 Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
2742  SMRange loc) {
2743  // Validate the member name for the given parent expression.
2744  FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc);
2745  if (failed(memberType))
2746  return failure();
2747 
2748  return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType);
2749 }
2750 
2751 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
2752  StringRef name, SMRange loc) {
2753  ast::Type parentType = parentExpr->getType();
2754  if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) {
2756  return valueRangeTy;
2757 
2758  // Verify member access based on the operation type.
2759  if (const ods::Operation *odsOp = opType.getODSOperation()) {
2760  auto results = odsOp->getResults();
2761 
2762  // Handle indexed results.
2763  unsigned index = 0;
2764  if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2765  index < results.size()) {
2766  return results[index].isVariadic() ? valueRangeTy : valueTy;
2767  }
2768 
2769  // Handle named results.
2770  const auto *it = llvm::find_if(results, [&](const auto &result) {
2771  return result.getName() == name;
2772  });
2773  if (it != results.end())
2774  return it->isVariadic() ? valueRangeTy : valueTy;
2775  } else if (llvm::isDigit(name[0])) {
2776  // Allow unchecked numeric indexing of the results of unregistered
2777  // operations. It returns a single value.
2778  return valueTy;
2779  }
2780  } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) {
2781  // Handle indexed results.
2782  unsigned index = 0;
2783  if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
2784  index < tupleType.size()) {
2785  return tupleType.getElementTypes()[index];
2786  }
2787 
2788  // Handle named results.
2789  auto elementNames = tupleType.getElementNames();
2790  const auto *it = llvm::find(elementNames, name);
2791  if (it != elementNames.end())
2792  return tupleType.getElementTypes()[it - elementNames.begin()];
2793  }
2794  return emitError(
2795  loc,
2796  llvm::formatv("invalid member access `{0}` on expression of type `{1}`",
2797  name, parentType));
2798 }
2799 
2800 FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
2801  SMRange loc, const ast::OpNameDecl *name,
2802  OpResultTypeContext resultTypeContext,
2805  MutableArrayRef<ast::Expr *> results) {
2806  Optional<StringRef> opNameRef = name->getName();
2807  const ods::Operation *odsOp = lookupODSOperation(opNameRef);
2808 
2809  // Verify the inputs operands.
2810  if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands)))
2811  return failure();
2812 
2813  // Verify the attribute list.
2814  for (ast::NamedAttributeDecl *attr : attributes) {
2815  // Check for an attribute type, or a type awaiting resolution.
2816  ast::Type attrType = attr->getValue()->getType();
2817  if (!attrType.isa<ast::AttributeType>()) {
2818  return emitError(
2819  attr->getValue()->getLoc(),
2820  llvm::formatv("expected `Attr` expression, but got `{0}`", attrType));
2821  }
2822  }
2823 
2824  assert(
2825  (resultTypeContext == OpResultTypeContext::Explicit || results.empty()) &&
2826  "unexpected inferrence when results were explicitly specified");
2827 
2828  // If we aren't relying on type inferrence, or explicit results were provided,
2829  // validate them.
2830  if (resultTypeContext == OpResultTypeContext::Explicit) {
2831  if (failed(validateOperationResults(loc, opNameRef, odsOp, results)))
2832  return failure();
2833 
2834  // Validate the use of interface based type inferrence for this operation.
2835  } else if (resultTypeContext == OpResultTypeContext::Interface) {
2836  assert(opNameRef &&
2837  "expected valid operation name when inferring operation results");
2838  checkOperationResultTypeInferrence(loc, *opNameRef, odsOp);
2839  }
2840 
2841  return ast::OperationExpr::create(ctx, loc, odsOp, name, operands, results,
2842  attributes);
2843 }
2844 
2846 Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name,
2847  const ods::Operation *odsOp,
2848  MutableArrayRef<ast::Expr *> operands) {
2849  return validateOperationOperandsOrResults(
2850  "operand", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name,
2851  operands, odsOp ? odsOp->getOperands() : llvm::None, valueTy,
2852  valueRangeTy);
2853 }
2854 
2856 Parser::validateOperationResults(SMRange loc, Optional<StringRef> name,
2857  const ods::Operation *odsOp,
2858  MutableArrayRef<ast::Expr *> results) {
2859  return validateOperationOperandsOrResults(
2860  "result", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name,
2861  results, odsOp ? odsOp->getResults() : llvm::None, typeTy, typeRangeTy);
2862 }
2863 
2864 void Parser::checkOperationResultTypeInferrence(SMRange loc, StringRef opName,
2865  const ods::Operation *odsOp) {
2866  // If the operation might not have inferrence support, emit a warning to the
2867  // user. We don't emit an error because the interface might be added to the
2868  // operation at runtime. It's rare, but it could still happen. We emit a
2869  // warning here instead.
2870 
2871  // Handle inferrence warnings for unknown operations.
2872  if (!odsOp) {
2873  ctx.getDiagEngine().emitWarning(
2874  loc, llvm::formatv(
2875  "operation result types are marked to be inferred, but "
2876  "`{0}` is unknown. Ensure that `{0}` supports zero "
2877  "results or implements `InferTypeOpInterface`. Include "
2878  "the ODS definition of this operation to remove this warning.",
2879  opName));
2880  return;
2881  }
2882 
2883  // Handle inferrence warnings for known operations that expected at least one
2884  // result, but don't have inference support. An elided results list can mean
2885  // "zero-results", and we don't want to warn when that is the expected
2886  // behavior.
2887  bool requiresInferrence =
2888  llvm::any_of(odsOp->getResults(), [](const ods::OperandOrResult &result) {
2889  return !result.isVariableLength();
2890  });
2891  if (requiresInferrence && !odsOp->hasResultTypeInferrence()) {
2893  loc,
2894  llvm::formatv("operation result types are marked to be inferred, but "
2895  "`{0}` does not provide an implementation of "
2896  "`InferTypeOpInterface`. Ensure that `{0}` attaches "
2897  "`InferTypeOpInterface` at runtime, or add support to "
2898  "the ODS definition to remove this warning.",
2899  opName));
2900  diag->attachNote(llvm::formatv("see the definition of `{0}` here", opName),
2901  odsOp->getLoc());
2902  return;
2903  }
2904 }
2905 
2906 LogicalResult Parser::validateOperationOperandsOrResults(
2907  StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc,
2909  ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
2910  ast::Type rangeTy) {
2911  // All operation types accept a single range parameter.
2912  if (values.size() == 1) {
2913  if (failed(convertExpressionTo(values[0], rangeTy)))
2914  return failure();
2915  return success();
2916  }
2917 
2918  /// If the operation has ODS information, we can more accurately verify the
2919  /// values.
2920  if (odsOpLoc) {
2921  if (odsValues.size() != values.size()) {
2922  return emitErrorAndNote(
2923  loc,
2924  llvm::formatv("invalid number of {0} groups for `{1}`; expected "
2925  "{2}, but got {3}",
2926  groupName, *name, odsValues.size(), values.size()),
2927  *odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name));
2928  }
2929  auto diagFn = [&](ast::Diagnostic &diag) {
2930  diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name),
2931  *odsOpLoc);
2932  };
2933  for (unsigned i = 0, e = values.size(); i < e; ++i) {
2934  ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy;
2935  if (failed(convertExpressionTo(values[i], expectedType, diagFn)))
2936  return failure();
2937  }
2938  return success();
2939  }
2940 
2941  // Otherwise, accept the value groups as they have been defined and just
2942  // ensure they are one of the expected types.
2943  for (ast::Expr *&valueExpr : values) {
2944  ast::Type valueExprType = valueExpr->getType();
2945 
2946  // Check if this is one of the expected types.
2947  if (valueExprType == rangeTy || valueExprType == singleTy)
2948  continue;
2949 
2950  // If the operand is an Operation, allow converting to a Value or
2951  // ValueRange. This situations arises quite often with nested operation
2952  // expressions: `op<my_dialect.foo>(op<my_dialect.bar>)`
2953  if (singleTy == valueTy) {
2954  if (valueExprType.isa<ast::OperationType>()) {
2955  valueExpr = convertOpToValue(valueExpr);
2956  continue;
2957  }
2958  }
2959 
2960  return emitError(
2961  valueExpr->getLoc(),
2962  llvm::formatv(
2963  "expected `{0}` or `{1}` convertible expression, but got `{2}`",
2964  singleTy, rangeTy, valueExprType));
2965  }
2966  return success();
2967 }
2968 
2970 Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements,
2971  ArrayRef<StringRef> elementNames) {
2972  for (const ast::Expr *element : elements) {
2973  ast::Type eleTy = element->getType();
2975  return emitError(
2976  element->getLoc(),
2977  llvm::formatv("unable to build a tuple with `{0}` element", eleTy));
2978  }
2979  }
2980  return ast::TupleExpr::create(ctx, loc, elements, elementNames);
2981 }
2982 
2983 //===----------------------------------------------------------------------===//
2984 // Stmts
2985 
2986 FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc,
2987  ast::Expr *rootOp) {
2988  // Check that root is an Operation.
2989  ast::Type rootType = rootOp->getType();
2990  if (!rootType.isa<ast::OperationType>())
2991  return emitError(rootOp->getLoc(), "expected `Op` expression");
2992 
2993  return ast::EraseStmt::create(ctx, loc, rootOp);
2994 }
2995 
2997 Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp,
2998  MutableArrayRef<ast::Expr *> replValues) {
2999  // Check that root is an Operation.
3000  ast::Type rootType = rootOp->getType();
3001  if (!rootType.isa<ast::OperationType>()) {
3002  return emitError(
3003  rootOp->getLoc(),
3004  llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
3005  }
3006 
3007  // If there are multiple replacement values, we implicitly convert any Op
3008  // expressions to the value form.
3009  bool shouldConvertOpToValues = replValues.size() > 1;
3010  for (ast::Expr *&replExpr : replValues) {
3011  ast::Type replType = replExpr->getType();
3012 
3013  // Check that replExpr is an Operation, Value, or ValueRange.
3014  if (replType.isa<ast::OperationType>()) {
3015  if (shouldConvertOpToValues)
3016  replExpr = convertOpToValue(replExpr);
3017  continue;
3018  }
3019 
3020  if (replType != valueTy && replType != valueRangeTy) {
3021  return emitError(replExpr->getLoc(),
3022  llvm::formatv("expected `Op`, `Value` or `ValueRange` "
3023  "expression, but got `{0}`",
3024  replType));
3025  }
3026  }
3027 
3028  return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues);
3029 }
3030 
3032 Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp,
3033  ast::CompoundStmt *rewriteBody) {
3034  // Check that root is an Operation.
3035  ast::Type rootType = rootOp->getType();
3036  if (!rootType.isa<ast::OperationType>()) {
3037  return emitError(
3038  rootOp->getLoc(),
3039  llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
3040  }
3041 
3042  return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody);
3043 }
3044 
3045 //===----------------------------------------------------------------------===//
3046 // Code Completion
3047 //===----------------------------------------------------------------------===//
3048 
3049 LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) {
3050  ast::Type parentType = parentExpr->getType();
3051  if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>())
3052  codeCompleteContext->codeCompleteOperationMemberAccess(opType);
3053  else if (ast::TupleType tupleType = parentType.dyn_cast<ast::TupleType>())
3054  codeCompleteContext->codeCompleteTupleMemberAccess(tupleType);
3055  return failure();
3056 }
3057 
3058 LogicalResult Parser::codeCompleteAttributeName(Optional<StringRef> opName) {
3059  if (opName)
3060  codeCompleteContext->codeCompleteOperationAttributeName(*opName);
3061  return failure();
3062 }
3063 
3065 Parser::codeCompleteConstraintName(ast::Type inferredType,
3066  bool allowNonCoreConstraints,
3067  bool allowInlineTypeConstraints) {
3068  codeCompleteContext->codeCompleteConstraintName(
3069  inferredType, allowNonCoreConstraints, allowInlineTypeConstraints,
3070  curDeclScope);
3071  return failure();
3072 }
3073 
3074 LogicalResult Parser::codeCompleteDialectName() {
3075  codeCompleteContext->codeCompleteDialectName();
3076  return failure();
3077 }
3078 
3079 LogicalResult Parser::codeCompleteOperationName(StringRef dialectName) {
3080  codeCompleteContext->codeCompleteOperationName(dialectName);
3081  return failure();
3082 }
3083 
3084 LogicalResult Parser::codeCompletePatternMetadata() {
3085  codeCompleteContext->codeCompletePatternMetadata();
3086  return failure();
3087 }
3088 
3089 LogicalResult Parser::codeCompleteIncludeFilename(StringRef curPath) {
3090  codeCompleteContext->codeCompleteIncludeFilename(curPath);
3091  return failure();
3092 }
3093 
3094 void Parser::codeCompleteCallSignature(ast::Node *parent,
3095  unsigned currentNumArgs) {
3096  ast::CallableDecl *callableDecl = tryExtractCallableDecl(parent);
3097  if (!callableDecl)
3098  return;
3099 
3100  codeCompleteContext->codeCompleteCallSignature(callableDecl, currentNumArgs);
3101 }
3102 
3103 void Parser::codeCompleteOperationOperandsSignature(
3104  Optional<StringRef> opName, unsigned currentNumOperands) {
3105  codeCompleteContext->codeCompleteOperationOperandsSignature(
3106  opName, currentNumOperands);
3107 }
3108 
3109 void Parser::codeCompleteOperationResultsSignature(Optional<StringRef> opName,
3110  unsigned currentNumResults) {
3111  codeCompleteContext->codeCompleteOperationResultsSignature(opName,
3112  currentNumResults);
3113 }
3114 
3115 //===----------------------------------------------------------------------===//
3116 // Parser
3117 //===----------------------------------------------------------------------===//
3118 
3120 mlir::pdll::parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
3121  bool enableDocumentation,
3122  CodeCompleteContext *codeCompleteContext) {
3123  Parser parser(ctx, sourceMgr, enableDocumentation, codeCompleteContext);
3124  return parser.parseModule();
3125 }
static EraseStmt * create(Context &ctx, SMRange loc, Expr *rootOp)
Definition: Nodes.cpp:213
Include the generated interface declarations.
ArrayRef< OperandOrResult > getOperands() const
Returns the operands of this operation.
Definition: Operation.h:165
static ConstraintType get(Context &context)
Return an instance of the Constraint type.
Definition: Types.cpp:64
virtual void codeCompletePatternMetadata()
Signal code completion for Pattern metadata.
Definition: CodeComplete.h:67
static std::string diag(llvm::Value &v)
std::string getConditionTemplate() const
Definition: Constraint.cpp:50
virtual void codeCompleteOperationName(StringRef dialectName)
Signal code completion for an operation name in the given dialect.
Definition: CodeComplete.h:64
InFlightDiagnostic emitError(SMRange loc, const Twine &msg)
Emit an error to the diagnostic engine.
Definition: Diagnostic.h:145
ods::Context & getODSContext()
Return the ODS context used by the AST.
Definition: Context.h:38
static OperationExpr * create(Context &ctx, SMRange loc, const ods::Operation *odsOp, const OpNameDecl *nameDecl, ArrayRef< Expr *> operands, ArrayRef< Expr *> resultTypes, ArrayRef< NamedAttributeDecl *> attributes)
Definition: Nodes.cpp:301
This class represents the main context of the PDLL AST.
Definition: Context.h:25
This class represents the base of all AST Constraint decls.
Definition: Nodes.h:660
StringRef getCallableType() const
Return the callable type of this decl.
Definition: Nodes.h:1143
static NamedAttributeDecl * create(Context &ctx, const Name &name, Expr *value)
Definition: Nodes.cpp:477
static OperationType get(Context &context, Optional< StringRef > name=llvm::None, const ods::Operation *odsOp=nullptr)
Return an instance of the Operation type with an optional operation name.
Definition: Types.cpp:72
bool hasResultTypeInferrence() const
Return if the operation is known to support result type inferrence.
Definition: Operation.h:171
FailureOr< ast::Module * > parsePDLLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr, bool enableDocumentation=false, CodeCompleteContext *codeCompleteContext=nullptr)
Parse an AST module from the main file of the given source manager.
Definition: Parser.cpp:3120
Token signifying a code completion location.
Definition: Lexer.h:41
void appendAttribute(StringRef name, bool optional, const AttributeConstraint &constraint)
Append an attribute to this operation.
Definition: Operation.h:131
static UserConstraintDecl * createNative(Context &ctx, const Name &name, ArrayRef< VariableDecl *> inputs, ArrayRef< VariableDecl *> results, Optional< StringRef > codeBlock, Type resultType, ArrayRef< StringRef > nativeInputTypes={})
Create a native constraint with the given optional code block.
Definition: Nodes.h:842
U dyn_cast() const
Definition: Types.h:75
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
bool isNot(Kind k) const
Definition: Token.h:50
Token signifying a code completion location within a string.
Definition: Lexer.h:43
This class represents a PDLL tuple type, i.e.
Definition: Types.h:242
This class represents a generic ODS Type constraint.
Definition: Constraint.h:84
std::string getUniqueDefName() const
Returns a unique name for the TablGen def of this constraint.
Definition: Constraint.cpp:70
static ReplaceStmt * create(Context &ctx, SMRange loc, Expr *rootOp, ArrayRef< Expr *> replExprs)
Definition: Nodes.cpp:220
static StringRef getMemberName()
Return the member name used for the "all-results" access.
Definition: Nodes.h:477
bool isKeyword() const
Return true if this is one of the keyword token kinds (e.g. kw_if).
Definition: Token.cpp:186
bool isa() const
Provide type casting support.
Definition: Types.h:66
ArrayRef< Stmt * >::iterator begin() const
Definition: Nodes.h:190
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
static ValueType get(Context &context)
Return an instance of the Value type.
Definition: Types.cpp:169
static Module * create(Context &ctx, SMLoc loc, ArrayRef< Decl *> children)
Definition: Nodes.cpp:554
StringRef getDescription() const
Definition: Constraint.cpp:60
SMRange getLoc() const
Return the source location of this operation.
Definition: Operation.h:128
static LetStmt * create(Context &ctx, SMRange loc, VariableDecl *varDecl)
Definition: Nodes.cpp:202
This class contains all of the registered ODS operation classes.
Definition: Context.h:32
SMRange getLoc() const
Return the location of this node.
Definition: Nodes.h:130
static const Name & create(Context &ctx, StringRef name, SMRange location)
Definition: Nodes.cpp:32
static constexpr const bool value
This class represents the base Decl node.
Definition: Nodes.h:625
The class represents an Attribute constraint, and constrains a variable to be an Attribute.
Definition: Nodes.h:704
Format context containing substitutions for special placeholders.
Definition: Format.h:40
Paired punctuation.
Definition: Lexer.h:83
static AttrConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr=nullptr)
Definition: Nodes.cpp:369
virtual void codeCompleteDialectName()
Signal code completion for a dialect name.
Definition: CodeComplete.h:61
This class represents a scope for named AST decls.
Definition: Nodes.h:63
static VariableDecl * create(Context &ctx, const Name &name, Type type, Expr *initExpr, ArrayRef< ConstraintRef > constraints)
Definition: Nodes.cpp:536
StringRef getName() const
Return the raw string name.
Definition: Nodes.h:40
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
static TupleType get(Context &context, ArrayRef< Type > elementTypes, ArrayRef< StringRef > elementNames)
Return an instance of the Tuple type.
Definition: Types.cpp:138
Type getType() const
Return the type of this expression.
Definition: Nodes.h:347
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
const ConstraintDecl * constraint
Definition: Nodes.h:678
virtual void codeCompleteIncludeFilename(StringRef curPath)
Signal code completion for an include filename.
Definition: CodeComplete.h:70
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
This class provides support for representing a failure result, or a valid value of type T...
Definition: LogicalResult.h:78
Type refineWith(Type other) const
Try to refine this type with the one provided.
Definition: Types.cpp:32
The class represents a ValueRange constraint, and constrains a variable to be a ValueRange.
Definition: Nodes.h:802
static RewriteType get(Context &context)
Return an instance of the Rewrite type.
Definition: Types.cpp:130
FmtContext & withSelf(Twine subst)
Definition: Format.cpp:46
const Name * getName() const
Return the name of the decl, or nullptr if it doesn&#39;t have one.
Definition: Nodes.h:628
This class represents a base AST Statement node.
Definition: Nodes.h:163
static ReturnStmt * create(Context &ctx, SMRange loc, Expr *resultExpr)
Definition: Nodes.cpp:244
std::pair< Operation *, bool > insertOperation(StringRef name, StringRef summary, StringRef desc, StringRef nativeClassName, bool supportsResultTypeInferrence, SMLoc loc)
Insert a new operation with the context.
Definition: Context.cpp:63
const Operation * lookupOperation(StringRef name) const
Lookup an operation registered with the given name, or null if no operation with that name is registe...
Definition: Context.cpp:72
static CompoundStmt * create(Context &ctx, SMRange location, ArrayRef< Stmt *> children)
Definition: Nodes.cpp:187
This class represents a PDLL type that corresponds to an mlir::Attribute.
Definition: Types.h:130
void appendResult(StringRef name, VariableLengthKind variableLengthKind, const TypeConstraint &constraint)
Append a result to this operation.
Definition: Operation.h:144
ArrayRef< OperandOrResult > getResults() const
Returns the results of this operation.
Definition: Operation.h:168
This class represents a reference to a constraint, and contains a constraint and the location of the ...
Definition: Nodes.h:672
This class provides a simple implementation of a PDLL diagnostic.
Definition: Diagnostic.h:29
static TypeRangeConstraintDecl * create(Context &ctx, SMRange loc)
Definition: Nodes.cpp:405
static TypeExpr * create(Context &ctx, SMRange loc, StringRef value)
Definition: Nodes.cpp:352
void appendOperand(StringRef name, VariableLengthKind variableLengthKind, const TypeConstraint &constraint)
Append an operand to this operation.
Definition: Operation.h:137
virtual void codeCompleteOperationOperandsSignature(Optional< StringRef > opName, unsigned currentNumOperands)
Signal code completion for the signature of an operation&#39;s operands.
Definition: CodeComplete.h:82
const AttributeConstraint & insertAttributeConstraint(StringRef name, StringRef summary, StringRef cppClass)
Insert a new attribute constraint with the context.
Definition: Context.cpp:27
ArrayRef< VariableDecl * > getInputs() const
Return the inputs of this decl.
Definition: Nodes.h:1151
static TypeType get(Context &context)
Return an instance of the Type type.
Definition: Types.cpp:161
This class represents a PDLL type that corresponds to a constraint.
Definition: Types.h:144
General keywords.
Definition: Lexer.h:56
This class represents a base AST node.
Definition: Nodes.h:107
void add(Decl *decl)
Add a new decl to the scope.
Definition: Nodes.cpp:170
The class represents a Value constraint, and constrains a variable to be a Value. ...
Definition: Nodes.h:780
This decl represents a shared interface for all callable decls.
Definition: Nodes.h:1140
virtual void codeCompleteOperationAttributeName(StringRef opName)
Signal code completion for a member access into the given operation type.
Definition: CodeComplete.h:47
raw_indented_ostream & printReindented(StringRef str)
Prints a string re-indented to the current indent.
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostic.h:83
virtual void codeCompleteOperationMemberAccess(ast::OperationType opType)
Signal code completion for a member access into the given operation type.
This class provides an abstract interface into the parser for hooking in code completion events...
Definition: CodeComplete.h:29
virtual void codeCompleteConstraintName(ast::Type currentType, bool allowNonCoreConstraints, bool allowInlineTypeConstraints, const ast::DeclScope *scope)
Signal code completion for a constraint name with an optional decl scope.
This class represents a PDLL type that corresponds to a rewrite reference.
Definition: Types.h:228
StringRef getSpelling() const
Definition: Token.h:34
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Diagnostic & attachNote(const Twine &msg, Optional< SMRange > noteLoc=llvm::None)
Attach a note to this diagnostic.
Definition: Diagnostic.h:46
This Decl represents a NamedAttribute, and contains a string name and attribute value.
Definition: Nodes.h:945
This class provides an ODS representation of a specific operation.
Definition: Operation.h:125
static ValueRangeType get(Context &context)
Return an instance of the ValueRange type.
Definition: Types.cpp:121
static MemberAccessExpr * create(Context &ctx, SMRange loc, const Expr *parentExpr, StringRef memberName, Type type)
Definition: Nodes.cpp:289
This class provides a convenient API for interacting with source names.
Definition: Nodes.h:36
virtual void codeCompleteTupleMemberAccess(ast::TupleType tupleType)
Signal code completion for a member access into the given tuple type.
Optional< StringRef > getName() const
Return the name of this operation, or none if the name is unknown.
Definition: Nodes.h:975
static ValueConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr)
Definition: Nodes.cpp:415
virtual void codeCompleteCallSignature(const ast::CallableDecl *callable, unsigned currentNumArgs)
Signal code completion for the signature of a callable.
Definition: CodeComplete.h:77
DeclScope * getParentScope()
Return the parent scope of this scope, or nullptr if there is no parent.
Definition: Nodes.h:69
static TupleExpr * create(Context &ctx, SMRange loc, ArrayRef< Expr *> elements, ArrayRef< StringRef > elementNames)
Definition: Nodes.cpp:332
static TypeConstraintDecl * create(Context &ctx, SMRange loc)
Definition: Nodes.cpp:396
auto tgfmt(StringRef fmt, const FmtContext *ctx, Ts &&...vals) -> FmtObject< decltype(std::make_tuple(llvm::detail::build_format_adapter(std::forward< Ts >(vals))...))>
Formats text by substituting placeholders in format string with replacement parameters.
Definition: Format.h:263
This class provides an ODS representation of a specific operation operand or result.
Definition: Operation.h:74
virtual void codeCompleteOperationResultsSignature(Optional< StringRef > opName, unsigned currentNumResults)
Signal code completion for the signature of an operation&#39;s results.
Definition: CodeComplete.h:87
MutableArrayRef< Stmt * > getChildren()
Return the children of this compound statement.
Definition: Nodes.h:184
static OpConstraintDecl * create(Context &ctx, SMRange loc, const OpNameDecl *nameDecl=nullptr)
Definition: Nodes.cpp:379
StringRef getSummary() const
Definition: Constraint.cpp:54
Type getResultType() const
Return the result type of this decl.
Definition: Nodes.h:1158
static PatternDecl * create(Context &ctx, SMRange location, const Name *name, Optional< uint16_t > benefit, bool hasBoundedRecursion, const CompoundStmt *body)
Definition: Nodes.cpp:498
DiagnosticEngine & getDiagEngine()
Return the diagnostic engine of this context.
Definition: Context.h:42
const TypeConstraint & insertTypeConstraint(StringRef name, StringRef summary, StringRef cppClass)
Insert a new type constraint with the context.
Definition: Context.cpp:41
This Decl represents an OperationName.
Definition: Nodes.h:969
static OpNameDecl * create(Context &ctx, const Name &name)
Definition: Nodes.cpp:487
SMLoc getLoc() const
Definition: Token.cpp:18
static CallExpr * create(Context &ctx, SMRange loc, Expr *callable, ArrayRef< Expr *> arguments, Type resultType)
Definition: Nodes.cpp:263
static AttributeType get(Context &context)
Return an instance of the Attribute type.
Definition: Types.cpp:56
static AttributeExpr * create(Context &ctx, SMRange loc, StringRef value)
Definition: Nodes.cpp:253
static DeclRefExpr * create(Context &ctx, SMRange loc, Decl *decl, Type type)
Definition: Nodes.cpp:279
static RewriteStmt * create(Context &ctx, SMRange loc, Expr *rootOp, CompoundStmt *rewriteBody)
Definition: Nodes.cpp:234
raw_ostream subclass that simplifies indention a sequence of code.
This statement represents a compound statement, which contains a collection of other statements...
Definition: Nodes.h:177
SMRange getLoc() const
Get the location of this name.
Definition: Nodes.h:43
Punctuation.
Definition: Lexer.h:75
static ValueRangeConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr=nullptr)
Definition: Nodes.cpp:426
This class represents a base AST Expression node.
Definition: Nodes.h:344
static TypeRangeType get(Context &context)
Return an instance of the TypeRange type.
Definition: Types.cpp:109
InFlightDiagnostic emitWarning(SMRange loc, const Twine &msg)
Definition: Diagnostic.h:149
ArrayRef< Stmt * >::iterator end() const
Definition: Nodes.h:191
This class represents a PDLL type that corresponds to an mlir::Operation.
Definition: Types.h:157
static AllResultsMemberAccessExpr * create(Context &ctx, SMRange loc, const Expr *parentExpr, Type type)
Definition: Nodes.h:479
This class breaks up the current file into a token stream.
Definition: Lexer.h:23
This represents a token in the MLIR syntax.
Definition: Token.h:20