MLIR  21.0.0git
Nodes.h
Go to the documentation of this file.
1 //===- Nodes.h --------------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_TOOLS_PDLL_AST_NODES_H_
10 #define MLIR_TOOLS_PDLL_AST_NODES_H_
11 
12 #include "mlir/Support/LLVM.h"
14 #include "llvm/ADT/StringMap.h"
15 #include "llvm/ADT/StringRef.h"
16 #include "llvm/Support/SMLoc.h"
17 #include "llvm/Support/SourceMgr.h"
18 #include "llvm/Support/TrailingObjects.h"
19 #include <optional>
20 
21 namespace mlir {
22 namespace pdll {
23 namespace ast {
24 class Context;
25 class Decl;
26 class Expr;
27 class NamedAttributeDecl;
28 class OpNameDecl;
29 class VariableDecl;
30 
31 //===----------------------------------------------------------------------===//
32 // Name
33 //===----------------------------------------------------------------------===//
34 
35 /// This class provides a convenient API for interacting with source names. It
36 /// contains a string name as well as the source location for that name.
37 struct Name {
38  static const Name &create(Context &ctx, StringRef name, SMRange location);
39 
40  /// Return the raw string name.
41  StringRef getName() const { return name; }
42 
43  /// Get the location of this name.
44  SMRange getLoc() const { return location; }
45 
46 private:
47  Name() = delete;
48  Name(const Name &) = delete;
49  Name &operator=(const Name &) = delete;
50  Name(StringRef name, SMRange location) : name(name), location(location) {}
51 
52  /// The string name of the decl.
53  StringRef name;
54  /// The location of the decl name.
55  SMRange location;
56 };
57 
58 //===----------------------------------------------------------------------===//
59 // DeclScope
60 //===----------------------------------------------------------------------===//
61 
62 /// This class represents a scope for named AST decls. A scope determines the
63 /// visibility and lifetime of a named declaration.
64 class DeclScope {
65 public:
66  /// Create a new scope with an optional parent scope.
67  DeclScope(DeclScope *parent = nullptr) : parent(parent) {}
68 
69  /// Return the parent scope of this scope, or nullptr if there is no parent.
70  DeclScope *getParentScope() { return parent; }
71  const DeclScope *getParentScope() const { return parent; }
72 
73  /// Return all of the decls within this scope.
74  auto getDecls() const { return llvm::make_second_range(decls); }
75 
76  /// Add a new decl to the scope.
77  void add(Decl *decl);
78 
79  /// Lookup a decl with the given name starting from this scope. Returns
80  /// nullptr if no decl could be found.
81  Decl *lookup(StringRef name);
82  template <typename T>
83  T *lookup(StringRef name) {
84  return dyn_cast_or_null<T>(lookup(name));
85  }
86  const Decl *lookup(StringRef name) const {
87  return const_cast<DeclScope *>(this)->lookup(name);
88  }
89  template <typename T>
90  const T *lookup(StringRef name) const {
91  return dyn_cast_or_null<T>(lookup(name));
92  }
93 
94 private:
95  /// The parent scope, or null if this is a top-level scope.
96  DeclScope *parent;
97  /// The decls defined within this scope.
98  llvm::StringMap<Decl *> decls;
99 };
100 
101 //===----------------------------------------------------------------------===//
102 // Node
103 //===----------------------------------------------------------------------===//
104 
105 /// This class represents a base AST node. All AST nodes are derived from this
106 /// class, and it contains many of the base functionality for interacting with
107 /// nodes.
108 class Node {
109 public:
110  /// This CRTP class provides several utilies when defining new AST nodes.
111  template <typename T, typename BaseT>
112  class NodeBase : public BaseT {
113  public:
115 
116  /// Provide type casting support.
117  static bool classof(const Node *node) {
118  return node->getTypeID() == TypeID::get<T>();
119  }
120 
121  protected:
122  template <typename... Args>
123  explicit NodeBase(SMRange loc, Args &&...args)
124  : BaseT(TypeID::get<T>(), loc, std::forward<Args>(args)...) {}
125  };
126 
127  /// Return the type identifier of this node.
128  TypeID getTypeID() const { return typeID; }
129 
130  /// Return the location of this node.
131  SMRange getLoc() const { return loc; }
132 
133  /// Print this node to the given stream.
134  void print(raw_ostream &os) const;
135 
136  /// Walk all of the nodes including, and nested under, this node in pre-order.
137  void walk(function_ref<void(const Node *)> walkFn) const;
138  template <typename WalkFnT, typename ArgT = typename llvm::function_traits<
139  WalkFnT>::template arg_t<0>>
140  std::enable_if_t<!std::is_convertible<const Node *, ArgT>::value>
141  walk(WalkFnT &&walkFn) const {
142  walk([&](const Node *node) {
143  if (const ArgT *derivedNode = dyn_cast<ArgT>(node))
144  walkFn(derivedNode);
145  });
146  }
147 
148 protected:
149  Node(TypeID typeID, SMRange loc) : typeID(typeID), loc(loc) {}
150 
151 private:
152  /// A unique type identifier for this node.
153  TypeID typeID;
154 
155  /// The location of this node.
156  SMRange loc;
157 };
158 
159 //===----------------------------------------------------------------------===//
160 // Stmt
161 //===----------------------------------------------------------------------===//
162 
163 /// This class represents a base AST Statement node.
164 class Stmt : public Node {
165 public:
166  using Node::Node;
167 
168  /// Provide type casting support.
169  static bool classof(const Node *node);
170 };
171 
172 //===----------------------------------------------------------------------===//
173 // CompoundStmt
174 //===----------------------------------------------------------------------===//
175 
176 /// This statement represents a compound statement, which contains a collection
177 /// of other statements.
178 class CompoundStmt final : public Node::NodeBase<CompoundStmt, Stmt>,
179  private llvm::TrailingObjects<CompoundStmt, Stmt *> {
180 public:
181  static CompoundStmt *create(Context &ctx, SMRange location,
182  ArrayRef<Stmt *> children);
183 
184  /// Return the children of this compound statement.
186  return getTrailingObjects(numChildren);
187  }
189  return getTrailingObjects(numChildren);
190  }
191  ArrayRef<Stmt *>::iterator begin() const { return getChildren().begin(); }
192  ArrayRef<Stmt *>::iterator end() const { return getChildren().end(); }
193 
194 private:
195  CompoundStmt(SMRange location, unsigned numChildren)
196  : Base(location), numChildren(numChildren) {}
197 
198  /// The number of held children statements.
199  unsigned numChildren;
200 
201  // Allow access to various privates.
202  friend class llvm::TrailingObjects<CompoundStmt, Stmt *>;
203 };
204 
205 //===----------------------------------------------------------------------===//
206 // LetStmt
207 //===----------------------------------------------------------------------===//
208 
209 /// This statement represents a `let` statement in PDLL. This statement is used
210 /// to define variables.
211 class LetStmt final : public Node::NodeBase<LetStmt, Stmt> {
212 public:
213  static LetStmt *create(Context &ctx, SMRange loc, VariableDecl *varDecl);
214 
215  /// Return the variable defined by this statement.
216  VariableDecl *getVarDecl() const { return varDecl; }
217 
218 private:
219  LetStmt(SMRange loc, VariableDecl *varDecl) : Base(loc), varDecl(varDecl) {}
220 
221  /// The variable defined by this statement.
222  VariableDecl *varDecl;
223 };
224 
225 //===----------------------------------------------------------------------===//
226 // OpRewriteStmt
227 //===----------------------------------------------------------------------===//
228 
229 /// This class represents a base operation rewrite statement. Operation rewrite
230 /// statements perform a set of transformations on a given root operation.
231 class OpRewriteStmt : public Stmt {
232 public:
233  /// Provide type casting support.
234  static bool classof(const Node *node);
235 
236  /// Return the root operation of this rewrite.
237  Expr *getRootOpExpr() const { return rootOp; }
238 
239 protected:
240  OpRewriteStmt(TypeID typeID, SMRange loc, Expr *rootOp)
241  : Stmt(typeID, loc), rootOp(rootOp) {}
242 
243 protected:
244  /// The root operation being rewritten.
246 };
247 
248 //===----------------------------------------------------------------------===//
249 // EraseStmt
250 //===----------------------------------------------------------------------===//
251 
252 /// This statement represents the `erase` statement in PDLL. This statement
253 /// erases the given root operation, corresponding roughly to the
254 /// PatternRewriter::eraseOp API.
255 class EraseStmt final : public Node::NodeBase<EraseStmt, OpRewriteStmt> {
256 public:
257  static EraseStmt *create(Context &ctx, SMRange loc, Expr *rootOp);
258 
259 private:
260  EraseStmt(SMRange loc, Expr *rootOp) : Base(loc, rootOp) {}
261 };
262 
263 //===----------------------------------------------------------------------===//
264 // ReplaceStmt
265 //===----------------------------------------------------------------------===//
266 
267 /// This statement represents the `replace` statement in PDLL. This statement
268 /// replace the given root operation with a set of values, corresponding roughly
269 /// to the PatternRewriter::replaceOp API.
270 class ReplaceStmt final : public Node::NodeBase<ReplaceStmt, OpRewriteStmt>,
271  private llvm::TrailingObjects<ReplaceStmt, Expr *> {
272 public:
273  static ReplaceStmt *create(Context &ctx, SMRange loc, Expr *rootOp,
274  ArrayRef<Expr *> replExprs);
275 
276  /// Return the replacement values of this statement.
278  return getTrailingObjects(numReplExprs);
279  }
281  return getTrailingObjects(numReplExprs);
282  }
283 
284 private:
285  ReplaceStmt(SMRange loc, Expr *rootOp, unsigned numReplExprs)
286  : Base(loc, rootOp), numReplExprs(numReplExprs) {}
287 
288  /// The number of replacement values within this statement.
289  unsigned numReplExprs;
290 
291  /// TrailingObject utilities.
292  friend class llvm::TrailingObjects<ReplaceStmt, Expr *>;
293 };
294 
295 //===----------------------------------------------------------------------===//
296 // RewriteStmt
297 //===----------------------------------------------------------------------===//
298 
299 /// This statement represents an operation rewrite that contains a block of
300 /// nested rewrite commands. This allows for building more complex operation
301 /// rewrites that span across multiple statements, which may be unconnected.
302 class RewriteStmt final : public Node::NodeBase<RewriteStmt, OpRewriteStmt> {
303 public:
304  static RewriteStmt *create(Context &ctx, SMRange loc, Expr *rootOp,
305  CompoundStmt *rewriteBody);
306 
307  /// Return the compound rewrite body.
308  CompoundStmt *getRewriteBody() const { return rewriteBody; }
309 
310 private:
311  RewriteStmt(SMRange loc, Expr *rootOp, CompoundStmt *rewriteBody)
312  : Base(loc, rootOp), rewriteBody(rewriteBody) {}
313 
314  /// The body of nested rewriters within this statement.
315  CompoundStmt *rewriteBody;
316 };
317 
318 //===----------------------------------------------------------------------===//
319 // ReturnStmt
320 //===----------------------------------------------------------------------===//
321 
322 /// This statement represents a return from a "callable" like decl, e.g. a
323 /// Constraint or a Rewrite.
324 class ReturnStmt final : public Node::NodeBase<ReturnStmt, Stmt> {
325 public:
326  static ReturnStmt *create(Context &ctx, SMRange loc, Expr *resultExpr);
327 
328  /// Return the result expression of this statement.
329  Expr *getResultExpr() { return resultExpr; }
330  const Expr *getResultExpr() const { return resultExpr; }
331 
332  /// Set the result expression of this statement.
333  void setResultExpr(Expr *expr) { resultExpr = expr; }
334 
335 private:
336  ReturnStmt(SMRange loc, Expr *resultExpr)
337  : Base(loc), resultExpr(resultExpr) {}
338 
339  // The result expression of this statement.
340  Expr *resultExpr;
341 };
342 
343 //===----------------------------------------------------------------------===//
344 // Expr
345 //===----------------------------------------------------------------------===//
346 
347 /// This class represents a base AST Expression node.
348 class Expr : public Stmt {
349 public:
350  /// Return the type of this expression.
351  Type getType() const { return type; }
352 
353  /// Provide type casting support.
354  static bool classof(const Node *node);
355 
356 protected:
357  Expr(TypeID typeID, SMRange loc, Type type) : Stmt(typeID, loc), type(type) {}
358 
359 private:
360  /// The type of this expression.
361  Type type;
362 };
363 
364 //===----------------------------------------------------------------------===//
365 // AttributeExpr
366 //===----------------------------------------------------------------------===//
367 
368 /// This expression represents a literal MLIR Attribute, and contains the
369 /// textual assembly format of that attribute.
370 class AttributeExpr : public Node::NodeBase<AttributeExpr, Expr> {
371 public:
372  static AttributeExpr *create(Context &ctx, SMRange loc, StringRef value);
373 
374  /// Get the raw value of this expression. This is the textual assembly format
375  /// of the MLIR Attribute.
376  StringRef getValue() const { return value; }
377 
378 private:
379  AttributeExpr(Context &ctx, SMRange loc, StringRef value)
380  : Base(loc, AttributeType::get(ctx)), value(value) {}
381 
382  /// The value referenced by this expression.
383  StringRef value;
384 };
385 
386 //===----------------------------------------------------------------------===//
387 // CallExpr
388 //===----------------------------------------------------------------------===//
389 
390 /// This expression represents a call to a decl, such as a
391 /// UserConstraintDecl/UserRewriteDecl.
392 class CallExpr final : public Node::NodeBase<CallExpr, Expr>,
393  private llvm::TrailingObjects<CallExpr, Expr *> {
394 public:
395  static CallExpr *create(Context &ctx, SMRange loc, Expr *callable,
396  ArrayRef<Expr *> arguments, Type resultType,
397  bool isNegated = false);
398 
399  /// Return the callable of this call.
400  Expr *getCallableExpr() const { return callable; }
401 
402  /// Return the arguments of this call.
403  MutableArrayRef<Expr *> getArguments() { return getTrailingObjects(numArgs); }
404  ArrayRef<Expr *> getArguments() const { return getTrailingObjects(numArgs); }
405 
406  /// Returns whether the result of this call is to be negated.
407  bool getIsNegated() const { return isNegated; }
408 
409 private:
410  CallExpr(SMRange loc, Type type, Expr *callable, unsigned numArgs,
411  bool isNegated)
412  : Base(loc, type), callable(callable), numArgs(numArgs),
413  isNegated(isNegated) {}
414 
415  /// The callable of this call.
416  Expr *callable;
417 
418  /// The number of arguments of the call.
419  unsigned numArgs;
420 
421  /// TrailingObject utilities.
422  friend llvm::TrailingObjects<CallExpr, Expr *>;
423 
424  // Is the result of this call to be negated.
425  bool isNegated;
426 };
427 
428 //===----------------------------------------------------------------------===//
429 // DeclRefExpr
430 //===----------------------------------------------------------------------===//
431 
432 /// This expression represents a reference to a Decl node.
433 class DeclRefExpr : public Node::NodeBase<DeclRefExpr, Expr> {
434 public:
435  static DeclRefExpr *create(Context &ctx, SMRange loc, Decl *decl, Type type);
436 
437  /// Get the decl referenced by this expression.
438  Decl *getDecl() const { return decl; }
439 
440 private:
441  DeclRefExpr(SMRange loc, Decl *decl, Type type)
442  : Base(loc, type), decl(decl) {}
443 
444  /// The decl referenced by this expression.
445  Decl *decl;
446 };
447 
448 //===----------------------------------------------------------------------===//
449 // MemberAccessExpr
450 //===----------------------------------------------------------------------===//
451 
452 /// This expression represents a named member or field access of a given parent
453 /// expression.
454 class MemberAccessExpr : public Node::NodeBase<MemberAccessExpr, Expr> {
455 public:
456  static MemberAccessExpr *create(Context &ctx, SMRange loc,
457  const Expr *parentExpr, StringRef memberName,
458  Type type);
459 
460  /// Get the parent expression of this access.
461  const Expr *getParentExpr() const { return parentExpr; }
462 
463  /// Return the name of the member being accessed.
464  StringRef getMemberName() const { return memberName; }
465 
466 private:
467  MemberAccessExpr(SMRange loc, const Expr *parentExpr, StringRef memberName,
468  Type type)
469  : Base(loc, type), parentExpr(parentExpr), memberName(memberName) {}
470 
471  /// The parent expression of this access.
472  const Expr *parentExpr;
473 
474  /// The name of the member being accessed from the parent.
475  StringRef memberName;
476 };
477 
478 //===----------------------------------------------------------------------===//
479 // AllResultsMemberAccessExpr
480 //===----------------------------------------------------------------------===//
481 
482 /// This class represents an instance of MemberAccessExpr that references all
483 /// results of an operation.
485 public:
486  /// Return the member name used for the "all-results" access.
487  static StringRef getMemberName() { return "$results"; }
488 
489  static AllResultsMemberAccessExpr *create(Context &ctx, SMRange loc,
490  const Expr *parentExpr, Type type) {
491  return cast<AllResultsMemberAccessExpr>(
492  MemberAccessExpr::create(ctx, loc, parentExpr, getMemberName(), type));
493  }
494 
495  /// Provide type casting support.
496  static bool classof(const Node *node) {
497  const MemberAccessExpr *memAccess = dyn_cast<MemberAccessExpr>(node);
498  return memAccess && memAccess->getMemberName() == getMemberName();
499  }
500 };
501 
502 //===----------------------------------------------------------------------===//
503 // OperationExpr
504 //===----------------------------------------------------------------------===//
505 
506 /// This expression represents the structural form of an MLIR Operation. It
507 /// represents either an input operation to match, or an operation to create
508 /// within a rewrite.
509 class OperationExpr final
510  : public Node::NodeBase<OperationExpr, Expr>,
511  private llvm::TrailingObjects<OperationExpr, Expr *,
512  NamedAttributeDecl *> {
513 public:
514  static OperationExpr *create(Context &ctx, SMRange loc,
515  const ods::Operation *odsOp,
516  const OpNameDecl *nameDecl,
517  ArrayRef<Expr *> operands,
518  ArrayRef<Expr *> resultTypes,
519  ArrayRef<NamedAttributeDecl *> attributes);
520 
521  /// Return the name of the operation, or std::nullopt if there isn't one.
522  std::optional<StringRef> getName() const;
523 
524  /// Return the declaration of the operation name.
525  const OpNameDecl *getNameDecl() const { return nameDecl; }
526 
527  /// Return the location of the name of the operation expression, or an invalid
528  /// location if there isn't a name.
529  SMRange getNameLoc() const { return nameLoc; }
530 
531  /// Return the operands of this operation.
533  return getTrailingObjects<Expr *>(numOperands);
534  }
536  return getTrailingObjects<Expr *>(numOperands);
537  }
538 
539  /// Return the result types of this operation.
541  return {getTrailingObjects<Expr *>() + numOperands, numResultTypes};
542  }
544  return const_cast<OperationExpr *>(this)->getResultTypes();
545  }
546 
547  /// Return the attributes of this operation.
549  return getTrailingObjects<NamedAttributeDecl *>(numAttributes);
550  }
552  return getTrailingObjects<NamedAttributeDecl *>(numAttributes);
553  }
554 
555 private:
556  OperationExpr(SMRange loc, Type type, const OpNameDecl *nameDecl,
557  unsigned numOperands, unsigned numResultTypes,
558  unsigned numAttributes, SMRange nameLoc)
559  : Base(loc, type), nameDecl(nameDecl), numOperands(numOperands),
560  numResultTypes(numResultTypes), numAttributes(numAttributes),
561  nameLoc(nameLoc) {}
562 
563  /// The name decl of this expression.
564  const OpNameDecl *nameDecl;
565 
566  /// The number of operands, result types, and attributes of the operation.
567  unsigned numOperands, numResultTypes, numAttributes;
568 
569  /// The location of the operation name in the expression if it has a name.
570  SMRange nameLoc;
571 
572  /// TrailingObject utilities.
573  friend llvm::TrailingObjects<OperationExpr, Expr *, NamedAttributeDecl *>;
574  size_t numTrailingObjects(OverloadToken<Expr *>) const {
575  return numOperands + numResultTypes;
576  }
577 };
578 
579 //===----------------------------------------------------------------------===//
580 // RangeExpr
581 //===----------------------------------------------------------------------===//
582 
583 /// This expression builds a range from a set of element values (which may be
584 /// ranges themselves).
585 class RangeExpr final : public Node::NodeBase<RangeExpr, Expr>,
586  private llvm::TrailingObjects<RangeExpr, Expr *> {
587 public:
588  static RangeExpr *create(Context &ctx, SMRange loc, ArrayRef<Expr *> elements,
589  RangeType type);
590 
591  /// Return the element expressions of this range.
593  return getTrailingObjects(numElements);
594  }
596  return getTrailingObjects(numElements);
597  }
598 
599  /// Return the range result type of this expression.
600  RangeType getType() const { return mlir::cast<RangeType>(Base::getType()); }
601 
602 private:
603  RangeExpr(SMRange loc, RangeType type, unsigned numElements)
604  : Base(loc, type), numElements(numElements) {}
605 
606  /// The number of element values for this range.
607  unsigned numElements;
608 
609  /// TrailingObject utilities.
610  friend class llvm::TrailingObjects<RangeExpr, Expr *>;
611 };
612 
613 //===----------------------------------------------------------------------===//
614 // TupleExpr
615 //===----------------------------------------------------------------------===//
616 
617 /// This expression builds a tuple from a set of element values.
618 class TupleExpr final : public Node::NodeBase<TupleExpr, Expr>,
619  private llvm::TrailingObjects<TupleExpr, Expr *> {
620 public:
621  static TupleExpr *create(Context &ctx, SMRange loc, ArrayRef<Expr *> elements,
622  ArrayRef<StringRef> elementNames);
623 
624  /// Return the element expressions of this tuple.
626  return getTrailingObjects(getType().size());
627  }
629  return getTrailingObjects(getType().size());
630  }
631 
632  /// Return the tuple result type of this expression.
633  TupleType getType() const { return mlir::cast<TupleType>(Base::getType()); }
634 
635 private:
636  TupleExpr(SMRange loc, TupleType type) : Base(loc, type) {}
637 
638  /// TrailingObject utilities.
639  friend class llvm::TrailingObjects<TupleExpr, Expr *>;
640 };
641 
642 //===----------------------------------------------------------------------===//
643 // TypeExpr
644 //===----------------------------------------------------------------------===//
645 
646 /// This expression represents a literal MLIR Type, and contains the textual
647 /// assembly format of that type.
648 class TypeExpr : public Node::NodeBase<TypeExpr, Expr> {
649 public:
650  static TypeExpr *create(Context &ctx, SMRange loc, StringRef value);
651 
652  /// Get the raw value of this expression. This is the textual assembly format
653  /// of the MLIR Type.
654  StringRef getValue() const { return value; }
655 
656 private:
657  TypeExpr(Context &ctx, SMRange loc, StringRef value)
658  : Base(loc, TypeType::get(ctx)), value(value) {}
659 
660  /// The value referenced by this expression.
661  StringRef value;
662 };
663 
664 //===----------------------------------------------------------------------===//
665 // Decl
666 //===----------------------------------------------------------------------===//
667 
668 /// This class represents the base Decl node.
669 class Decl : public Node {
670 public:
671  /// Return the name of the decl, or nullptr if it doesn't have one.
672  const Name *getName() const { return name; }
673 
674  /// Provide type casting support.
675  static bool classof(const Node *node);
676 
677  /// Set the documentation comment for this decl.
678  void setDocComment(Context &ctx, StringRef comment);
679 
680  /// Return the documentation comment attached to this decl if it has been set.
681  /// Otherwise, returns std::nullopt.
682  std::optional<StringRef> getDocComment() const { return docComment; }
683 
684 protected:
685  Decl(TypeID typeID, SMRange loc, const Name *name = nullptr)
686  : Node(typeID, loc), name(name) {}
687 
688 private:
689  /// The name of the decl. This is optional for some decls, such as
690  /// PatternDecl.
691  const Name *name;
692 
693  /// The documentation comment attached to this decl. Defaults to std::nullopt
694  /// if the comment is unset/unknown.
695  std::optional<StringRef> docComment;
696 };
697 
698 //===----------------------------------------------------------------------===//
699 // ConstraintDecl
700 //===----------------------------------------------------------------------===//
701 
702 /// This class represents the base of all AST Constraint decls. Constraints
703 /// apply matcher conditions to, and define the type of PDLL variables.
704 class ConstraintDecl : public Decl {
705 public:
706  /// Provide type casting support.
707  static bool classof(const Node *node);
708 
709 protected:
710  ConstraintDecl(TypeID typeID, SMRange loc, const Name *name = nullptr)
711  : Decl(typeID, loc, name) {}
712 };
713 
714 /// This class represents a reference to a constraint, and contains a constraint
715 /// and the location of the reference.
717  ConstraintRef(const ConstraintDecl *constraint, SMRange refLoc)
718  : constraint(constraint), referenceLoc(refLoc) {}
720  : ConstraintRef(constraint, constraint->getLoc()) {}
721 
723  SMRange referenceLoc;
724 };
725 
726 //===----------------------------------------------------------------------===//
727 // CoreConstraintDecl
728 //===----------------------------------------------------------------------===//
729 
730 /// This class represents the base of all "core" constraints. Core constraints
731 /// are those that generally represent a concrete IR construct, such as
732 /// `Type`s or `Value`s.
734 public:
735  /// Provide type casting support.
736  static bool classof(const Node *node);
737 
738 protected:
739  CoreConstraintDecl(TypeID typeID, SMRange loc, const Name *name = nullptr)
740  : ConstraintDecl(typeID, loc, name) {}
741 };
742 
743 //===----------------------------------------------------------------------===//
744 // AttrConstraintDecl
745 //===----------------------------------------------------------------------===//
746 
747 /// The class represents an Attribute constraint, and constrains a variable to
748 /// be an Attribute.
750  : public Node::NodeBase<AttrConstraintDecl, CoreConstraintDecl> {
751 public:
752  static AttrConstraintDecl *create(Context &ctx, SMRange loc,
753  Expr *typeExpr = nullptr);
754 
755  /// Return the optional type the attribute is constrained to.
756  Expr *getTypeExpr() { return typeExpr; }
757  const Expr *getTypeExpr() const { return typeExpr; }
758 
759 protected:
761  : Base(loc), typeExpr(typeExpr) {}
762 
763  /// An optional type that the attribute is constrained to.
765 };
766 
767 //===----------------------------------------------------------------------===//
768 // OpConstraintDecl
769 //===----------------------------------------------------------------------===//
770 
771 /// The class represents an Operation constraint, and constrains a variable to
772 /// be an Operation.
774  : public Node::NodeBase<OpConstraintDecl, CoreConstraintDecl> {
775 public:
776  static OpConstraintDecl *create(Context &ctx, SMRange loc,
777  const OpNameDecl *nameDecl = nullptr);
778 
779  /// Return the name of the operation, or std::nullopt if there isn't one.
780  std::optional<StringRef> getName() const;
781 
782  /// Return the declaration of the operation name.
783  const OpNameDecl *getNameDecl() const { return nameDecl; }
784 
785 protected:
786  explicit OpConstraintDecl(SMRange loc, const OpNameDecl *nameDecl)
787  : Base(loc), nameDecl(nameDecl) {}
788 
789  /// The operation name of this constraint.
791 };
792 
793 //===----------------------------------------------------------------------===//
794 // TypeConstraintDecl
795 //===----------------------------------------------------------------------===//
796 
797 /// The class represents a Type constraint, and constrains a variable to be a
798 /// Type.
800  : public Node::NodeBase<TypeConstraintDecl, CoreConstraintDecl> {
801 public:
802  static TypeConstraintDecl *create(Context &ctx, SMRange loc);
803 
804 protected:
805  using Base::Base;
806 };
807 
808 //===----------------------------------------------------------------------===//
809 // TypeRangeConstraintDecl
810 //===----------------------------------------------------------------------===//
811 
812 /// The class represents a TypeRange constraint, and constrains a variable to be
813 /// a TypeRange.
815  : public Node::NodeBase<TypeRangeConstraintDecl, CoreConstraintDecl> {
816 public:
817  static TypeRangeConstraintDecl *create(Context &ctx, SMRange loc);
818 
819 protected:
820  using Base::Base;
821 };
822 
823 //===----------------------------------------------------------------------===//
824 // ValueConstraintDecl
825 //===----------------------------------------------------------------------===//
826 
827 /// The class represents a Value constraint, and constrains a variable to be a
828 /// Value.
830  : public Node::NodeBase<ValueConstraintDecl, CoreConstraintDecl> {
831 public:
832  static ValueConstraintDecl *create(Context &ctx, SMRange loc, Expr *typeExpr);
833 
834  /// Return the optional type the value is constrained to.
835  Expr *getTypeExpr() { return typeExpr; }
836  const Expr *getTypeExpr() const { return typeExpr; }
837 
838 protected:
840  : Base(loc), typeExpr(typeExpr) {}
841 
842  /// An optional type that the value is constrained to.
844 };
845 
846 //===----------------------------------------------------------------------===//
847 // ValueRangeConstraintDecl
848 //===----------------------------------------------------------------------===//
849 
850 /// The class represents a ValueRange constraint, and constrains a variable to
851 /// be a ValueRange.
853  : public Node::NodeBase<ValueRangeConstraintDecl, CoreConstraintDecl> {
854 public:
855  static ValueRangeConstraintDecl *create(Context &ctx, SMRange loc,
856  Expr *typeExpr = nullptr);
857 
858  /// Return the optional type the value range is constrained to.
859  Expr *getTypeExpr() { return typeExpr; }
860  const Expr *getTypeExpr() const { return typeExpr; }
861 
862 protected:
864  : Base(loc), typeExpr(typeExpr) {}
865 
866  /// An optional type that the value range is constrained to.
868 };
869 
870 //===----------------------------------------------------------------------===//
871 // UserConstraintDecl
872 //===----------------------------------------------------------------------===//
873 
874 /// This decl represents a user defined constraint. This is either:
875 /// * an imported native constraint
876 /// - Similar to an external function declaration. This is a native
877 /// constraint defined externally, and imported into PDLL via a
878 /// declaration.
879 /// * a native constraint defined in PDLL
880 /// - This is a native constraint, i.e. a constraint whose implementation is
881 /// defined in C++(or potentially some other non-PDLL language). The
882 /// implementation of this constraint is specified as a string code block
883 /// in PDLL.
884 /// * a PDLL constraint
885 /// - This is a constraint which is defined using only PDLL constructs.
887  : public Node::NodeBase<UserConstraintDecl, ConstraintDecl>,
888  llvm::TrailingObjects<UserConstraintDecl, VariableDecl *, StringRef> {
889 public:
890  /// Create a native constraint with the given optional code block.
891  static UserConstraintDecl *
893  ArrayRef<VariableDecl *> results,
894  std::optional<StringRef> codeBlock, Type resultType,
895  ArrayRef<StringRef> nativeInputTypes = {}) {
896  return createImpl(ctx, name, inputs, nativeInputTypes, results, codeBlock,
897  /*body=*/nullptr, resultType);
898  }
899 
900  /// Create a PDLL constraint with the given body.
901  static UserConstraintDecl *createPDLL(Context &ctx, const Name &name,
903  ArrayRef<VariableDecl *> results,
904  const CompoundStmt *body,
905  Type resultType) {
906  return createImpl(ctx, name, inputs, /*nativeInputTypes=*/std::nullopt,
907  results, /*codeBlock=*/std::nullopt, body, resultType);
908  }
909 
910  /// Return the name of the constraint.
911  const Name &getName() const { return *Decl::getName(); }
912 
913  /// Return the input arguments of this constraint.
915  return getTrailingObjects<VariableDecl *>(numInputs);
916  }
918  return getTrailingObjects<VariableDecl *>(numInputs);
919  }
920 
921  /// Return the explicit native type to use for the given input. Returns
922  /// std::nullopt if no explicit type was set.
923  std::optional<StringRef> getNativeInputType(unsigned index) const;
924 
925  /// Return the explicit results of the constraint declaration. May be empty,
926  /// even if the constraint has results (e.g. in the case of inferred results).
928  return {getTrailingObjects<VariableDecl *>() + numInputs, numResults};
929  }
931  return const_cast<UserConstraintDecl *>(this)->getResults();
932  }
933 
934  /// Return the optional code block of this constraint, if this is a native
935  /// constraint with a provided implementation.
936  std::optional<StringRef> getCodeBlock() const { return codeBlock; }
937 
938  /// Return the body of this constraint if this constraint is a PDLL
939  /// constraint, otherwise returns nullptr.
940  const CompoundStmt *getBody() const { return constraintBody; }
941 
942  /// Return the result type of this constraint.
943  Type getResultType() const { return resultType; }
944 
945  /// Returns true if this constraint is external.
946  bool isExternal() const { return !constraintBody && !codeBlock; }
947 
948 private:
949  /// Create either a PDLL constraint or a native constraint with the given
950  /// components.
951  static UserConstraintDecl *createImpl(Context &ctx, const Name &name,
953  ArrayRef<StringRef> nativeInputTypes,
954  ArrayRef<VariableDecl *> results,
955  std::optional<StringRef> codeBlock,
956  const CompoundStmt *body,
957  Type resultType);
958 
959  UserConstraintDecl(const Name &name, unsigned numInputs,
960  bool hasNativeInputTypes, unsigned numResults,
961  std::optional<StringRef> codeBlock,
962  const CompoundStmt *body, Type resultType)
963  : Base(name.getLoc(), &name), numInputs(numInputs),
964  numResults(numResults), codeBlock(codeBlock), constraintBody(body),
965  resultType(resultType), hasNativeInputTypes(hasNativeInputTypes) {}
966 
967  /// The number of inputs to this constraint.
968  unsigned numInputs;
969 
970  /// The number of explicit results to this constraint.
971  unsigned numResults;
972 
973  /// The optional code block of this constraint.
974  std::optional<StringRef> codeBlock;
975 
976  /// The optional body of this constraint.
977  const CompoundStmt *constraintBody;
978 
979  /// The result type of the constraint.
980  Type resultType;
981 
982  /// Flag indicating if this constraint has explicit native input types.
983  bool hasNativeInputTypes;
984 
985  /// Allow access to various internals.
986  friend llvm::TrailingObjects<UserConstraintDecl, VariableDecl *, StringRef>;
987  size_t numTrailingObjects(OverloadToken<VariableDecl *>) const {
988  return numInputs + numResults;
989  }
990 };
991 
992 //===----------------------------------------------------------------------===//
993 // NamedAttributeDecl
994 //===----------------------------------------------------------------------===//
995 
996 /// This Decl represents a NamedAttribute, and contains a string name and
997 /// attribute value.
998 class NamedAttributeDecl : public Node::NodeBase<NamedAttributeDecl, Decl> {
999 public:
1000  static NamedAttributeDecl *create(Context &ctx, const Name &name,
1001  Expr *value);
1002 
1003  /// Return the name of the attribute.
1004  const Name &getName() const { return *Decl::getName(); }
1005 
1006  /// Return value of the attribute.
1007  Expr *getValue() const { return value; }
1008 
1009 private:
1010  NamedAttributeDecl(const Name &name, Expr *value)
1011  : Base(name.getLoc(), &name), value(value) {}
1012 
1013  /// The value of the attribute.
1014  Expr *value;
1015 };
1016 
1017 //===----------------------------------------------------------------------===//
1018 // OpNameDecl
1019 //===----------------------------------------------------------------------===//
1020 
1021 /// This Decl represents an OperationName.
1022 class OpNameDecl : public Node::NodeBase<OpNameDecl, Decl> {
1023 public:
1024  static OpNameDecl *create(Context &ctx, const Name &name);
1025  static OpNameDecl *create(Context &ctx, SMRange loc);
1026 
1027  /// Return the name of this operation, or std::nullopt if the name is unknown.
1028  std::optional<StringRef> getName() const {
1029  const Name *name = Decl::getName();
1030  return name ? std::optional<StringRef>(name->getName()) : std::nullopt;
1031  }
1032 
1033 private:
1034  explicit OpNameDecl(const Name &name) : Base(name.getLoc(), &name) {}
1035  explicit OpNameDecl(SMRange loc) : Base(loc) {}
1036 };
1037 
1038 //===----------------------------------------------------------------------===//
1039 // PatternDecl
1040 //===----------------------------------------------------------------------===//
1041 
1042 /// This Decl represents a single Pattern.
1043 class PatternDecl : public Node::NodeBase<PatternDecl, Decl> {
1044 public:
1045  static PatternDecl *create(Context &ctx, SMRange location, const Name *name,
1046  std::optional<uint16_t> benefit,
1047  bool hasBoundedRecursion,
1048  const CompoundStmt *body);
1049 
1050  /// Return the benefit of this pattern if specified, or std::nullopt.
1051  std::optional<uint16_t> getBenefit() const { return benefit; }
1052 
1053  /// Return if this pattern has bounded rewrite recursion.
1054  bool hasBoundedRewriteRecursion() const { return hasBoundedRecursion; }
1055 
1056  /// Return the body of this pattern.
1057  const CompoundStmt *getBody() const { return patternBody; }
1058 
1059  /// Return the root rewrite statement of this pattern.
1061  return cast<OpRewriteStmt>(patternBody->getChildren().back());
1062  }
1063 
1064 private:
1065  PatternDecl(SMRange loc, const Name *name, std::optional<uint16_t> benefit,
1066  bool hasBoundedRecursion, const CompoundStmt *body)
1067  : Base(loc, name), benefit(benefit),
1068  hasBoundedRecursion(hasBoundedRecursion), patternBody(body) {}
1069 
1070  /// The benefit of the pattern if it was explicitly specified, std::nullopt
1071  /// otherwise.
1072  std::optional<uint16_t> benefit;
1073 
1074  /// If the pattern has properly bounded rewrite recursion or not.
1075  bool hasBoundedRecursion;
1076 
1077  /// The compound statement representing the body of the pattern.
1078  const CompoundStmt *patternBody;
1079 };
1080 
1081 //===----------------------------------------------------------------------===//
1082 // UserRewriteDecl
1083 //===----------------------------------------------------------------------===//
1084 
1085 /// This decl represents a user defined rewrite. This is either:
1086 /// * an imported native rewrite
1087 /// - Similar to an external function declaration. This is a native
1088 /// rewrite defined externally, and imported into PDLL via a declaration.
1089 /// * a native rewrite defined in PDLL
1090 /// - This is a native rewrite, i.e. a rewrite whose implementation is
1091 /// defined in C++(or potentially some other non-PDLL language). The
1092 /// implementation of this rewrite is specified as a string code block
1093 /// in PDLL.
1094 /// * a PDLL rewrite
1095 /// - This is a rewrite which is defined using only PDLL constructs.
1096 class UserRewriteDecl final
1097  : public Node::NodeBase<UserRewriteDecl, Decl>,
1098  llvm::TrailingObjects<UserRewriteDecl, VariableDecl *> {
1099 public:
1100  /// Create a native rewrite with the given optional code block.
1101  static UserRewriteDecl *createNative(Context &ctx, const Name &name,
1102  ArrayRef<VariableDecl *> inputs,
1103  ArrayRef<VariableDecl *> results,
1104  std::optional<StringRef> codeBlock,
1105  Type resultType) {
1106  return createImpl(ctx, name, inputs, results, codeBlock, /*body=*/nullptr,
1107  resultType);
1108  }
1109 
1110  /// Create a PDLL rewrite with the given body.
1111  static UserRewriteDecl *createPDLL(Context &ctx, const Name &name,
1112  ArrayRef<VariableDecl *> inputs,
1113  ArrayRef<VariableDecl *> results,
1114  const CompoundStmt *body,
1115  Type resultType) {
1116  return createImpl(ctx, name, inputs, results, /*codeBlock=*/std::nullopt,
1117  body, resultType);
1118  }
1119 
1120  /// Return the name of the rewrite.
1121  const Name &getName() const { return *Decl::getName(); }
1122 
1123  /// Return the input arguments of this rewrite.
1125  return getTrailingObjects(numInputs);
1126  }
1128  return getTrailingObjects(numInputs);
1129  }
1130 
1131  /// Return the explicit results of the rewrite declaration. May be empty,
1132  /// even if the rewrite has results (e.g. in the case of inferred results).
1134  return {getTrailingObjects() + numInputs, numResults};
1135  }
1137  return const_cast<UserRewriteDecl *>(this)->getResults();
1138  }
1139 
1140  /// Return the optional code block of this rewrite, if this is a native
1141  /// rewrite with a provided implementation.
1142  std::optional<StringRef> getCodeBlock() const { return codeBlock; }
1143 
1144  /// Return the body of this rewrite if this rewrite is a PDLL rewrite,
1145  /// otherwise returns nullptr.
1146  const CompoundStmt *getBody() const { return rewriteBody; }
1147 
1148  /// Return the result type of this rewrite.
1149  Type getResultType() const { return resultType; }
1150 
1151  /// Returns true if this rewrite is external.
1152  bool isExternal() const { return !rewriteBody && !codeBlock; }
1153 
1154 private:
1155  /// Create either a PDLL rewrite or a native rewrite with the given
1156  /// components.
1157  static UserRewriteDecl *createImpl(Context &ctx, const Name &name,
1158  ArrayRef<VariableDecl *> inputs,
1159  ArrayRef<VariableDecl *> results,
1160  std::optional<StringRef> codeBlock,
1161  const CompoundStmt *body, Type resultType);
1162 
1163  UserRewriteDecl(const Name &name, unsigned numInputs, unsigned numResults,
1164  std::optional<StringRef> codeBlock, const CompoundStmt *body,
1165  Type resultType)
1166  : Base(name.getLoc(), &name), numInputs(numInputs),
1167  numResults(numResults), codeBlock(codeBlock), rewriteBody(body),
1168  resultType(resultType) {}
1169 
1170  /// The number of inputs to this rewrite.
1171  unsigned numInputs;
1172 
1173  /// The number of explicit results to this rewrite.
1174  unsigned numResults;
1175 
1176  /// The optional code block of this rewrite.
1177  std::optional<StringRef> codeBlock;
1178 
1179  /// The optional body of this rewrite.
1180  const CompoundStmt *rewriteBody;
1181 
1182  /// The result type of the rewrite.
1183  Type resultType;
1184 
1185  /// Allow access to various internals.
1186  friend llvm::TrailingObjects<UserRewriteDecl, VariableDecl *>;
1187 };
1188 
1189 //===----------------------------------------------------------------------===//
1190 // CallableDecl
1191 //===----------------------------------------------------------------------===//
1192 
1193 /// This decl represents a shared interface for all callable decls.
1194 class CallableDecl : public Decl {
1195 public:
1196  /// Return the callable type of this decl.
1197  StringRef getCallableType() const {
1198  if (isa<UserConstraintDecl>(this))
1199  return "constraint";
1200  assert(isa<UserRewriteDecl>(this) && "unknown callable type");
1201  return "rewrite";
1202  }
1203 
1204  /// Return the inputs of this decl.
1206  if (const auto *cst = dyn_cast<UserConstraintDecl>(this))
1207  return cst->getInputs();
1208  return cast<UserRewriteDecl>(this)->getInputs();
1209  }
1210 
1211  /// Return the result type of this decl.
1213  if (const auto *cst = dyn_cast<UserConstraintDecl>(this))
1214  return cst->getResultType();
1215  return cast<UserRewriteDecl>(this)->getResultType();
1216  }
1217 
1218  /// Return the explicit results of the declaration. Note that these may be
1219  /// empty, even if the callable has results (e.g. in the case of inferred
1220  /// results).
1222  if (const auto *cst = dyn_cast<UserConstraintDecl>(this))
1223  return cst->getResults();
1224  return cast<UserRewriteDecl>(this)->getResults();
1225  }
1226 
1227  /// Return the optional code block of this callable, if this is a native
1228  /// callable with a provided implementation.
1229  std::optional<StringRef> getCodeBlock() const {
1230  if (const auto *cst = dyn_cast<UserConstraintDecl>(this))
1231  return cst->getCodeBlock();
1232  return cast<UserRewriteDecl>(this)->getCodeBlock();
1233  }
1234 
1235  /// Support LLVM type casting facilities.
1236  static bool classof(const Node *decl) {
1237  return isa<UserConstraintDecl, UserRewriteDecl>(decl);
1238  }
1239 };
1240 
1241 //===----------------------------------------------------------------------===//
1242 // VariableDecl
1243 //===----------------------------------------------------------------------===//
1244 
1245 /// This Decl represents the definition of a PDLL variable.
1246 class VariableDecl final
1247  : public Node::NodeBase<VariableDecl, Decl>,
1248  private llvm::TrailingObjects<VariableDecl, ConstraintRef> {
1249 public:
1250  static VariableDecl *create(Context &ctx, const Name &name, Type type,
1251  Expr *initExpr,
1252  ArrayRef<ConstraintRef> constraints);
1253 
1254  /// Return the constraints of this variable.
1256  return getTrailingObjects(numConstraints);
1257  }
1259  return getTrailingObjects(numConstraints);
1260  }
1261 
1262  /// Return the initializer expression of this statement, or nullptr if there
1263  /// was no initializer.
1264  Expr *getInitExpr() const { return initExpr; }
1265 
1266  /// Return the name of the decl.
1267  const Name &getName() const { return *Decl::getName(); }
1268 
1269  /// Return the type of the decl.
1270  Type getType() const { return type; }
1271 
1272 private:
1273  VariableDecl(const Name &name, Type type, Expr *initExpr,
1274  unsigned numConstraints)
1275  : Base(name.getLoc(), &name), type(type), initExpr(initExpr),
1276  numConstraints(numConstraints) {}
1277 
1278  /// The type of the variable.
1279  Type type;
1280 
1281  /// The optional initializer expression of this statement.
1282  Expr *initExpr;
1283 
1284  /// The number of constraints attached to this variable.
1285  unsigned numConstraints;
1286 
1287  /// Allow access to various internals.
1288  friend llvm::TrailingObjects<VariableDecl, ConstraintRef>;
1289 };
1290 
1291 //===----------------------------------------------------------------------===//
1292 // Module
1293 //===----------------------------------------------------------------------===//
1294 
1295 /// This class represents a top-level AST module.
1296 class Module final : public Node::NodeBase<Module, Node>,
1297  private llvm::TrailingObjects<Module, Decl *> {
1298 public:
1299  static Module *create(Context &ctx, SMLoc loc, ArrayRef<Decl *> children);
1300 
1301  /// Return the children of this module.
1303  return getTrailingObjects(numChildren);
1304  }
1306  return getTrailingObjects(numChildren);
1307  }
1308 
1309 private:
1310  Module(SMLoc loc, unsigned numChildren)
1311  : Base(SMRange{loc, loc}), numChildren(numChildren) {}
1312 
1313  /// The number of decls held by this module.
1314  unsigned numChildren;
1315 
1316  /// Allow access to various internals.
1317  friend llvm::TrailingObjects<Module, Decl *>;
1318 };
1319 
1320 //===----------------------------------------------------------------------===//
1321 // Defered Method Definitions
1322 //===----------------------------------------------------------------------===//
1323 
1324 inline bool Decl::classof(const Node *node) {
1326  UserRewriteDecl, VariableDecl>(node);
1327 }
1328 
1329 inline bool ConstraintDecl::classof(const Node *node) {
1330  return isa<CoreConstraintDecl, UserConstraintDecl>(node);
1331 }
1332 
1333 inline bool CoreConstraintDecl::classof(const Node *node) {
1336  ValueRangeConstraintDecl>(node);
1337 }
1338 
1339 inline bool Expr::classof(const Node *node) {
1342 }
1343 
1344 inline bool OpRewriteStmt::classof(const Node *node) {
1345  return isa<EraseStmt, ReplaceStmt, RewriteStmt>(node);
1346 }
1347 
1348 inline bool Stmt::classof(const Node *node) {
1349  return isa<CompoundStmt, LetStmt, OpRewriteStmt, Expr>(node);
1350 }
1351 
1352 } // namespace ast
1353 } // namespace pdll
1354 } // namespace mlir
1355 
1356 #endif // MLIR_TOOLS_PDLL_AST_NODES_H_
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:107
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of MemberAccessExpr that references all results of an operation.
Definition: Nodes.h:484
static StringRef getMemberName()
Return the member name used for the "all-results" access.
Definition: Nodes.h:487
static AllResultsMemberAccessExpr * create(Context &ctx, SMRange loc, const Expr *parentExpr, Type type)
Definition: Nodes.h:489
static bool classof(const Node *node)
Provide type casting support.
Definition: Nodes.h:496
The class represents an Attribute constraint, and constrains a variable to be an Attribute.
Definition: Nodes.h:750
Expr * typeExpr
An optional type that the attribute is constrained to.
Definition: Nodes.h:764
static AttrConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr=nullptr)
Definition: Nodes.cpp:385
Expr * getTypeExpr()
Return the optional type the attribute is constrained to.
Definition: Nodes.h:756
const Expr * getTypeExpr() const
Definition: Nodes.h:757
AttrConstraintDecl(SMRange loc, Expr *typeExpr)
Definition: Nodes.h:760
This expression represents a literal MLIR Attribute, and contains the textual assembly format of that...
Definition: Nodes.h:370
static AttributeExpr * create(Context &ctx, SMRange loc, StringRef value)
Definition: Nodes.cpp:259
StringRef getValue() const
Get the raw value of this expression.
Definition: Nodes.h:376
This class represents a PDLL type that corresponds to an mlir::Attribute.
Definition: Types.h:107
This expression represents a call to a decl, such as a UserConstraintDecl/UserRewriteDecl.
Definition: Nodes.h:393
Expr * getCallableExpr() const
Return the callable of this call.
Definition: Nodes.h:400
MutableArrayRef< Expr * > getArguments()
Return the arguments of this call.
Definition: Nodes.h:403
ArrayRef< Expr * > getArguments() const
Definition: Nodes.h:404
static CallExpr * create(Context &ctx, SMRange loc, Expr *callable, ArrayRef< Expr * > arguments, Type resultType, bool isNegated=false)
Definition: Nodes.cpp:269
bool getIsNegated() const
Returns whether the result of this call is to be negated.
Definition: Nodes.h:407
This decl represents a shared interface for all callable decls.
Definition: Nodes.h:1194
std::optional< StringRef > getCodeBlock() const
Return the optional code block of this callable, if this is a native callable with a provided impleme...
Definition: Nodes.h:1229
Type getResultType() const
Return the result type of this decl.
Definition: Nodes.h:1212
ArrayRef< VariableDecl * > getInputs() const
Return the inputs of this decl.
Definition: Nodes.h:1205
StringRef getCallableType() const
Return the callable type of this decl.
Definition: Nodes.h:1197
ArrayRef< VariableDecl * > getResults() const
Return the explicit results of the declaration.
Definition: Nodes.h:1221
static bool classof(const Node *decl)
Support LLVM type casting facilities.
Definition: Nodes.h:1236
This statement represents a compound statement, which contains a collection of other statements.
Definition: Nodes.h:179
ArrayRef< Stmt * > getChildren() const
Definition: Nodes.h:188
ArrayRef< Stmt * >::iterator begin() const
Definition: Nodes.h:191
MutableArrayRef< Stmt * > getChildren()
Return the children of this compound statement.
Definition: Nodes.h:185
ArrayRef< Stmt * >::iterator end() const
Definition: Nodes.h:192
static CompoundStmt * create(Context &ctx, SMRange location, ArrayRef< Stmt * > children)
Definition: Nodes.cpp:192
This class represents the base of all AST Constraint decls.
Definition: Nodes.h:704
static bool classof(const Node *node)
Provide type casting support.
Definition: Nodes.h:1329
ConstraintDecl(TypeID typeID, SMRange loc, const Name *name=nullptr)
Definition: Nodes.h:710
This class represents the main context of the PDLL AST.
Definition: Context.h:25
This class represents the base of all "core" constraints.
Definition: Nodes.h:733
static bool classof(const Node *node)
Provide type casting support.
Definition: Nodes.h:1333
CoreConstraintDecl(TypeID typeID, SMRange loc, const Name *name=nullptr)
Definition: Nodes.h:739
This expression represents a reference to a Decl node.
Definition: Nodes.h:433
static DeclRefExpr * create(Context &ctx, SMRange loc, Decl *decl, Type type)
Definition: Nodes.cpp:285
Decl * getDecl() const
Get the decl referenced by this expression.
Definition: Nodes.h:438
This class represents a scope for named AST decls.
Definition: Nodes.h:64
auto getDecls() const
Return all of the decls within this scope.
Definition: Nodes.h:74
const Decl * lookup(StringRef name) const
Definition: Nodes.h:86
const DeclScope * getParentScope() const
Definition: Nodes.h:71
DeclScope * getParentScope()
Return the parent scope of this scope, or nullptr if there is no parent.
Definition: Nodes.h:70
T * lookup(StringRef name)
Definition: Nodes.h:83
Decl * lookup(StringRef name)
Lookup a decl with the given name starting from this scope.
Definition: Nodes.cpp:182
void add(Decl *decl)
Add a new decl to the scope.
Definition: Nodes.cpp:175
DeclScope(DeclScope *parent=nullptr)
Create a new scope with an optional parent scope.
Definition: Nodes.h:67
const T * lookup(StringRef name) const
Definition: Nodes.h:90
This class represents the base Decl node.
Definition: Nodes.h:669
std::optional< StringRef > getDocComment() const
Return the documentation comment attached to this decl if it has been set.
Definition: Nodes.h:682
const Name * getName() const
Return the name of the decl, or nullptr if it doesn't have one.
Definition: Nodes.h:672
Decl(TypeID typeID, SMRange loc, const Name *name=nullptr)
Definition: Nodes.h:685
static bool classof(const Node *node)
Provide type casting support.
Definition: Nodes.h:1324
void setDocComment(Context &ctx, StringRef comment)
Set the documentation comment for this decl.
Definition: Nodes.cpp:377
This statement represents the erase statement in PDLL.
Definition: Nodes.h:255
static EraseStmt * create(Context &ctx, SMRange loc, Expr *rootOp)
Definition: Nodes.cpp:218
This class represents a base AST Expression node.
Definition: Nodes.h:348
Expr(TypeID typeID, SMRange loc, Type type)
Definition: Nodes.h:357
static bool classof(const Node *node)
Provide type casting support.
Definition: Nodes.h:1339
Type getType() const
Return the type of this expression.
Definition: Nodes.h:351
This statement represents a let statement in PDLL.
Definition: Nodes.h:211
VariableDecl * getVarDecl() const
Return the variable defined by this statement.
Definition: Nodes.h:216
static LetStmt * create(Context &ctx, SMRange loc, VariableDecl *varDecl)
Definition: Nodes.cpp:206
This expression represents a named member or field access of a given parent expression.
Definition: Nodes.h:454
static MemberAccessExpr * create(Context &ctx, SMRange loc, const Expr *parentExpr, StringRef memberName, Type type)
Definition: Nodes.cpp:295
const Expr * getParentExpr() const
Get the parent expression of this access.
Definition: Nodes.h:461
StringRef getMemberName() const
Return the name of the member being accessed.
Definition: Nodes.h:464
This class represents a top-level AST module.
Definition: Nodes.h:1297
MutableArrayRef< Decl * > getChildren()
Return the children of this module.
Definition: Nodes.h:1302
static Module * create(Context &ctx, SMLoc loc, ArrayRef< Decl * > children)
Definition: Nodes.cpp:566
ArrayRef< Decl * > getChildren() const
Definition: Nodes.h:1305
This Decl represents a NamedAttribute, and contains a string name and attribute value.
Definition: Nodes.h:998
Expr * getValue() const
Return value of the attribute.
Definition: Nodes.h:1007
const Name & getName() const
Return the name of the attribute.
Definition: Nodes.h:1004
static NamedAttributeDecl * create(Context &ctx, const Name &name, Expr *value)
Definition: Nodes.cpp:492
This CRTP class provides several utilies when defining new AST nodes.
Definition: Nodes.h:112
static bool classof(const Node *node)
Provide type casting support.
Definition: Nodes.h:117
NodeBase< T, BaseT > Base
Definition: Nodes.h:114
NodeBase(SMRange loc, Args &&...args)
Definition: Nodes.h:123
This class represents a base AST node.
Definition: Nodes.h:108
Node(TypeID typeID, SMRange loc)
Definition: Nodes.h:149
void walk(function_ref< void(const Node *)> walkFn) const
Walk all of the nodes including, and nested under, this node in pre-order.
Definition: Nodes.cpp:167
SMRange getLoc() const
Return the location of this node.
Definition: Nodes.h:131
std::enable_if_t<!std::is_convertible< const Node *, ArgT >::value > walk(WalkFnT &&walkFn) const
Definition: Nodes.h:141
void print(raw_ostream &os) const
Print this node to the given stream.
TypeID getTypeID() const
Return the type identifier of this node.
Definition: Nodes.h:128
The class represents an Operation constraint, and constrains a variable to be an Operation.
Definition: Nodes.h:774
static OpConstraintDecl * create(Context &ctx, SMRange loc, const OpNameDecl *nameDecl=nullptr)
Definition: Nodes.cpp:395
std::optional< StringRef > getName() const
Return the name of the operation, or std::nullopt if there isn't one.
Definition: Nodes.cpp:404
OpConstraintDecl(SMRange loc, const OpNameDecl *nameDecl)
Definition: Nodes.h:786
const OpNameDecl * getNameDecl() const
Return the declaration of the operation name.
Definition: Nodes.h:783
const OpNameDecl * nameDecl
The operation name of this constraint.
Definition: Nodes.h:790
This Decl represents an OperationName.
Definition: Nodes.h:1022
std::optional< StringRef > getName() const
Return the name of this operation, or std::nullopt if the name is unknown.
Definition: Nodes.h:1028
static OpNameDecl * create(Context &ctx, const Name &name)
Definition: Nodes.cpp:502
This class represents a base operation rewrite statement.
Definition: Nodes.h:231
static bool classof(const Node *node)
Provide type casting support.
Definition: Nodes.h:1344
OpRewriteStmt(TypeID typeID, SMRange loc, Expr *rootOp)
Definition: Nodes.h:240
Expr * rootOp
The root operation being rewritten.
Definition: Nodes.h:245
Expr * getRootOpExpr() const
Return the root operation of this rewrite.
Definition: Nodes.h:237
This expression represents the structural form of an MLIR Operation.
Definition: Nodes.h:512
ArrayRef< NamedAttributeDecl * > getAttributes() const
Definition: Nodes.h:551
MutableArrayRef< Expr * > getResultTypes()
Return the result types of this operation.
Definition: Nodes.h:540
MutableArrayRef< Expr * > getResultTypes() const
Definition: Nodes.h:543
MutableArrayRef< NamedAttributeDecl * > getAttributes()
Return the attributes of this operation.
Definition: Nodes.h:548
const OpNameDecl * getNameDecl() const
Return the declaration of the operation name.
Definition: Nodes.h:525
SMRange getNameLoc() const
Return the location of the name of the operation expression, or an invalid location if there isn't a ...
Definition: Nodes.h:529
MutableArrayRef< Expr * > getOperands()
Return the operands of this operation.
Definition: Nodes.h:532
static OperationExpr * create(Context &ctx, SMRange loc, const ods::Operation *odsOp, const OpNameDecl *nameDecl, ArrayRef< Expr * > operands, ArrayRef< Expr * > resultTypes, ArrayRef< NamedAttributeDecl * > attributes)
Definition: Nodes.cpp:307
std::optional< StringRef > getName() const
Return the name of the operation, or std::nullopt if there isn't one.
Definition: Nodes.cpp:327
ArrayRef< Expr * > getOperands() const
Definition: Nodes.h:535
This Decl represents a single Pattern.
Definition: Nodes.h:1043
const CompoundStmt * getBody() const
Return the body of this pattern.
Definition: Nodes.h:1057
static PatternDecl * create(Context &ctx, SMRange location, const Name *name, std::optional< uint16_t > benefit, bool hasBoundedRecursion, const CompoundStmt *body)
Definition: Nodes.cpp:513
const OpRewriteStmt * getRootRewriteStmt() const
Return the root rewrite statement of this pattern.
Definition: Nodes.h:1060
std::optional< uint16_t > getBenefit() const
Return the benefit of this pattern if specified, or std::nullopt.
Definition: Nodes.h:1051
bool hasBoundedRewriteRecursion() const
Return if this pattern has bounded rewrite recursion.
Definition: Nodes.h:1054
This expression builds a range from a set of element values (which may be ranges themselves).
Definition: Nodes.h:586
static RangeExpr * create(Context &ctx, SMRange loc, ArrayRef< Expr * > elements, RangeType type)
Definition: Nodes.cpp:335
MutableArrayRef< Expr * > getElements()
Return the element expressions of this range.
Definition: Nodes.h:592
ArrayRef< Expr * > getElements() const
Definition: Nodes.h:595
RangeType getType() const
Return the range result type of this expression.
Definition: Nodes.h:600
This class represents a PDLL type that corresponds to a range of elements with a given element type.
Definition: Types.h:159
This statement represents the replace statement in PDLL.
Definition: Nodes.h:271
ArrayRef< Expr * > getReplExprs() const
Definition: Nodes.h:280
MutableArrayRef< Expr * > getReplExprs()
Return the replacement values of this statement.
Definition: Nodes.h:277
static ReplaceStmt * create(Context &ctx, SMRange loc, Expr *rootOp, ArrayRef< Expr * > replExprs)
Definition: Nodes.cpp:226
This statement represents a return from a "callable" like decl, e.g.
Definition: Nodes.h:324
static ReturnStmt * create(Context &ctx, SMRange loc, Expr *resultExpr)
Definition: Nodes.cpp:250
void setResultExpr(Expr *expr)
Set the result expression of this statement.
Definition: Nodes.h:333
Expr * getResultExpr()
Return the result expression of this statement.
Definition: Nodes.h:329
const Expr * getResultExpr() const
Definition: Nodes.h:330
This statement represents an operation rewrite that contains a block of nested rewrite commands.
Definition: Nodes.h:302
static RewriteStmt * create(Context &ctx, SMRange loc, Expr *rootOp, CompoundStmt *rewriteBody)
Definition: Nodes.cpp:240
CompoundStmt * getRewriteBody() const
Return the compound rewrite body.
Definition: Nodes.h:308
This class represents a base AST Statement node.
Definition: Nodes.h:164
static bool classof(const Node *node)
Provide type casting support.
Definition: Nodes.h:1348
This expression builds a tuple from a set of element values.
Definition: Nodes.h:619
TupleType getType() const
Return the tuple result type of this expression.
Definition: Nodes.h:633
MutableArrayRef< Expr * > getElements()
Return the element expressions of this tuple.
Definition: Nodes.h:625
ArrayRef< Expr * > getElements() const
Definition: Nodes.h:628
static TupleExpr * create(Context &ctx, SMRange loc, ArrayRef< Expr * > elements, ArrayRef< StringRef > elementNames)
Definition: Nodes.cpp:349
This class represents a PDLL tuple type, i.e.
Definition: Types.h:222
The class represents a Type constraint, and constrains a variable to be a Type.
Definition: Nodes.h:800
static TypeConstraintDecl * create(Context &ctx, SMRange loc)
Definition: Nodes.cpp:412
This expression represents a literal MLIR Type, and contains the textual assembly format of that type...
Definition: Nodes.h:648
StringRef getValue() const
Get the raw value of this expression.
Definition: Nodes.h:654
static TypeExpr * create(Context &ctx, SMRange loc, StringRef value)
Definition: Nodes.cpp:368
The class represents a TypeRange constraint, and constrains a variable to be a TypeRange.
Definition: Nodes.h:815
static TypeRangeConstraintDecl * create(Context &ctx, SMRange loc)
Definition: Nodes.cpp:421
This class represents a PDLL type that corresponds to an mlir::Type.
Definition: Types.h:250
This decl represents a user defined constraint.
Definition: Nodes.h:888
ArrayRef< VariableDecl * > getResults() const
Definition: Nodes.h:930
bool isExternal() const
Returns true if this constraint is external.
Definition: Nodes.h:946
MutableArrayRef< VariableDecl * > getInputs()
Return the input arguments of this constraint.
Definition: Nodes.h:914
std::optional< StringRef > getNativeInputType(unsigned index) const
Return the explicit native type to use for the given input.
Definition: Nodes.cpp:452
const Name & getName() const
Return the name of the constraint.
Definition: Nodes.h:911
std::optional< StringRef > getCodeBlock() const
Return the optional code block of this constraint, if this is a native constraint with a provided imp...
Definition: Nodes.h:936
static UserConstraintDecl * createPDLL(Context &ctx, const Name &name, ArrayRef< VariableDecl * > inputs, ArrayRef< VariableDecl * > results, const CompoundStmt *body, Type resultType)
Create a PDLL constraint with the given body.
Definition: Nodes.h:901
Type getResultType() const
Return the result type of this constraint.
Definition: Nodes.h:943
MutableArrayRef< VariableDecl * > getResults()
Return the explicit results of the constraint declaration.
Definition: Nodes.h:927
ArrayRef< VariableDecl * > getInputs() const
Definition: Nodes.h:917
const CompoundStmt * getBody() const
Return the body of this constraint if this constraint is a PDLL constraint, otherwise returns nullptr...
Definition: Nodes.h:940
static UserConstraintDecl * createNative(Context &ctx, const Name &name, ArrayRef< VariableDecl * > inputs, ArrayRef< VariableDecl * > results, std::optional< StringRef > codeBlock, Type resultType, ArrayRef< StringRef > nativeInputTypes={})
Create a native constraint with the given optional code block.
Definition: Nodes.h:892
This decl represents a user defined rewrite.
Definition: Nodes.h:1098
std::optional< StringRef > getCodeBlock() const
Return the optional code block of this rewrite, if this is a native rewrite with a provided implement...
Definition: Nodes.h:1142
Type getResultType() const
Return the result type of this rewrite.
Definition: Nodes.h:1149
const Name & getName() const
Return the name of the rewrite.
Definition: Nodes.h:1121
ArrayRef< VariableDecl * > getResults() const
Definition: Nodes.h:1136
const CompoundStmt * getBody() const
Return the body of this rewrite if this rewrite is a PDLL rewrite, otherwise returns nullptr.
Definition: Nodes.h:1146
static UserRewriteDecl * createNative(Context &ctx, const Name &name, ArrayRef< VariableDecl * > inputs, ArrayRef< VariableDecl * > results, std::optional< StringRef > codeBlock, Type resultType)
Create a native rewrite with the given optional code block.
Definition: Nodes.h:1101
static UserRewriteDecl * createPDLL(Context &ctx, const Name &name, ArrayRef< VariableDecl * > inputs, ArrayRef< VariableDecl * > results, const CompoundStmt *body, Type resultType)
Create a PDLL rewrite with the given body.
Definition: Nodes.h:1111
MutableArrayRef< VariableDecl * > getInputs()
Return the input arguments of this rewrite.
Definition: Nodes.h:1124
ArrayRef< VariableDecl * > getInputs() const
Definition: Nodes.h:1127
MutableArrayRef< VariableDecl * > getResults()
Return the explicit results of the rewrite declaration.
Definition: Nodes.h:1133
bool isExternal() const
Returns true if this rewrite is external.
Definition: Nodes.h:1152
The class represents a Value constraint, and constrains a variable to be a Value.
Definition: Nodes.h:830
static ValueConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr)
Definition: Nodes.cpp:431
Expr * typeExpr
An optional type that the value is constrained to.
Definition: Nodes.h:843
const Expr * getTypeExpr() const
Definition: Nodes.h:836
ValueConstraintDecl(SMRange loc, Expr *typeExpr)
Definition: Nodes.h:839
Expr * getTypeExpr()
Return the optional type the value is constrained to.
Definition: Nodes.h:835
The class represents a ValueRange constraint, and constrains a variable to be a ValueRange.
Definition: Nodes.h:853
const Expr * getTypeExpr() const
Definition: Nodes.h:860
Expr * getTypeExpr()
Return the optional type the value range is constrained to.
Definition: Nodes.h:859
static ValueRangeConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr=nullptr)
Definition: Nodes.cpp:442
ValueRangeConstraintDecl(SMRange loc, Expr *typeExpr)
Definition: Nodes.h:863
Expr * typeExpr
An optional type that the value range is constrained to.
Definition: Nodes.h:867
This Decl represents the definition of a PDLL variable.
Definition: Nodes.h:1248
const Name & getName() const
Return the name of the decl.
Definition: Nodes.h:1267
Expr * getInitExpr() const
Return the initializer expression of this statement, or nullptr if there was no initializer.
Definition: Nodes.h:1264
static VariableDecl * create(Context &ctx, const Name &name, Type type, Expr *initExpr, ArrayRef< ConstraintRef > constraints)
Definition: Nodes.cpp:549
MutableArrayRef< ConstraintRef > getConstraints()
Return the constraints of this variable.
Definition: Nodes.h:1255
Type getType() const
Return the type of the decl.
Definition: Nodes.h:1270
ArrayRef< ConstraintRef > getConstraints() const
Definition: Nodes.h:1258
This class provides an ODS representation of a specific operation.
Definition: Operation.h:125
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents a reference to a constraint, and contains a constraint and the location of the ...
Definition: Nodes.h:716
ConstraintRef(const ConstraintDecl *constraint, SMRange refLoc)
Definition: Nodes.h:717
const ConstraintDecl * constraint
Definition: Nodes.h:722
ConstraintRef(const ConstraintDecl *constraint)
Definition: Nodes.h:719
This class provides a convenient API for interacting with source names.
Definition: Nodes.h:37
StringRef getName() const
Return the raw string name.
Definition: Nodes.h:41
SMRange getLoc() const
Get the location of this name.
Definition: Nodes.h:44
static const Name & create(Context &ctx, StringRef name, SMRange location)
Definition: Nodes.cpp:33