MLIR  21.0.0git
Nodes.h
Go to the documentation of this file.
1 //===- Nodes.h --------------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_TOOLS_PDLL_AST_NODES_H_
10 #define MLIR_TOOLS_PDLL_AST_NODES_H_
11 
12 #include "mlir/Support/LLVM.h"
14 #include "llvm/ADT/StringMap.h"
15 #include "llvm/ADT/StringRef.h"
16 #include "llvm/Support/SMLoc.h"
17 #include "llvm/Support/SourceMgr.h"
18 #include "llvm/Support/TrailingObjects.h"
19 #include <optional>
20 
21 namespace mlir {
22 namespace pdll {
23 namespace ast {
24 class Context;
25 class Decl;
26 class Expr;
27 class NamedAttributeDecl;
28 class OpNameDecl;
29 class VariableDecl;
30 
31 //===----------------------------------------------------------------------===//
32 // Name
33 //===----------------------------------------------------------------------===//
34 
35 /// This class provides a convenient API for interacting with source names. It
36 /// contains a string name as well as the source location for that name.
37 struct Name {
38  static const Name &create(Context &ctx, StringRef name, SMRange location);
39 
40  /// Return the raw string name.
41  StringRef getName() const { return name; }
42 
43  /// Get the location of this name.
44  SMRange getLoc() const { return location; }
45 
46 private:
47  Name() = delete;
48  Name(const Name &) = delete;
49  Name &operator=(const Name &) = delete;
50  Name(StringRef name, SMRange location) : name(name), location(location) {}
51 
52  /// The string name of the decl.
53  StringRef name;
54  /// The location of the decl name.
55  SMRange location;
56 };
57 
58 //===----------------------------------------------------------------------===//
59 // DeclScope
60 //===----------------------------------------------------------------------===//
61 
62 /// This class represents a scope for named AST decls. A scope determines the
63 /// visibility and lifetime of a named declaration.
64 class DeclScope {
65 public:
66  /// Create a new scope with an optional parent scope.
67  DeclScope(DeclScope *parent = nullptr) : parent(parent) {}
68 
69  /// Return the parent scope of this scope, or nullptr if there is no parent.
70  DeclScope *getParentScope() { return parent; }
71  const DeclScope *getParentScope() const { return parent; }
72 
73  /// Return all of the decls within this scope.
74  auto getDecls() const { return llvm::make_second_range(decls); }
75 
76  /// Add a new decl to the scope.
77  void add(Decl *decl);
78 
79  /// Lookup a decl with the given name starting from this scope. Returns
80  /// nullptr if no decl could be found.
81  Decl *lookup(StringRef name);
82  template <typename T>
83  T *lookup(StringRef name) {
84  return dyn_cast_or_null<T>(lookup(name));
85  }
86  const Decl *lookup(StringRef name) const {
87  return const_cast<DeclScope *>(this)->lookup(name);
88  }
89  template <typename T>
90  const T *lookup(StringRef name) const {
91  return dyn_cast_or_null<T>(lookup(name));
92  }
93 
94 private:
95  /// The parent scope, or null if this is a top-level scope.
96  DeclScope *parent;
97  /// The decls defined within this scope.
98  llvm::StringMap<Decl *> decls;
99 };
100 
101 //===----------------------------------------------------------------------===//
102 // Node
103 //===----------------------------------------------------------------------===//
104 
105 /// This class represents a base AST node. All AST nodes are derived from this
106 /// class, and it contains many of the base functionality for interacting with
107 /// nodes.
108 class Node {
109 public:
110  /// This CRTP class provides several utilies when defining new AST nodes.
111  template <typename T, typename BaseT>
112  class NodeBase : public BaseT {
113  public:
115 
116  /// Provide type casting support.
117  static bool classof(const Node *node) {
118  return node->getTypeID() == TypeID::get<T>();
119  }
120 
121  protected:
122  template <typename... Args>
123  explicit NodeBase(SMRange loc, Args &&...args)
124  : BaseT(TypeID::get<T>(), loc, std::forward<Args>(args)...) {}
125  };
126 
127  /// Return the type identifier of this node.
128  TypeID getTypeID() const { return typeID; }
129 
130  /// Return the location of this node.
131  SMRange getLoc() const { return loc; }
132 
133  /// Print this node to the given stream.
134  void print(raw_ostream &os) const;
135 
136  /// Walk all of the nodes including, and nested under, this node in pre-order.
137  void walk(function_ref<void(const Node *)> walkFn) const;
138  template <typename WalkFnT, typename ArgT = typename llvm::function_traits<
139  WalkFnT>::template arg_t<0>>
140  std::enable_if_t<!std::is_convertible<const Node *, ArgT>::value>
141  walk(WalkFnT &&walkFn) const {
142  walk([&](const Node *node) {
143  if (const ArgT *derivedNode = dyn_cast<ArgT>(node))
144  walkFn(derivedNode);
145  });
146  }
147 
148 protected:
149  Node(TypeID typeID, SMRange loc) : typeID(typeID), loc(loc) {}
150 
151 private:
152  /// A unique type identifier for this node.
153  TypeID typeID;
154 
155  /// The location of this node.
156  SMRange loc;
157 };
158 
159 //===----------------------------------------------------------------------===//
160 // Stmt
161 //===----------------------------------------------------------------------===//
162 
163 /// This class represents a base AST Statement node.
164 class Stmt : public Node {
165 public:
166  using Node::Node;
167 
168  /// Provide type casting support.
169  static bool classof(const Node *node);
170 };
171 
172 //===----------------------------------------------------------------------===//
173 // CompoundStmt
174 //===----------------------------------------------------------------------===//
175 
176 /// This statement represents a compound statement, which contains a collection
177 /// of other statements.
178 class CompoundStmt final : public Node::NodeBase<CompoundStmt, Stmt>,
179  private llvm::TrailingObjects<CompoundStmt, Stmt *> {
180 public:
181  static CompoundStmt *create(Context &ctx, SMRange location,
182  ArrayRef<Stmt *> children);
183 
184  /// Return the children of this compound statement.
186  return {getTrailingObjects<Stmt *>(), numChildren};
187  }
189  return const_cast<CompoundStmt *>(this)->getChildren();
190  }
191  ArrayRef<Stmt *>::iterator begin() const { return getChildren().begin(); }
192  ArrayRef<Stmt *>::iterator end() const { return getChildren().end(); }
193 
194 private:
195  CompoundStmt(SMRange location, unsigned numChildren)
196  : Base(location), numChildren(numChildren) {}
197 
198  /// The number of held children statements.
199  unsigned numChildren;
200 
201  // Allow access to various privates.
202  friend class llvm::TrailingObjects<CompoundStmt, Stmt *>;
203 };
204 
205 //===----------------------------------------------------------------------===//
206 // LetStmt
207 //===----------------------------------------------------------------------===//
208 
209 /// This statement represents a `let` statement in PDLL. This statement is used
210 /// to define variables.
211 class LetStmt final : public Node::NodeBase<LetStmt, Stmt> {
212 public:
213  static LetStmt *create(Context &ctx, SMRange loc, VariableDecl *varDecl);
214 
215  /// Return the variable defined by this statement.
216  VariableDecl *getVarDecl() const { return varDecl; }
217 
218 private:
219  LetStmt(SMRange loc, VariableDecl *varDecl) : Base(loc), varDecl(varDecl) {}
220 
221  /// The variable defined by this statement.
222  VariableDecl *varDecl;
223 };
224 
225 //===----------------------------------------------------------------------===//
226 // OpRewriteStmt
227 //===----------------------------------------------------------------------===//
228 
229 /// This class represents a base operation rewrite statement. Operation rewrite
230 /// statements perform a set of transformations on a given root operation.
231 class OpRewriteStmt : public Stmt {
232 public:
233  /// Provide type casting support.
234  static bool classof(const Node *node);
235 
236  /// Return the root operation of this rewrite.
237  Expr *getRootOpExpr() const { return rootOp; }
238 
239 protected:
240  OpRewriteStmt(TypeID typeID, SMRange loc, Expr *rootOp)
241  : Stmt(typeID, loc), rootOp(rootOp) {}
242 
243 protected:
244  /// The root operation being rewritten.
246 };
247 
248 //===----------------------------------------------------------------------===//
249 // EraseStmt
250 //===----------------------------------------------------------------------===//
251 
252 /// This statement represents the `erase` statement in PDLL. This statement
253 /// erases the given root operation, corresponding roughly to the
254 /// PatternRewriter::eraseOp API.
255 class EraseStmt final : public Node::NodeBase<EraseStmt, OpRewriteStmt> {
256 public:
257  static EraseStmt *create(Context &ctx, SMRange loc, Expr *rootOp);
258 
259 private:
260  EraseStmt(SMRange loc, Expr *rootOp) : Base(loc, rootOp) {}
261 };
262 
263 //===----------------------------------------------------------------------===//
264 // ReplaceStmt
265 //===----------------------------------------------------------------------===//
266 
267 /// This statement represents the `replace` statement in PDLL. This statement
268 /// replace the given root operation with a set of values, corresponding roughly
269 /// to the PatternRewriter::replaceOp API.
270 class ReplaceStmt final : public Node::NodeBase<ReplaceStmt, OpRewriteStmt>,
271  private llvm::TrailingObjects<ReplaceStmt, Expr *> {
272 public:
273  static ReplaceStmt *create(Context &ctx, SMRange loc, Expr *rootOp,
274  ArrayRef<Expr *> replExprs);
275 
276  /// Return the replacement values of this statement.
278  return {getTrailingObjects<Expr *>(), numReplExprs};
279  }
281  return const_cast<ReplaceStmt *>(this)->getReplExprs();
282  }
283 
284 private:
285  ReplaceStmt(SMRange loc, Expr *rootOp, unsigned numReplExprs)
286  : Base(loc, rootOp), numReplExprs(numReplExprs) {}
287 
288  /// The number of replacement values within this statement.
289  unsigned numReplExprs;
290 
291  /// TrailingObject utilities.
292  friend class llvm::TrailingObjects<ReplaceStmt, Expr *>;
293 };
294 
295 //===----------------------------------------------------------------------===//
296 // RewriteStmt
297 //===----------------------------------------------------------------------===//
298 
299 /// This statement represents an operation rewrite that contains a block of
300 /// nested rewrite commands. This allows for building more complex operation
301 /// rewrites that span across multiple statements, which may be unconnected.
302 class RewriteStmt final : public Node::NodeBase<RewriteStmt, OpRewriteStmt> {
303 public:
304  static RewriteStmt *create(Context &ctx, SMRange loc, Expr *rootOp,
305  CompoundStmt *rewriteBody);
306 
307  /// Return the compound rewrite body.
308  CompoundStmt *getRewriteBody() const { return rewriteBody; }
309 
310 private:
311  RewriteStmt(SMRange loc, Expr *rootOp, CompoundStmt *rewriteBody)
312  : Base(loc, rootOp), rewriteBody(rewriteBody) {}
313 
314  /// The body of nested rewriters within this statement.
315  CompoundStmt *rewriteBody;
316 };
317 
318 //===----------------------------------------------------------------------===//
319 // ReturnStmt
320 //===----------------------------------------------------------------------===//
321 
322 /// This statement represents a return from a "callable" like decl, e.g. a
323 /// Constraint or a Rewrite.
324 class ReturnStmt final : public Node::NodeBase<ReturnStmt, Stmt> {
325 public:
326  static ReturnStmt *create(Context &ctx, SMRange loc, Expr *resultExpr);
327 
328  /// Return the result expression of this statement.
329  Expr *getResultExpr() { return resultExpr; }
330  const Expr *getResultExpr() const { return resultExpr; }
331 
332  /// Set the result expression of this statement.
333  void setResultExpr(Expr *expr) { resultExpr = expr; }
334 
335 private:
336  ReturnStmt(SMRange loc, Expr *resultExpr)
337  : Base(loc), resultExpr(resultExpr) {}
338 
339  // The result expression of this statement.
340  Expr *resultExpr;
341 };
342 
343 //===----------------------------------------------------------------------===//
344 // Expr
345 //===----------------------------------------------------------------------===//
346 
347 /// This class represents a base AST Expression node.
348 class Expr : public Stmt {
349 public:
350  /// Return the type of this expression.
351  Type getType() const { return type; }
352 
353  /// Provide type casting support.
354  static bool classof(const Node *node);
355 
356 protected:
357  Expr(TypeID typeID, SMRange loc, Type type) : Stmt(typeID, loc), type(type) {}
358 
359 private:
360  /// The type of this expression.
361  Type type;
362 };
363 
364 //===----------------------------------------------------------------------===//
365 // AttributeExpr
366 //===----------------------------------------------------------------------===//
367 
368 /// This expression represents a literal MLIR Attribute, and contains the
369 /// textual assembly format of that attribute.
370 class AttributeExpr : public Node::NodeBase<AttributeExpr, Expr> {
371 public:
372  static AttributeExpr *create(Context &ctx, SMRange loc, StringRef value);
373 
374  /// Get the raw value of this expression. This is the textual assembly format
375  /// of the MLIR Attribute.
376  StringRef getValue() const { return value; }
377 
378 private:
379  AttributeExpr(Context &ctx, SMRange loc, StringRef value)
380  : Base(loc, AttributeType::get(ctx)), value(value) {}
381 
382  /// The value referenced by this expression.
383  StringRef value;
384 };
385 
386 //===----------------------------------------------------------------------===//
387 // CallExpr
388 //===----------------------------------------------------------------------===//
389 
390 /// This expression represents a call to a decl, such as a
391 /// UserConstraintDecl/UserRewriteDecl.
392 class CallExpr final : public Node::NodeBase<CallExpr, Expr>,
393  private llvm::TrailingObjects<CallExpr, Expr *> {
394 public:
395  static CallExpr *create(Context &ctx, SMRange loc, Expr *callable,
396  ArrayRef<Expr *> arguments, Type resultType,
397  bool isNegated = false);
398 
399  /// Return the callable of this call.
400  Expr *getCallableExpr() const { return callable; }
401 
402  /// Return the arguments of this call.
404  return {getTrailingObjects<Expr *>(), numArgs};
405  }
407  return const_cast<CallExpr *>(this)->getArguments();
408  }
409 
410  /// Returns whether the result of this call is to be negated.
411  bool getIsNegated() const { return isNegated; }
412 
413 private:
414  CallExpr(SMRange loc, Type type, Expr *callable, unsigned numArgs,
415  bool isNegated)
416  : Base(loc, type), callable(callable), numArgs(numArgs),
417  isNegated(isNegated) {}
418 
419  /// The callable of this call.
420  Expr *callable;
421 
422  /// The number of arguments of the call.
423  unsigned numArgs;
424 
425  /// TrailingObject utilities.
426  friend llvm::TrailingObjects<CallExpr, Expr *>;
427 
428  // Is the result of this call to be negated.
429  bool isNegated;
430 };
431 
432 //===----------------------------------------------------------------------===//
433 // DeclRefExpr
434 //===----------------------------------------------------------------------===//
435 
436 /// This expression represents a reference to a Decl node.
437 class DeclRefExpr : public Node::NodeBase<DeclRefExpr, Expr> {
438 public:
439  static DeclRefExpr *create(Context &ctx, SMRange loc, Decl *decl, Type type);
440 
441  /// Get the decl referenced by this expression.
442  Decl *getDecl() const { return decl; }
443 
444 private:
445  DeclRefExpr(SMRange loc, Decl *decl, Type type)
446  : Base(loc, type), decl(decl) {}
447 
448  /// The decl referenced by this expression.
449  Decl *decl;
450 };
451 
452 //===----------------------------------------------------------------------===//
453 // MemberAccessExpr
454 //===----------------------------------------------------------------------===//
455 
456 /// This expression represents a named member or field access of a given parent
457 /// expression.
458 class MemberAccessExpr : public Node::NodeBase<MemberAccessExpr, Expr> {
459 public:
460  static MemberAccessExpr *create(Context &ctx, SMRange loc,
461  const Expr *parentExpr, StringRef memberName,
462  Type type);
463 
464  /// Get the parent expression of this access.
465  const Expr *getParentExpr() const { return parentExpr; }
466 
467  /// Return the name of the member being accessed.
468  StringRef getMemberName() const { return memberName; }
469 
470 private:
471  MemberAccessExpr(SMRange loc, const Expr *parentExpr, StringRef memberName,
472  Type type)
473  : Base(loc, type), parentExpr(parentExpr), memberName(memberName) {}
474 
475  /// The parent expression of this access.
476  const Expr *parentExpr;
477 
478  /// The name of the member being accessed from the parent.
479  StringRef memberName;
480 };
481 
482 //===----------------------------------------------------------------------===//
483 // AllResultsMemberAccessExpr
484 //===----------------------------------------------------------------------===//
485 
486 /// This class represents an instance of MemberAccessExpr that references all
487 /// results of an operation.
489 public:
490  /// Return the member name used for the "all-results" access.
491  static StringRef getMemberName() { return "$results"; }
492 
493  static AllResultsMemberAccessExpr *create(Context &ctx, SMRange loc,
494  const Expr *parentExpr, Type type) {
495  return cast<AllResultsMemberAccessExpr>(
496  MemberAccessExpr::create(ctx, loc, parentExpr, getMemberName(), type));
497  }
498 
499  /// Provide type casting support.
500  static bool classof(const Node *node) {
501  const MemberAccessExpr *memAccess = dyn_cast<MemberAccessExpr>(node);
502  return memAccess && memAccess->getMemberName() == getMemberName();
503  }
504 };
505 
506 //===----------------------------------------------------------------------===//
507 // OperationExpr
508 //===----------------------------------------------------------------------===//
509 
510 /// This expression represents the structural form of an MLIR Operation. It
511 /// represents either an input operation to match, or an operation to create
512 /// within a rewrite.
513 class OperationExpr final
514  : public Node::NodeBase<OperationExpr, Expr>,
515  private llvm::TrailingObjects<OperationExpr, Expr *,
516  NamedAttributeDecl *> {
517 public:
518  static OperationExpr *create(Context &ctx, SMRange loc,
519  const ods::Operation *odsOp,
520  const OpNameDecl *nameDecl,
521  ArrayRef<Expr *> operands,
522  ArrayRef<Expr *> resultTypes,
523  ArrayRef<NamedAttributeDecl *> attributes);
524 
525  /// Return the name of the operation, or std::nullopt if there isn't one.
526  std::optional<StringRef> getName() const;
527 
528  /// Return the declaration of the operation name.
529  const OpNameDecl *getNameDecl() const { return nameDecl; }
530 
531  /// Return the location of the name of the operation expression, or an invalid
532  /// location if there isn't a name.
533  SMRange getNameLoc() const { return nameLoc; }
534 
535  /// Return the operands of this operation.
537  return {getTrailingObjects<Expr *>(), numOperands};
538  }
540  return const_cast<OperationExpr *>(this)->getOperands();
541  }
542 
543  /// Return the result types of this operation.
545  return {getTrailingObjects<Expr *>() + numOperands, numResultTypes};
546  }
548  return const_cast<OperationExpr *>(this)->getResultTypes();
549  }
550 
551  /// Return the attributes of this operation.
553  return {getTrailingObjects<NamedAttributeDecl *>(), numAttributes};
554  }
556  return const_cast<OperationExpr *>(this)->getAttributes();
557  }
558 
559 private:
560  OperationExpr(SMRange loc, Type type, const OpNameDecl *nameDecl,
561  unsigned numOperands, unsigned numResultTypes,
562  unsigned numAttributes, SMRange nameLoc)
563  : Base(loc, type), nameDecl(nameDecl), numOperands(numOperands),
564  numResultTypes(numResultTypes), numAttributes(numAttributes),
565  nameLoc(nameLoc) {}
566 
567  /// The name decl of this expression.
568  const OpNameDecl *nameDecl;
569 
570  /// The number of operands, result types, and attributes of the operation.
571  unsigned numOperands, numResultTypes, numAttributes;
572 
573  /// The location of the operation name in the expression if it has a name.
574  SMRange nameLoc;
575 
576  /// TrailingObject utilities.
577  friend llvm::TrailingObjects<OperationExpr, Expr *, NamedAttributeDecl *>;
578  size_t numTrailingObjects(OverloadToken<Expr *>) const {
579  return numOperands + numResultTypes;
580  }
581 };
582 
583 //===----------------------------------------------------------------------===//
584 // RangeExpr
585 //===----------------------------------------------------------------------===//
586 
587 /// This expression builds a range from a set of element values (which may be
588 /// ranges themselves).
589 class RangeExpr final : public Node::NodeBase<RangeExpr, Expr>,
590  private llvm::TrailingObjects<RangeExpr, Expr *> {
591 public:
592  static RangeExpr *create(Context &ctx, SMRange loc, ArrayRef<Expr *> elements,
593  RangeType type);
594 
595  /// Return the element expressions of this range.
597  return {getTrailingObjects<Expr *>(), numElements};
598  }
600  return const_cast<RangeExpr *>(this)->getElements();
601  }
602 
603  /// Return the range result type of this expression.
604  RangeType getType() const { return mlir::cast<RangeType>(Base::getType()); }
605 
606 private:
607  RangeExpr(SMRange loc, RangeType type, unsigned numElements)
608  : Base(loc, type), numElements(numElements) {}
609 
610  /// The number of element values for this range.
611  unsigned numElements;
612 
613  /// TrailingObject utilities.
614  friend class llvm::TrailingObjects<RangeExpr, Expr *>;
615 };
616 
617 //===----------------------------------------------------------------------===//
618 // TupleExpr
619 //===----------------------------------------------------------------------===//
620 
621 /// This expression builds a tuple from a set of element values.
622 class TupleExpr final : public Node::NodeBase<TupleExpr, Expr>,
623  private llvm::TrailingObjects<TupleExpr, Expr *> {
624 public:
625  static TupleExpr *create(Context &ctx, SMRange loc, ArrayRef<Expr *> elements,
626  ArrayRef<StringRef> elementNames);
627 
628  /// Return the element expressions of this tuple.
630  return {getTrailingObjects<Expr *>(), getType().size()};
631  }
633  return const_cast<TupleExpr *>(this)->getElements();
634  }
635 
636  /// Return the tuple result type of this expression.
637  TupleType getType() const { return mlir::cast<TupleType>(Base::getType()); }
638 
639 private:
640  TupleExpr(SMRange loc, TupleType type) : Base(loc, type) {}
641 
642  /// TrailingObject utilities.
643  friend class llvm::TrailingObjects<TupleExpr, Expr *>;
644 };
645 
646 //===----------------------------------------------------------------------===//
647 // TypeExpr
648 //===----------------------------------------------------------------------===//
649 
650 /// This expression represents a literal MLIR Type, and contains the textual
651 /// assembly format of that type.
652 class TypeExpr : public Node::NodeBase<TypeExpr, Expr> {
653 public:
654  static TypeExpr *create(Context &ctx, SMRange loc, StringRef value);
655 
656  /// Get the raw value of this expression. This is the textual assembly format
657  /// of the MLIR Type.
658  StringRef getValue() const { return value; }
659 
660 private:
661  TypeExpr(Context &ctx, SMRange loc, StringRef value)
662  : Base(loc, TypeType::get(ctx)), value(value) {}
663 
664  /// The value referenced by this expression.
665  StringRef value;
666 };
667 
668 //===----------------------------------------------------------------------===//
669 // Decl
670 //===----------------------------------------------------------------------===//
671 
672 /// This class represents the base Decl node.
673 class Decl : public Node {
674 public:
675  /// Return the name of the decl, or nullptr if it doesn't have one.
676  const Name *getName() const { return name; }
677 
678  /// Provide type casting support.
679  static bool classof(const Node *node);
680 
681  /// Set the documentation comment for this decl.
682  void setDocComment(Context &ctx, StringRef comment);
683 
684  /// Return the documentation comment attached to this decl if it has been set.
685  /// Otherwise, returns std::nullopt.
686  std::optional<StringRef> getDocComment() const { return docComment; }
687 
688 protected:
689  Decl(TypeID typeID, SMRange loc, const Name *name = nullptr)
690  : Node(typeID, loc), name(name) {}
691 
692 private:
693  /// The name of the decl. This is optional for some decls, such as
694  /// PatternDecl.
695  const Name *name;
696 
697  /// The documentation comment attached to this decl. Defaults to std::nullopt
698  /// if the comment is unset/unknown.
699  std::optional<StringRef> docComment;
700 };
701 
702 //===----------------------------------------------------------------------===//
703 // ConstraintDecl
704 //===----------------------------------------------------------------------===//
705 
706 /// This class represents the base of all AST Constraint decls. Constraints
707 /// apply matcher conditions to, and define the type of PDLL variables.
708 class ConstraintDecl : public Decl {
709 public:
710  /// Provide type casting support.
711  static bool classof(const Node *node);
712 
713 protected:
714  ConstraintDecl(TypeID typeID, SMRange loc, const Name *name = nullptr)
715  : Decl(typeID, loc, name) {}
716 };
717 
718 /// This class represents a reference to a constraint, and contains a constraint
719 /// and the location of the reference.
721  ConstraintRef(const ConstraintDecl *constraint, SMRange refLoc)
722  : constraint(constraint), referenceLoc(refLoc) {}
724  : ConstraintRef(constraint, constraint->getLoc()) {}
725 
727  SMRange referenceLoc;
728 };
729 
730 //===----------------------------------------------------------------------===//
731 // CoreConstraintDecl
732 //===----------------------------------------------------------------------===//
733 
734 /// This class represents the base of all "core" constraints. Core constraints
735 /// are those that generally represent a concrete IR construct, such as
736 /// `Type`s or `Value`s.
738 public:
739  /// Provide type casting support.
740  static bool classof(const Node *node);
741 
742 protected:
743  CoreConstraintDecl(TypeID typeID, SMRange loc, const Name *name = nullptr)
744  : ConstraintDecl(typeID, loc, name) {}
745 };
746 
747 //===----------------------------------------------------------------------===//
748 // AttrConstraintDecl
749 //===----------------------------------------------------------------------===//
750 
751 /// The class represents an Attribute constraint, and constrains a variable to
752 /// be an Attribute.
754  : public Node::NodeBase<AttrConstraintDecl, CoreConstraintDecl> {
755 public:
756  static AttrConstraintDecl *create(Context &ctx, SMRange loc,
757  Expr *typeExpr = nullptr);
758 
759  /// Return the optional type the attribute is constrained to.
760  Expr *getTypeExpr() { return typeExpr; }
761  const Expr *getTypeExpr() const { return typeExpr; }
762 
763 protected:
765  : Base(loc), typeExpr(typeExpr) {}
766 
767  /// An optional type that the attribute is constrained to.
769 };
770 
771 //===----------------------------------------------------------------------===//
772 // OpConstraintDecl
773 //===----------------------------------------------------------------------===//
774 
775 /// The class represents an Operation constraint, and constrains a variable to
776 /// be an Operation.
778  : public Node::NodeBase<OpConstraintDecl, CoreConstraintDecl> {
779 public:
780  static OpConstraintDecl *create(Context &ctx, SMRange loc,
781  const OpNameDecl *nameDecl = nullptr);
782 
783  /// Return the name of the operation, or std::nullopt if there isn't one.
784  std::optional<StringRef> getName() const;
785 
786  /// Return the declaration of the operation name.
787  const OpNameDecl *getNameDecl() const { return nameDecl; }
788 
789 protected:
790  explicit OpConstraintDecl(SMRange loc, const OpNameDecl *nameDecl)
791  : Base(loc), nameDecl(nameDecl) {}
792 
793  /// The operation name of this constraint.
795 };
796 
797 //===----------------------------------------------------------------------===//
798 // TypeConstraintDecl
799 //===----------------------------------------------------------------------===//
800 
801 /// The class represents a Type constraint, and constrains a variable to be a
802 /// Type.
804  : public Node::NodeBase<TypeConstraintDecl, CoreConstraintDecl> {
805 public:
806  static TypeConstraintDecl *create(Context &ctx, SMRange loc);
807 
808 protected:
809  using Base::Base;
810 };
811 
812 //===----------------------------------------------------------------------===//
813 // TypeRangeConstraintDecl
814 //===----------------------------------------------------------------------===//
815 
816 /// The class represents a TypeRange constraint, and constrains a variable to be
817 /// a TypeRange.
819  : public Node::NodeBase<TypeRangeConstraintDecl, CoreConstraintDecl> {
820 public:
821  static TypeRangeConstraintDecl *create(Context &ctx, SMRange loc);
822 
823 protected:
824  using Base::Base;
825 };
826 
827 //===----------------------------------------------------------------------===//
828 // ValueConstraintDecl
829 //===----------------------------------------------------------------------===//
830 
831 /// The class represents a Value constraint, and constrains a variable to be a
832 /// Value.
834  : public Node::NodeBase<ValueConstraintDecl, CoreConstraintDecl> {
835 public:
836  static ValueConstraintDecl *create(Context &ctx, SMRange loc, Expr *typeExpr);
837 
838  /// Return the optional type the value is constrained to.
839  Expr *getTypeExpr() { return typeExpr; }
840  const Expr *getTypeExpr() const { return typeExpr; }
841 
842 protected:
844  : Base(loc), typeExpr(typeExpr) {}
845 
846  /// An optional type that the value is constrained to.
848 };
849 
850 //===----------------------------------------------------------------------===//
851 // ValueRangeConstraintDecl
852 //===----------------------------------------------------------------------===//
853 
854 /// The class represents a ValueRange constraint, and constrains a variable to
855 /// be a ValueRange.
857  : public Node::NodeBase<ValueRangeConstraintDecl, CoreConstraintDecl> {
858 public:
859  static ValueRangeConstraintDecl *create(Context &ctx, SMRange loc,
860  Expr *typeExpr = nullptr);
861 
862  /// Return the optional type the value range is constrained to.
863  Expr *getTypeExpr() { return typeExpr; }
864  const Expr *getTypeExpr() const { return typeExpr; }
865 
866 protected:
868  : Base(loc), typeExpr(typeExpr) {}
869 
870  /// An optional type that the value range is constrained to.
872 };
873 
874 //===----------------------------------------------------------------------===//
875 // UserConstraintDecl
876 //===----------------------------------------------------------------------===//
877 
878 /// This decl represents a user defined constraint. This is either:
879 /// * an imported native constraint
880 /// - Similar to an external function declaration. This is a native
881 /// constraint defined externally, and imported into PDLL via a
882 /// declaration.
883 /// * a native constraint defined in PDLL
884 /// - This is a native constraint, i.e. a constraint whose implementation is
885 /// defined in C++(or potentially some other non-PDLL language). The
886 /// implementation of this constraint is specified as a string code block
887 /// in PDLL.
888 /// * a PDLL constraint
889 /// - This is a constraint which is defined using only PDLL constructs.
891  : public Node::NodeBase<UserConstraintDecl, ConstraintDecl>,
892  llvm::TrailingObjects<UserConstraintDecl, VariableDecl *, StringRef> {
893 public:
894  /// Create a native constraint with the given optional code block.
895  static UserConstraintDecl *
897  ArrayRef<VariableDecl *> results,
898  std::optional<StringRef> codeBlock, Type resultType,
899  ArrayRef<StringRef> nativeInputTypes = {}) {
900  return createImpl(ctx, name, inputs, nativeInputTypes, results, codeBlock,
901  /*body=*/nullptr, resultType);
902  }
903 
904  /// Create a PDLL constraint with the given body.
905  static UserConstraintDecl *createPDLL(Context &ctx, const Name &name,
907  ArrayRef<VariableDecl *> results,
908  const CompoundStmt *body,
909  Type resultType) {
910  return createImpl(ctx, name, inputs, /*nativeInputTypes=*/std::nullopt,
911  results, /*codeBlock=*/std::nullopt, body, resultType);
912  }
913 
914  /// Return the name of the constraint.
915  const Name &getName() const { return *Decl::getName(); }
916 
917  /// Return the input arguments of this constraint.
919  return {getTrailingObjects<VariableDecl *>(), numInputs};
920  }
922  return const_cast<UserConstraintDecl *>(this)->getInputs();
923  }
924 
925  /// Return the explicit native type to use for the given input. Returns
926  /// std::nullopt if no explicit type was set.
927  std::optional<StringRef> getNativeInputType(unsigned index) const;
928 
929  /// Return the explicit results of the constraint declaration. May be empty,
930  /// even if the constraint has results (e.g. in the case of inferred results).
932  return {getTrailingObjects<VariableDecl *>() + numInputs, numResults};
933  }
935  return const_cast<UserConstraintDecl *>(this)->getResults();
936  }
937 
938  /// Return the optional code block of this constraint, if this is a native
939  /// constraint with a provided implementation.
940  std::optional<StringRef> getCodeBlock() const { return codeBlock; }
941 
942  /// Return the body of this constraint if this constraint is a PDLL
943  /// constraint, otherwise returns nullptr.
944  const CompoundStmt *getBody() const { return constraintBody; }
945 
946  /// Return the result type of this constraint.
947  Type getResultType() const { return resultType; }
948 
949  /// Returns true if this constraint is external.
950  bool isExternal() const { return !constraintBody && !codeBlock; }
951 
952 private:
953  /// Create either a PDLL constraint or a native constraint with the given
954  /// components.
955  static UserConstraintDecl *createImpl(Context &ctx, const Name &name,
957  ArrayRef<StringRef> nativeInputTypes,
958  ArrayRef<VariableDecl *> results,
959  std::optional<StringRef> codeBlock,
960  const CompoundStmt *body,
961  Type resultType);
962 
963  UserConstraintDecl(const Name &name, unsigned numInputs,
964  bool hasNativeInputTypes, unsigned numResults,
965  std::optional<StringRef> codeBlock,
966  const CompoundStmt *body, Type resultType)
967  : Base(name.getLoc(), &name), numInputs(numInputs),
968  numResults(numResults), codeBlock(codeBlock), constraintBody(body),
969  resultType(resultType), hasNativeInputTypes(hasNativeInputTypes) {}
970 
971  /// The number of inputs to this constraint.
972  unsigned numInputs;
973 
974  /// The number of explicit results to this constraint.
975  unsigned numResults;
976 
977  /// The optional code block of this constraint.
978  std::optional<StringRef> codeBlock;
979 
980  /// The optional body of this constraint.
981  const CompoundStmt *constraintBody;
982 
983  /// The result type of the constraint.
984  Type resultType;
985 
986  /// Flag indicating if this constraint has explicit native input types.
987  bool hasNativeInputTypes;
988 
989  /// Allow access to various internals.
990  friend llvm::TrailingObjects<UserConstraintDecl, VariableDecl *, StringRef>;
991  size_t numTrailingObjects(OverloadToken<VariableDecl *>) const {
992  return numInputs + numResults;
993  }
994 };
995 
996 //===----------------------------------------------------------------------===//
997 // NamedAttributeDecl
998 //===----------------------------------------------------------------------===//
999 
1000 /// This Decl represents a NamedAttribute, and contains a string name and
1001 /// attribute value.
1002 class NamedAttributeDecl : public Node::NodeBase<NamedAttributeDecl, Decl> {
1003 public:
1004  static NamedAttributeDecl *create(Context &ctx, const Name &name,
1005  Expr *value);
1006 
1007  /// Return the name of the attribute.
1008  const Name &getName() const { return *Decl::getName(); }
1009 
1010  /// Return value of the attribute.
1011  Expr *getValue() const { return value; }
1012 
1013 private:
1014  NamedAttributeDecl(const Name &name, Expr *value)
1015  : Base(name.getLoc(), &name), value(value) {}
1016 
1017  /// The value of the attribute.
1018  Expr *value;
1019 };
1020 
1021 //===----------------------------------------------------------------------===//
1022 // OpNameDecl
1023 //===----------------------------------------------------------------------===//
1024 
1025 /// This Decl represents an OperationName.
1026 class OpNameDecl : public Node::NodeBase<OpNameDecl, Decl> {
1027 public:
1028  static OpNameDecl *create(Context &ctx, const Name &name);
1029  static OpNameDecl *create(Context &ctx, SMRange loc);
1030 
1031  /// Return the name of this operation, or std::nullopt if the name is unknown.
1032  std::optional<StringRef> getName() const {
1033  const Name *name = Decl::getName();
1034  return name ? std::optional<StringRef>(name->getName()) : std::nullopt;
1035  }
1036 
1037 private:
1038  explicit OpNameDecl(const Name &name) : Base(name.getLoc(), &name) {}
1039  explicit OpNameDecl(SMRange loc) : Base(loc) {}
1040 };
1041 
1042 //===----------------------------------------------------------------------===//
1043 // PatternDecl
1044 //===----------------------------------------------------------------------===//
1045 
1046 /// This Decl represents a single Pattern.
1047 class PatternDecl : public Node::NodeBase<PatternDecl, Decl> {
1048 public:
1049  static PatternDecl *create(Context &ctx, SMRange location, const Name *name,
1050  std::optional<uint16_t> benefit,
1051  bool hasBoundedRecursion,
1052  const CompoundStmt *body);
1053 
1054  /// Return the benefit of this pattern if specified, or std::nullopt.
1055  std::optional<uint16_t> getBenefit() const { return benefit; }
1056 
1057  /// Return if this pattern has bounded rewrite recursion.
1058  bool hasBoundedRewriteRecursion() const { return hasBoundedRecursion; }
1059 
1060  /// Return the body of this pattern.
1061  const CompoundStmt *getBody() const { return patternBody; }
1062 
1063  /// Return the root rewrite statement of this pattern.
1065  return cast<OpRewriteStmt>(patternBody->getChildren().back());
1066  }
1067 
1068 private:
1069  PatternDecl(SMRange loc, const Name *name, std::optional<uint16_t> benefit,
1070  bool hasBoundedRecursion, const CompoundStmt *body)
1071  : Base(loc, name), benefit(benefit),
1072  hasBoundedRecursion(hasBoundedRecursion), patternBody(body) {}
1073 
1074  /// The benefit of the pattern if it was explicitly specified, std::nullopt
1075  /// otherwise.
1076  std::optional<uint16_t> benefit;
1077 
1078  /// If the pattern has properly bounded rewrite recursion or not.
1079  bool hasBoundedRecursion;
1080 
1081  /// The compound statement representing the body of the pattern.
1082  const CompoundStmt *patternBody;
1083 };
1084 
1085 //===----------------------------------------------------------------------===//
1086 // UserRewriteDecl
1087 //===----------------------------------------------------------------------===//
1088 
1089 /// This decl represents a user defined rewrite. This is either:
1090 /// * an imported native rewrite
1091 /// - Similar to an external function declaration. This is a native
1092 /// rewrite defined externally, and imported into PDLL via a declaration.
1093 /// * a native rewrite defined in PDLL
1094 /// - This is a native rewrite, i.e. a rewrite whose implementation is
1095 /// defined in C++(or potentially some other non-PDLL language). The
1096 /// implementation of this rewrite is specified as a string code block
1097 /// in PDLL.
1098 /// * a PDLL rewrite
1099 /// - This is a rewrite which is defined using only PDLL constructs.
1100 class UserRewriteDecl final
1101  : public Node::NodeBase<UserRewriteDecl, Decl>,
1102  llvm::TrailingObjects<UserRewriteDecl, VariableDecl *> {
1103 public:
1104  /// Create a native rewrite with the given optional code block.
1105  static UserRewriteDecl *createNative(Context &ctx, const Name &name,
1106  ArrayRef<VariableDecl *> inputs,
1107  ArrayRef<VariableDecl *> results,
1108  std::optional<StringRef> codeBlock,
1109  Type resultType) {
1110  return createImpl(ctx, name, inputs, results, codeBlock, /*body=*/nullptr,
1111  resultType);
1112  }
1113 
1114  /// Create a PDLL rewrite with the given body.
1115  static UserRewriteDecl *createPDLL(Context &ctx, const Name &name,
1116  ArrayRef<VariableDecl *> inputs,
1117  ArrayRef<VariableDecl *> results,
1118  const CompoundStmt *body,
1119  Type resultType) {
1120  return createImpl(ctx, name, inputs, results, /*codeBlock=*/std::nullopt,
1121  body, resultType);
1122  }
1123 
1124  /// Return the name of the rewrite.
1125  const Name &getName() const { return *Decl::getName(); }
1126 
1127  /// Return the input arguments of this rewrite.
1129  return {getTrailingObjects<VariableDecl *>(), numInputs};
1130  }
1132  return const_cast<UserRewriteDecl *>(this)->getInputs();
1133  }
1134 
1135  /// Return the explicit results of the rewrite declaration. May be empty,
1136  /// even if the rewrite has results (e.g. in the case of inferred results).
1138  return {getTrailingObjects<VariableDecl *>() + numInputs, numResults};
1139  }
1141  return const_cast<UserRewriteDecl *>(this)->getResults();
1142  }
1143 
1144  /// Return the optional code block of this rewrite, if this is a native
1145  /// rewrite with a provided implementation.
1146  std::optional<StringRef> getCodeBlock() const { return codeBlock; }
1147 
1148  /// Return the body of this rewrite if this rewrite is a PDLL rewrite,
1149  /// otherwise returns nullptr.
1150  const CompoundStmt *getBody() const { return rewriteBody; }
1151 
1152  /// Return the result type of this rewrite.
1153  Type getResultType() const { return resultType; }
1154 
1155  /// Returns true if this rewrite is external.
1156  bool isExternal() const { return !rewriteBody && !codeBlock; }
1157 
1158 private:
1159  /// Create either a PDLL rewrite or a native rewrite with the given
1160  /// components.
1161  static UserRewriteDecl *createImpl(Context &ctx, const Name &name,
1162  ArrayRef<VariableDecl *> inputs,
1163  ArrayRef<VariableDecl *> results,
1164  std::optional<StringRef> codeBlock,
1165  const CompoundStmt *body, Type resultType);
1166 
1167  UserRewriteDecl(const Name &name, unsigned numInputs, unsigned numResults,
1168  std::optional<StringRef> codeBlock, const CompoundStmt *body,
1169  Type resultType)
1170  : Base(name.getLoc(), &name), numInputs(numInputs),
1171  numResults(numResults), codeBlock(codeBlock), rewriteBody(body),
1172  resultType(resultType) {}
1173 
1174  /// The number of inputs to this rewrite.
1175  unsigned numInputs;
1176 
1177  /// The number of explicit results to this rewrite.
1178  unsigned numResults;
1179 
1180  /// The optional code block of this rewrite.
1181  std::optional<StringRef> codeBlock;
1182 
1183  /// The optional body of this rewrite.
1184  const CompoundStmt *rewriteBody;
1185 
1186  /// The result type of the rewrite.
1187  Type resultType;
1188 
1189  /// Allow access to various internals.
1190  friend llvm::TrailingObjects<UserRewriteDecl, VariableDecl *>;
1191 };
1192 
1193 //===----------------------------------------------------------------------===//
1194 // CallableDecl
1195 //===----------------------------------------------------------------------===//
1196 
1197 /// This decl represents a shared interface for all callable decls.
1198 class CallableDecl : public Decl {
1199 public:
1200  /// Return the callable type of this decl.
1201  StringRef getCallableType() const {
1202  if (isa<UserConstraintDecl>(this))
1203  return "constraint";
1204  assert(isa<UserRewriteDecl>(this) && "unknown callable type");
1205  return "rewrite";
1206  }
1207 
1208  /// Return the inputs of this decl.
1210  if (const auto *cst = dyn_cast<UserConstraintDecl>(this))
1211  return cst->getInputs();
1212  return cast<UserRewriteDecl>(this)->getInputs();
1213  }
1214 
1215  /// Return the result type of this decl.
1217  if (const auto *cst = dyn_cast<UserConstraintDecl>(this))
1218  return cst->getResultType();
1219  return cast<UserRewriteDecl>(this)->getResultType();
1220  }
1221 
1222  /// Return the explicit results of the declaration. Note that these may be
1223  /// empty, even if the callable has results (e.g. in the case of inferred
1224  /// results).
1226  if (const auto *cst = dyn_cast<UserConstraintDecl>(this))
1227  return cst->getResults();
1228  return cast<UserRewriteDecl>(this)->getResults();
1229  }
1230 
1231  /// Return the optional code block of this callable, if this is a native
1232  /// callable with a provided implementation.
1233  std::optional<StringRef> getCodeBlock() const {
1234  if (const auto *cst = dyn_cast<UserConstraintDecl>(this))
1235  return cst->getCodeBlock();
1236  return cast<UserRewriteDecl>(this)->getCodeBlock();
1237  }
1238 
1239  /// Support LLVM type casting facilities.
1240  static bool classof(const Node *decl) {
1241  return isa<UserConstraintDecl, UserRewriteDecl>(decl);
1242  }
1243 };
1244 
1245 //===----------------------------------------------------------------------===//
1246 // VariableDecl
1247 //===----------------------------------------------------------------------===//
1248 
1249 /// This Decl represents the definition of a PDLL variable.
1250 class VariableDecl final
1251  : public Node::NodeBase<VariableDecl, Decl>,
1252  private llvm::TrailingObjects<VariableDecl, ConstraintRef> {
1253 public:
1254  static VariableDecl *create(Context &ctx, const Name &name, Type type,
1255  Expr *initExpr,
1256  ArrayRef<ConstraintRef> constraints);
1257 
1258  /// Return the constraints of this variable.
1260  return {getTrailingObjects<ConstraintRef>(), numConstraints};
1261  }
1263  return const_cast<VariableDecl *>(this)->getConstraints();
1264  }
1265 
1266  /// Return the initializer expression of this statement, or nullptr if there
1267  /// was no initializer.
1268  Expr *getInitExpr() const { return initExpr; }
1269 
1270  /// Return the name of the decl.
1271  const Name &getName() const { return *Decl::getName(); }
1272 
1273  /// Return the type of the decl.
1274  Type getType() const { return type; }
1275 
1276 private:
1277  VariableDecl(const Name &name, Type type, Expr *initExpr,
1278  unsigned numConstraints)
1279  : Base(name.getLoc(), &name), type(type), initExpr(initExpr),
1280  numConstraints(numConstraints) {}
1281 
1282  /// The type of the variable.
1283  Type type;
1284 
1285  /// The optional initializer expression of this statement.
1286  Expr *initExpr;
1287 
1288  /// The number of constraints attached to this variable.
1289  unsigned numConstraints;
1290 
1291  /// Allow access to various internals.
1292  friend llvm::TrailingObjects<VariableDecl, ConstraintRef>;
1293 };
1294 
1295 //===----------------------------------------------------------------------===//
1296 // Module
1297 //===----------------------------------------------------------------------===//
1298 
1299 /// This class represents a top-level AST module.
1300 class Module final : public Node::NodeBase<Module, Node>,
1301  private llvm::TrailingObjects<Module, Decl *> {
1302 public:
1303  static Module *create(Context &ctx, SMLoc loc, ArrayRef<Decl *> children);
1304 
1305  /// Return the children of this module.
1307  return {getTrailingObjects<Decl *>(), numChildren};
1308  }
1310  return const_cast<Module *>(this)->getChildren();
1311  }
1312 
1313 private:
1314  Module(SMLoc loc, unsigned numChildren)
1315  : Base(SMRange{loc, loc}), numChildren(numChildren) {}
1316 
1317  /// The number of decls held by this module.
1318  unsigned numChildren;
1319 
1320  /// Allow access to various internals.
1321  friend llvm::TrailingObjects<Module, Decl *>;
1322 };
1323 
1324 //===----------------------------------------------------------------------===//
1325 // Defered Method Definitions
1326 //===----------------------------------------------------------------------===//
1327 
1328 inline bool Decl::classof(const Node *node) {
1330  UserRewriteDecl, VariableDecl>(node);
1331 }
1332 
1333 inline bool ConstraintDecl::classof(const Node *node) {
1334  return isa<CoreConstraintDecl, UserConstraintDecl>(node);
1335 }
1336 
1337 inline bool CoreConstraintDecl::classof(const Node *node) {
1340  ValueRangeConstraintDecl>(node);
1341 }
1342 
1343 inline bool Expr::classof(const Node *node) {
1346 }
1347 
1348 inline bool OpRewriteStmt::classof(const Node *node) {
1349  return isa<EraseStmt, ReplaceStmt, RewriteStmt>(node);
1350 }
1351 
1352 inline bool Stmt::classof(const Node *node) {
1353  return isa<CompoundStmt, LetStmt, OpRewriteStmt, Expr>(node);
1354 }
1355 
1356 } // namespace ast
1357 } // namespace pdll
1358 } // namespace mlir
1359 
1360 #endif // MLIR_TOOLS_PDLL_AST_NODES_H_
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:107
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of MemberAccessExpr that references all results of an operation.
Definition: Nodes.h:488
static StringRef getMemberName()
Return the member name used for the "all-results" access.
Definition: Nodes.h:491
static AllResultsMemberAccessExpr * create(Context &ctx, SMRange loc, const Expr *parentExpr, Type type)
Definition: Nodes.h:493
static bool classof(const Node *node)
Provide type casting support.
Definition: Nodes.h:500
The class represents an Attribute constraint, and constrains a variable to be an Attribute.
Definition: Nodes.h:754
Expr * typeExpr
An optional type that the attribute is constrained to.
Definition: Nodes.h:768
static AttrConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr=nullptr)
Definition: Nodes.cpp:393
Expr * getTypeExpr()
Return the optional type the attribute is constrained to.
Definition: Nodes.h:760
const Expr * getTypeExpr() const
Definition: Nodes.h:761
AttrConstraintDecl(SMRange loc, Expr *typeExpr)
Definition: Nodes.h:764
This expression represents a literal MLIR Attribute, and contains the textual assembly format of that...
Definition: Nodes.h:370
static AttributeExpr * create(Context &ctx, SMRange loc, StringRef value)
Definition: Nodes.cpp:261
StringRef getValue() const
Get the raw value of this expression.
Definition: Nodes.h:376
This class represents a PDLL type that corresponds to an mlir::Attribute.
Definition: Types.h:107
This expression represents a call to a decl, such as a UserConstraintDecl/UserRewriteDecl.
Definition: Nodes.h:393
Expr * getCallableExpr() const
Return the callable of this call.
Definition: Nodes.h:400
MutableArrayRef< Expr * > getArguments()
Return the arguments of this call.
Definition: Nodes.h:403
ArrayRef< Expr * > getArguments() const
Definition: Nodes.h:406
static CallExpr * create(Context &ctx, SMRange loc, Expr *callable, ArrayRef< Expr * > arguments, Type resultType, bool isNegated=false)
Definition: Nodes.cpp:271
bool getIsNegated() const
Returns whether the result of this call is to be negated.
Definition: Nodes.h:411
This decl represents a shared interface for all callable decls.
Definition: Nodes.h:1198
std::optional< StringRef > getCodeBlock() const
Return the optional code block of this callable, if this is a native callable with a provided impleme...
Definition: Nodes.h:1233
Type getResultType() const
Return the result type of this decl.
Definition: Nodes.h:1216
ArrayRef< VariableDecl * > getInputs() const
Return the inputs of this decl.
Definition: Nodes.h:1209
StringRef getCallableType() const
Return the callable type of this decl.
Definition: Nodes.h:1201
ArrayRef< VariableDecl * > getResults() const
Return the explicit results of the declaration.
Definition: Nodes.h:1225
static bool classof(const Node *decl)
Support LLVM type casting facilities.
Definition: Nodes.h:1240
This statement represents a compound statement, which contains a collection of other statements.
Definition: Nodes.h:179
ArrayRef< Stmt * > getChildren() const
Definition: Nodes.h:188
ArrayRef< Stmt * >::iterator begin() const
Definition: Nodes.h:191
MutableArrayRef< Stmt * > getChildren()
Return the children of this compound statement.
Definition: Nodes.h:185
ArrayRef< Stmt * >::iterator end() const
Definition: Nodes.h:192
static CompoundStmt * create(Context &ctx, SMRange location, ArrayRef< Stmt * > children)
Definition: Nodes.cpp:192
This class represents the base of all AST Constraint decls.
Definition: Nodes.h:708
static bool classof(const Node *node)
Provide type casting support.
Definition: Nodes.h:1333
ConstraintDecl(TypeID typeID, SMRange loc, const Name *name=nullptr)
Definition: Nodes.h:714
This class represents the main context of the PDLL AST.
Definition: Context.h:25
This class represents the base of all "core" constraints.
Definition: Nodes.h:737
static bool classof(const Node *node)
Provide type casting support.
Definition: Nodes.h:1337
CoreConstraintDecl(TypeID typeID, SMRange loc, const Name *name=nullptr)
Definition: Nodes.h:743
This expression represents a reference to a Decl node.
Definition: Nodes.h:437
static DeclRefExpr * create(Context &ctx, SMRange loc, Decl *decl, Type type)
Definition: Nodes.cpp:288
Decl * getDecl() const
Get the decl referenced by this expression.
Definition: Nodes.h:442
This class represents a scope for named AST decls.
Definition: Nodes.h:64
auto getDecls() const
Return all of the decls within this scope.
Definition: Nodes.h:74
const Decl * lookup(StringRef name) const
Definition: Nodes.h:86
const DeclScope * getParentScope() const
Definition: Nodes.h:71
DeclScope * getParentScope()
Return the parent scope of this scope, or nullptr if there is no parent.
Definition: Nodes.h:70
T * lookup(StringRef name)
Definition: Nodes.h:83
Decl * lookup(StringRef name)
Lookup a decl with the given name starting from this scope.
Definition: Nodes.cpp:182
void add(Decl *decl)
Add a new decl to the scope.
Definition: Nodes.cpp:175
DeclScope(DeclScope *parent=nullptr)
Create a new scope with an optional parent scope.
Definition: Nodes.h:67
const T * lookup(StringRef name) const
Definition: Nodes.h:90
This class represents the base Decl node.
Definition: Nodes.h:673
std::optional< StringRef > getDocComment() const
Return the documentation comment attached to this decl if it has been set.
Definition: Nodes.h:686
const Name * getName() const
Return the name of the decl, or nullptr if it doesn't have one.
Definition: Nodes.h:676
Decl(TypeID typeID, SMRange loc, const Name *name=nullptr)
Definition: Nodes.h:689
static bool classof(const Node *node)
Provide type casting support.
Definition: Nodes.h:1328
void setDocComment(Context &ctx, StringRef comment)
Set the documentation comment for this decl.
Definition: Nodes.cpp:385
This statement represents the erase statement in PDLL.
Definition: Nodes.h:255
static EraseStmt * create(Context &ctx, SMRange loc, Expr *rootOp)
Definition: Nodes.cpp:219
This class represents a base AST Expression node.
Definition: Nodes.h:348
Expr(TypeID typeID, SMRange loc, Type type)
Definition: Nodes.h:357
static bool classof(const Node *node)
Provide type casting support.
Definition: Nodes.h:1343
Type getType() const
Return the type of this expression.
Definition: Nodes.h:351
This statement represents a let statement in PDLL.
Definition: Nodes.h:211
VariableDecl * getVarDecl() const
Return the variable defined by this statement.
Definition: Nodes.h:216
static LetStmt * create(Context &ctx, SMRange loc, VariableDecl *varDecl)
Definition: Nodes.cpp:207
This expression represents a named member or field access of a given parent expression.
Definition: Nodes.h:458
static MemberAccessExpr * create(Context &ctx, SMRange loc, const Expr *parentExpr, StringRef memberName, Type type)
Definition: Nodes.cpp:298
const Expr * getParentExpr() const
Get the parent expression of this access.
Definition: Nodes.h:465
StringRef getMemberName() const
Return the name of the member being accessed.
Definition: Nodes.h:468
This class represents a top-level AST module.
Definition: Nodes.h:1301
MutableArrayRef< Decl * > getChildren()
Return the children of this module.
Definition: Nodes.h:1306
static Module * create(Context &ctx, SMLoc loc, ArrayRef< Decl * > children)
Definition: Nodes.cpp:579
ArrayRef< Decl * > getChildren() const
Definition: Nodes.h:1309
This Decl represents a NamedAttribute, and contains a string name and attribute value.
Definition: Nodes.h:1002
Expr * getValue() const
Return value of the attribute.
Definition: Nodes.h:1011
const Name & getName() const
Return the name of the attribute.
Definition: Nodes.h:1008
static NamedAttributeDecl * create(Context &ctx, const Name &name, Expr *value)
Definition: Nodes.cpp:502
This CRTP class provides several utilies when defining new AST nodes.
Definition: Nodes.h:112
static bool classof(const Node *node)
Provide type casting support.
Definition: Nodes.h:117
NodeBase< T, BaseT > Base
Definition: Nodes.h:114
NodeBase(SMRange loc, Args &&...args)
Definition: Nodes.h:123
This class represents a base AST node.
Definition: Nodes.h:108
Node(TypeID typeID, SMRange loc)
Definition: Nodes.h:149
void walk(function_ref< void(const Node *)> walkFn) const
Walk all of the nodes including, and nested under, this node in pre-order.
Definition: Nodes.cpp:167
SMRange getLoc() const
Return the location of this node.
Definition: Nodes.h:131
std::enable_if_t<!std::is_convertible< const Node *, ArgT >::value > walk(WalkFnT &&walkFn) const
Definition: Nodes.h:141
void print(raw_ostream &os) const
Print this node to the given stream.
TypeID getTypeID() const
Return the type identifier of this node.
Definition: Nodes.h:128
The class represents an Operation constraint, and constrains a variable to be an Operation.
Definition: Nodes.h:778
static OpConstraintDecl * create(Context &ctx, SMRange loc, const OpNameDecl *nameDecl=nullptr)
Definition: Nodes.cpp:403
std::optional< StringRef > getName() const
Return the name of the operation, or std::nullopt if there isn't one.
Definition: Nodes.cpp:412
OpConstraintDecl(SMRange loc, const OpNameDecl *nameDecl)
Definition: Nodes.h:790
const OpNameDecl * getNameDecl() const
Return the declaration of the operation name.
Definition: Nodes.h:787
const OpNameDecl * nameDecl
The operation name of this constraint.
Definition: Nodes.h:794
This Decl represents an OperationName.
Definition: Nodes.h:1026
std::optional< StringRef > getName() const
Return the name of this operation, or std::nullopt if the name is unknown.
Definition: Nodes.h:1032
static OpNameDecl * create(Context &ctx, const Name &name)
Definition: Nodes.cpp:512
This class represents a base operation rewrite statement.
Definition: Nodes.h:231
static bool classof(const Node *node)
Provide type casting support.
Definition: Nodes.h:1348
OpRewriteStmt(TypeID typeID, SMRange loc, Expr *rootOp)
Definition: Nodes.h:240
Expr * rootOp
The root operation being rewritten.
Definition: Nodes.h:245
Expr * getRootOpExpr() const
Return the root operation of this rewrite.
Definition: Nodes.h:237
This expression represents the structural form of an MLIR Operation.
Definition: Nodes.h:516
MutableArrayRef< Expr * > getResultTypes()
Return the result types of this operation.
Definition: Nodes.h:544
MutableArrayRef< Expr * > getResultTypes() const
Definition: Nodes.h:547
MutableArrayRef< NamedAttributeDecl * > getAttributes()
Return the attributes of this operation.
Definition: Nodes.h:552
const OpNameDecl * getNameDecl() const
Return the declaration of the operation name.
Definition: Nodes.h:529
SMRange getNameLoc() const
Return the location of the name of the operation expression, or an invalid location if there isn't a ...
Definition: Nodes.h:533
MutableArrayRef< Expr * > getOperands()
Return the operands of this operation.
Definition: Nodes.h:536
static OperationExpr * create(Context &ctx, SMRange loc, const ods::Operation *odsOp, const OpNameDecl *nameDecl, ArrayRef< Expr * > operands, ArrayRef< Expr * > resultTypes, ArrayRef< NamedAttributeDecl * > attributes)
Definition: Nodes.cpp:310
std::optional< StringRef > getName() const
Return the name of the operation, or std::nullopt if there isn't one.
Definition: Nodes.cpp:333
MutableArrayRef< NamedAttributeDecl * > getAttributes() const
Definition: Nodes.h:555
ArrayRef< Expr * > getOperands() const
Definition: Nodes.h:539
This Decl represents a single Pattern.
Definition: Nodes.h:1047
const CompoundStmt * getBody() const
Return the body of this pattern.
Definition: Nodes.h:1061
static PatternDecl * create(Context &ctx, SMRange location, const Name *name, std::optional< uint16_t > benefit, bool hasBoundedRecursion, const CompoundStmt *body)
Definition: Nodes.cpp:523
const OpRewriteStmt * getRootRewriteStmt() const
Return the root rewrite statement of this pattern.
Definition: Nodes.h:1064
std::optional< uint16_t > getBenefit() const
Return the benefit of this pattern if specified, or std::nullopt.
Definition: Nodes.h:1055
bool hasBoundedRewriteRecursion() const
Return if this pattern has bounded rewrite recursion.
Definition: Nodes.h:1058
This expression builds a range from a set of element values (which may be ranges themselves).
Definition: Nodes.h:590
static RangeExpr * create(Context &ctx, SMRange loc, ArrayRef< Expr * > elements, RangeType type)
Definition: Nodes.cpp:341
MutableArrayRef< Expr * > getElements()
Return the element expressions of this range.
Definition: Nodes.h:596
ArrayRef< Expr * > getElements() const
Definition: Nodes.h:599
RangeType getType() const
Return the range result type of this expression.
Definition: Nodes.h:604
This class represents a PDLL type that corresponds to a range of elements with a given element type.
Definition: Types.h:159
This statement represents the replace statement in PDLL.
Definition: Nodes.h:271
ArrayRef< Expr * > getReplExprs() const
Definition: Nodes.h:280
MutableArrayRef< Expr * > getReplExprs()
Return the replacement values of this statement.
Definition: Nodes.h:277
static ReplaceStmt * create(Context &ctx, SMRange loc, Expr *rootOp, ArrayRef< Expr * > replExprs)
Definition: Nodes.cpp:227
This statement represents a return from a "callable" like decl, e.g.
Definition: Nodes.h:324
static ReturnStmt * create(Context &ctx, SMRange loc, Expr *resultExpr)
Definition: Nodes.cpp:252
void setResultExpr(Expr *expr)
Set the result expression of this statement.
Definition: Nodes.h:333
Expr * getResultExpr()
Return the result expression of this statement.
Definition: Nodes.h:329
const Expr * getResultExpr() const
Definition: Nodes.h:330
This statement represents an operation rewrite that contains a block of nested rewrite commands.
Definition: Nodes.h:302
static RewriteStmt * create(Context &ctx, SMRange loc, Expr *rootOp, CompoundStmt *rewriteBody)
Definition: Nodes.cpp:242
CompoundStmt * getRewriteBody() const
Return the compound rewrite body.
Definition: Nodes.h:308
This class represents a base AST Statement node.
Definition: Nodes.h:164
static bool classof(const Node *node)
Provide type casting support.
Definition: Nodes.h:1352
This expression builds a tuple from a set of element values.
Definition: Nodes.h:623
TupleType getType() const
Return the tuple result type of this expression.
Definition: Nodes.h:637
MutableArrayRef< Expr * > getElements()
Return the element expressions of this tuple.
Definition: Nodes.h:629
ArrayRef< Expr * > getElements() const
Definition: Nodes.h:632
static TupleExpr * create(Context &ctx, SMRange loc, ArrayRef< Expr * > elements, ArrayRef< StringRef > elementNames)
Definition: Nodes.cpp:356
This class represents a PDLL tuple type, i.e.
Definition: Types.h:222
size_t size() const
Return the number of elements within this tuple.
Definition: Types.h:239
The class represents a Type constraint, and constrains a variable to be a Type.
Definition: Nodes.h:804
static TypeConstraintDecl * create(Context &ctx, SMRange loc)
Definition: Nodes.cpp:420
This expression represents a literal MLIR Type, and contains the textual assembly format of that type...
Definition: Nodes.h:652
StringRef getValue() const
Get the raw value of this expression.
Definition: Nodes.h:658
static TypeExpr * create(Context &ctx, SMRange loc, StringRef value)
Definition: Nodes.cpp:376
The class represents a TypeRange constraint, and constrains a variable to be a TypeRange.
Definition: Nodes.h:819
static TypeRangeConstraintDecl * create(Context &ctx, SMRange loc)
Definition: Nodes.cpp:429
This class represents a PDLL type that corresponds to an mlir::Type.
Definition: Types.h:250
This decl represents a user defined constraint.
Definition: Nodes.h:892
ArrayRef< VariableDecl * > getResults() const
Definition: Nodes.h:934
bool isExternal() const
Returns true if this constraint is external.
Definition: Nodes.h:950
MutableArrayRef< VariableDecl * > getInputs()
Return the input arguments of this constraint.
Definition: Nodes.h:918
std::optional< StringRef > getNativeInputType(unsigned index) const
Return the explicit native type to use for the given input.
Definition: Nodes.cpp:460
const Name & getName() const
Return the name of the constraint.
Definition: Nodes.h:915
std::optional< StringRef > getCodeBlock() const
Return the optional code block of this constraint, if this is a native constraint with a provided imp...
Definition: Nodes.h:940
static UserConstraintDecl * createPDLL(Context &ctx, const Name &name, ArrayRef< VariableDecl * > inputs, ArrayRef< VariableDecl * > results, const CompoundStmt *body, Type resultType)
Create a PDLL constraint with the given body.
Definition: Nodes.h:905
Type getResultType() const
Return the result type of this constraint.
Definition: Nodes.h:947
MutableArrayRef< VariableDecl * > getResults()
Return the explicit results of the constraint declaration.
Definition: Nodes.h:931
ArrayRef< VariableDecl * > getInputs() const
Definition: Nodes.h:921
const CompoundStmt * getBody() const
Return the body of this constraint if this constraint is a PDLL constraint, otherwise returns nullptr...
Definition: Nodes.h:944
static UserConstraintDecl * createNative(Context &ctx, const Name &name, ArrayRef< VariableDecl * > inputs, ArrayRef< VariableDecl * > results, std::optional< StringRef > codeBlock, Type resultType, ArrayRef< StringRef > nativeInputTypes={})
Create a native constraint with the given optional code block.
Definition: Nodes.h:896
This decl represents a user defined rewrite.
Definition: Nodes.h:1102
std::optional< StringRef > getCodeBlock() const
Return the optional code block of this rewrite, if this is a native rewrite with a provided implement...
Definition: Nodes.h:1146
Type getResultType() const
Return the result type of this rewrite.
Definition: Nodes.h:1153
const Name & getName() const
Return the name of the rewrite.
Definition: Nodes.h:1125
ArrayRef< VariableDecl * > getResults() const
Definition: Nodes.h:1140
const CompoundStmt * getBody() const
Return the body of this rewrite if this rewrite is a PDLL rewrite, otherwise returns nullptr.
Definition: Nodes.h:1150
static UserRewriteDecl * createNative(Context &ctx, const Name &name, ArrayRef< VariableDecl * > inputs, ArrayRef< VariableDecl * > results, std::optional< StringRef > codeBlock, Type resultType)
Create a native rewrite with the given optional code block.
Definition: Nodes.h:1105
static UserRewriteDecl * createPDLL(Context &ctx, const Name &name, ArrayRef< VariableDecl * > inputs, ArrayRef< VariableDecl * > results, const CompoundStmt *body, Type resultType)
Create a PDLL rewrite with the given body.
Definition: Nodes.h:1115
MutableArrayRef< VariableDecl * > getInputs()
Return the input arguments of this rewrite.
Definition: Nodes.h:1128
ArrayRef< VariableDecl * > getInputs() const
Definition: Nodes.h:1131
MutableArrayRef< VariableDecl * > getResults()
Return the explicit results of the rewrite declaration.
Definition: Nodes.h:1137
bool isExternal() const
Returns true if this rewrite is external.
Definition: Nodes.h:1156
The class represents a Value constraint, and constrains a variable to be a Value.
Definition: Nodes.h:834
static ValueConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr)
Definition: Nodes.cpp:439
Expr * typeExpr
An optional type that the value is constrained to.
Definition: Nodes.h:847
const Expr * getTypeExpr() const
Definition: Nodes.h:840
ValueConstraintDecl(SMRange loc, Expr *typeExpr)
Definition: Nodes.h:843
Expr * getTypeExpr()
Return the optional type the value is constrained to.
Definition: Nodes.h:839
The class represents a ValueRange constraint, and constrains a variable to be a ValueRange.
Definition: Nodes.h:857
const Expr * getTypeExpr() const
Definition: Nodes.h:864
Expr * getTypeExpr()
Return the optional type the value range is constrained to.
Definition: Nodes.h:863
static ValueRangeConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr=nullptr)
Definition: Nodes.cpp:450
ValueRangeConstraintDecl(SMRange loc, Expr *typeExpr)
Definition: Nodes.h:867
Expr * typeExpr
An optional type that the value range is constrained to.
Definition: Nodes.h:871
This Decl represents the definition of a PDLL variable.
Definition: Nodes.h:1252
const Name & getName() const
Return the name of the decl.
Definition: Nodes.h:1271
Expr * getInitExpr() const
Return the initializer expression of this statement, or nullptr if there was no initializer.
Definition: Nodes.h:1268
static VariableDecl * create(Context &ctx, const Name &name, Type type, Expr *initExpr, ArrayRef< ConstraintRef > constraints)
Definition: Nodes.cpp:561
MutableArrayRef< ConstraintRef > getConstraints()
Return the constraints of this variable.
Definition: Nodes.h:1259
Type getType() const
Return the type of the decl.
Definition: Nodes.h:1274
ArrayRef< ConstraintRef > getConstraints() const
Definition: Nodes.h:1262
This class provides an ODS representation of a specific operation.
Definition: Operation.h:125
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents a reference to a constraint, and contains a constraint and the location of the ...
Definition: Nodes.h:720
ConstraintRef(const ConstraintDecl *constraint, SMRange refLoc)
Definition: Nodes.h:721
const ConstraintDecl * constraint
Definition: Nodes.h:726
ConstraintRef(const ConstraintDecl *constraint)
Definition: Nodes.h:723
This class provides a convenient API for interacting with source names.
Definition: Nodes.h:37
StringRef getName() const
Return the raw string name.
Definition: Nodes.h:41
SMRange getLoc() const
Get the location of this name.
Definition: Nodes.h:44
static const Name & create(Context &ctx, StringRef name, SMRange location)
Definition: Nodes.cpp:33