MLIR 22.0.0git
Nodes.h
Go to the documentation of this file.
1//===- Nodes.h --------------------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_TOOLS_PDLL_AST_NODES_H_
10#define MLIR_TOOLS_PDLL_AST_NODES_H_
11
12#include "mlir/Support/LLVM.h"
14#include "llvm/ADT/StringMap.h"
15#include "llvm/ADT/StringRef.h"
16#include "llvm/Support/SMLoc.h"
17#include "llvm/Support/SourceMgr.h"
18#include "llvm/Support/TrailingObjects.h"
19#include <optional>
20
21namespace mlir {
22namespace pdll {
23namespace ast {
24class Context;
25class Decl;
26class Expr;
28class OpNameDecl;
29class VariableDecl;
30
31//===----------------------------------------------------------------------===//
32// Name
33//===----------------------------------------------------------------------===//
34
35/// This class provides a convenient API for interacting with source names. It
36/// contains a string name as well as the source location for that name.
37struct Name {
38 static const Name &create(Context &ctx, StringRef name, SMRange location);
39
40 /// Return the raw string name.
41 StringRef getName() const { return name; }
42
43 /// Get the location of this name.
44 SMRange getLoc() const { return location; }
45
46private:
47 Name() = delete;
48 Name(const Name &) = delete;
49 Name &operator=(const Name &) = delete;
50 Name(StringRef name, SMRange location) : name(name), location(location) {}
51
52 /// The string name of the decl.
53 StringRef name;
54 /// The location of the decl name.
55 SMRange location;
56};
57
58//===----------------------------------------------------------------------===//
59// DeclScope
60//===----------------------------------------------------------------------===//
61
62/// This class represents a scope for named AST decls. A scope determines the
63/// visibility and lifetime of a named declaration.
64class DeclScope {
65public:
66 /// Create a new scope with an optional parent scope.
67 DeclScope(DeclScope *parent = nullptr) : parent(parent) {}
68
69 /// Return the parent scope of this scope, or nullptr if there is no parent.
70 DeclScope *getParentScope() { return parent; }
71 const DeclScope *getParentScope() const { return parent; }
72
73 /// Return all of the decls within this scope.
74 auto getDecls() const { return llvm::make_second_range(decls); }
75
76 /// Add a new decl to the scope.
77 void add(Decl *decl);
78
79 /// Lookup a decl with the given name starting from this scope. Returns
80 /// nullptr if no decl could be found.
81 Decl *lookup(StringRef name);
82 template <typename T>
83 T *lookup(StringRef name) {
84 return dyn_cast_or_null<T>(lookup(name));
85 }
86 const Decl *lookup(StringRef name) const {
87 return const_cast<DeclScope *>(this)->lookup(name);
88 }
89 template <typename T>
90 const T *lookup(StringRef name) const {
91 return dyn_cast_or_null<T>(lookup(name));
92 }
93
94private:
95 /// The parent scope, or null if this is a top-level scope.
96 DeclScope *parent;
97 /// The decls defined within this scope.
98 llvm::StringMap<Decl *> decls;
99};
100
101//===----------------------------------------------------------------------===//
102// Node
103//===----------------------------------------------------------------------===//
104
105/// This class represents a base AST node. All AST nodes are derived from this
106/// class, and it contains many of the base functionality for interacting with
107/// nodes.
108class Node {
109public:
110 /// This CRTP class provides several utilies when defining new AST nodes.
111 template <typename T, typename BaseT>
112 class NodeBase : public BaseT {
113 public:
115
116 /// Provide type casting support.
117 static bool classof(const Node *node) {
118 return node->getTypeID() == TypeID::get<T>();
119 }
120
121 protected:
122 template <typename... Args>
123 explicit NodeBase(SMRange loc, Args &&...args)
124 : BaseT(TypeID::get<T>(), loc, std::forward<Args>(args)...) {}
125 };
126
127 /// Return the type identifier of this node.
128 TypeID getTypeID() const { return typeID; }
129
130 /// Return the location of this node.
131 SMRange getLoc() const { return loc; }
132
133 /// Print this node to the given stream.
134 void print(raw_ostream &os) const;
135
136 /// Walk all of the nodes including, and nested under, this node in pre-order.
137 void walk(function_ref<void(const Node *)> walkFn) const;
138 template <typename WalkFnT, typename ArgT = typename llvm::function_traits<
139 WalkFnT>::template arg_t<0>>
140 std::enable_if_t<!std::is_convertible<const Node *, ArgT>::value>
141 walk(WalkFnT &&walkFn) const {
142 walk([&](const Node *node) {
143 if (const ArgT *derivedNode = dyn_cast<ArgT>(node))
144 walkFn(derivedNode);
145 });
146 }
147
148protected:
149 Node(TypeID typeID, SMRange loc) : typeID(typeID), loc(loc) {}
150
151private:
152 /// A unique type identifier for this node.
153 TypeID typeID;
154
155 /// The location of this node.
156 SMRange loc;
157};
158
159//===----------------------------------------------------------------------===//
160// Stmt
161//===----------------------------------------------------------------------===//
162
163/// This class represents a base AST Statement node.
164class Stmt : public Node {
165public:
166 using Node::Node;
167
168 /// Provide type casting support.
169 static bool classof(const Node *node);
170};
171
172//===----------------------------------------------------------------------===//
173// CompoundStmt
174//===----------------------------------------------------------------------===//
175
176/// This statement represents a compound statement, which contains a collection
177/// of other statements.
178class CompoundStmt final : public Node::NodeBase<CompoundStmt, Stmt>,
179 private llvm::TrailingObjects<CompoundStmt, Stmt *> {
180public:
181 static CompoundStmt *create(Context &ctx, SMRange location,
182 ArrayRef<Stmt *> children);
183
184 /// Return the children of this compound statement.
186 return getTrailingObjects(numChildren);
187 }
189 return getTrailingObjects(numChildren);
190 }
191 ArrayRef<Stmt *>::iterator begin() const { return getChildren().begin(); }
192 ArrayRef<Stmt *>::iterator end() const { return getChildren().end(); }
193
194private:
195 CompoundStmt(SMRange location, unsigned numChildren)
196 : Base(location), numChildren(numChildren) {}
197
198 /// The number of held children statements.
199 unsigned numChildren;
200
201 // Allow access to various privates.
202 friend class llvm::TrailingObjects<CompoundStmt, Stmt *>;
203};
204
205//===----------------------------------------------------------------------===//
206// LetStmt
207//===----------------------------------------------------------------------===//
208
209/// This statement represents a `let` statement in PDLL. This statement is used
210/// to define variables.
211class LetStmt final : public Node::NodeBase<LetStmt, Stmt> {
212public:
213 static LetStmt *create(Context &ctx, SMRange loc, VariableDecl *varDecl);
214
215 /// Return the variable defined by this statement.
216 VariableDecl *getVarDecl() const { return varDecl; }
217
218private:
219 LetStmt(SMRange loc, VariableDecl *varDecl) : Base(loc), varDecl(varDecl) {}
220
221 /// The variable defined by this statement.
222 VariableDecl *varDecl;
223};
224
225//===----------------------------------------------------------------------===//
226// OpRewriteStmt
227//===----------------------------------------------------------------------===//
228
229/// This class represents a base operation rewrite statement. Operation rewrite
230/// statements perform a set of transformations on a given root operation.
231class OpRewriteStmt : public Stmt {
232public:
233 /// Provide type casting support.
234 static bool classof(const Node *node);
235
236 /// Return the root operation of this rewrite.
237 Expr *getRootOpExpr() const { return rootOp; }
238
239protected:
240 OpRewriteStmt(TypeID typeID, SMRange loc, Expr *rootOp)
241 : Stmt(typeID, loc), rootOp(rootOp) {}
242
243protected:
244 /// The root operation being rewritten.
246};
247
248//===----------------------------------------------------------------------===//
249// EraseStmt
250//===----------------------------------------------------------------------===//
251
252/// This statement represents the `erase` statement in PDLL. This statement
253/// erases the given root operation, corresponding roughly to the
254/// PatternRewriter::eraseOp API.
255class EraseStmt final : public Node::NodeBase<EraseStmt, OpRewriteStmt> {
256public:
257 static EraseStmt *create(Context &ctx, SMRange loc, Expr *rootOp);
258
259private:
260 EraseStmt(SMRange loc, Expr *rootOp) : Base(loc, rootOp) {}
261};
262
263//===----------------------------------------------------------------------===//
264// ReplaceStmt
265//===----------------------------------------------------------------------===//
266
267/// This statement represents the `replace` statement in PDLL. This statement
268/// replace the given root operation with a set of values, corresponding roughly
269/// to the PatternRewriter::replaceOp API.
270class ReplaceStmt final : public Node::NodeBase<ReplaceStmt, OpRewriteStmt>,
271 private llvm::TrailingObjects<ReplaceStmt, Expr *> {
272public:
273 static ReplaceStmt *create(Context &ctx, SMRange loc, Expr *rootOp,
274 ArrayRef<Expr *> replExprs);
275
276 /// Return the replacement values of this statement.
278 return getTrailingObjects(numReplExprs);
279 }
281 return getTrailingObjects(numReplExprs);
282 }
283
284private:
285 ReplaceStmt(SMRange loc, Expr *rootOp, unsigned numReplExprs)
286 : Base(loc, rootOp), numReplExprs(numReplExprs) {}
287
288 /// The number of replacement values within this statement.
289 unsigned numReplExprs;
290
291 /// TrailingObject utilities.
292 friend class llvm::TrailingObjects<ReplaceStmt, Expr *>;
293};
294
295//===----------------------------------------------------------------------===//
296// RewriteStmt
297//===----------------------------------------------------------------------===//
298
299/// This statement represents an operation rewrite that contains a block of
300/// nested rewrite commands. This allows for building more complex operation
301/// rewrites that span across multiple statements, which may be unconnected.
302class RewriteStmt final : public Node::NodeBase<RewriteStmt, OpRewriteStmt> {
303public:
304 static RewriteStmt *create(Context &ctx, SMRange loc, Expr *rootOp,
305 CompoundStmt *rewriteBody);
306
307 /// Return the compound rewrite body.
308 CompoundStmt *getRewriteBody() const { return rewriteBody; }
309
310private:
311 RewriteStmt(SMRange loc, Expr *rootOp, CompoundStmt *rewriteBody)
312 : Base(loc, rootOp), rewriteBody(rewriteBody) {}
313
314 /// The body of nested rewriters within this statement.
315 CompoundStmt *rewriteBody;
316};
317
318//===----------------------------------------------------------------------===//
319// ReturnStmt
320//===----------------------------------------------------------------------===//
321
322/// This statement represents a return from a "callable" like decl, e.g. a
323/// Constraint or a Rewrite.
324class ReturnStmt final : public Node::NodeBase<ReturnStmt, Stmt> {
325public:
326 static ReturnStmt *create(Context &ctx, SMRange loc, Expr *resultExpr);
327
328 /// Return the result expression of this statement.
329 Expr *getResultExpr() { return resultExpr; }
330 const Expr *getResultExpr() const { return resultExpr; }
331
332 /// Set the result expression of this statement.
333 void setResultExpr(Expr *expr) { resultExpr = expr; }
334
335private:
336 ReturnStmt(SMRange loc, Expr *resultExpr)
337 : Base(loc), resultExpr(resultExpr) {}
338
339 // The result expression of this statement.
340 Expr *resultExpr;
341};
342
343//===----------------------------------------------------------------------===//
344// Expr
345//===----------------------------------------------------------------------===//
346
347/// This class represents a base AST Expression node.
348class Expr : public Stmt {
349public:
350 /// Return the type of this expression.
351 Type getType() const { return type; }
352
353 /// Provide type casting support.
354 static bool classof(const Node *node);
355
356protected:
357 Expr(TypeID typeID, SMRange loc, Type type) : Stmt(typeID, loc), type(type) {}
358
359private:
360 /// The type of this expression.
361 Type type;
362};
363
364//===----------------------------------------------------------------------===//
365// AttributeExpr
366//===----------------------------------------------------------------------===//
367
368/// This expression represents a literal MLIR Attribute, and contains the
369/// textual assembly format of that attribute.
370class AttributeExpr : public Node::NodeBase<AttributeExpr, Expr> {
371public:
372 static AttributeExpr *create(Context &ctx, SMRange loc, StringRef value);
373
374 /// Get the raw value of this expression. This is the textual assembly format
375 /// of the MLIR Attribute.
376 StringRef getValue() const { return value; }
377
378private:
379 AttributeExpr(Context &ctx, SMRange loc, StringRef value)
380 : Base(loc, AttributeType::get(ctx)), value(value) {}
381
382 /// The value referenced by this expression.
383 StringRef value;
384};
385
386//===----------------------------------------------------------------------===//
387// CallExpr
388//===----------------------------------------------------------------------===//
389
390/// This expression represents a call to a decl, such as a
391/// UserConstraintDecl/UserRewriteDecl.
392class CallExpr final : public Node::NodeBase<CallExpr, Expr>,
393 private llvm::TrailingObjects<CallExpr, Expr *> {
394public:
395 static CallExpr *create(Context &ctx, SMRange loc, Expr *callable,
396 ArrayRef<Expr *> arguments, Type resultType,
397 bool isNegated = false);
398
399 /// Return the callable of this call.
400 Expr *getCallableExpr() const { return callable; }
401
402 /// Return the arguments of this call.
403 MutableArrayRef<Expr *> getArguments() { return getTrailingObjects(numArgs); }
404 ArrayRef<Expr *> getArguments() const { return getTrailingObjects(numArgs); }
405
406 /// Returns whether the result of this call is to be negated.
407 bool getIsNegated() const { return isNegated; }
408
409private:
410 CallExpr(SMRange loc, Type type, Expr *callable, unsigned numArgs,
411 bool isNegated)
412 : Base(loc, type), callable(callable), numArgs(numArgs),
413 isNegated(isNegated) {}
414
415 /// The callable of this call.
416 Expr *callable;
417
418 /// The number of arguments of the call.
419 unsigned numArgs;
420
421 /// TrailingObject utilities.
422 friend llvm::TrailingObjects<CallExpr, Expr *>;
423
424 // Is the result of this call to be negated.
425 bool isNegated;
426};
427
428//===----------------------------------------------------------------------===//
429// DeclRefExpr
430//===----------------------------------------------------------------------===//
431
432/// This expression represents a reference to a Decl node.
433class DeclRefExpr : public Node::NodeBase<DeclRefExpr, Expr> {
434public:
435 static DeclRefExpr *create(Context &ctx, SMRange loc, Decl *decl, Type type);
436
437 /// Get the decl referenced by this expression.
438 Decl *getDecl() const { return decl; }
439
440private:
441 DeclRefExpr(SMRange loc, Decl *decl, Type type)
442 : Base(loc, type), decl(decl) {}
443
444 /// The decl referenced by this expression.
445 Decl *decl;
446};
447
448//===----------------------------------------------------------------------===//
449// MemberAccessExpr
450//===----------------------------------------------------------------------===//
451
452/// This expression represents a named member or field access of a given parent
453/// expression.
454class MemberAccessExpr : public Node::NodeBase<MemberAccessExpr, Expr> {
455public:
456 static MemberAccessExpr *create(Context &ctx, SMRange loc,
457 const Expr *parentExpr, StringRef memberName,
458 Type type);
459
460 /// Get the parent expression of this access.
461 const Expr *getParentExpr() const { return parentExpr; }
462
463 /// Return the name of the member being accessed.
464 StringRef getMemberName() const { return memberName; }
465
466private:
467 MemberAccessExpr(SMRange loc, const Expr *parentExpr, StringRef memberName,
468 Type type)
469 : Base(loc, type), parentExpr(parentExpr), memberName(memberName) {}
470
471 /// The parent expression of this access.
472 const Expr *parentExpr;
473
474 /// The name of the member being accessed from the parent.
475 StringRef memberName;
476};
477
478//===----------------------------------------------------------------------===//
479// AllResultsMemberAccessExpr
480//===----------------------------------------------------------------------===//
481
482/// This class represents an instance of MemberAccessExpr that references all
483/// results of an operation.
484class AllResultsMemberAccessExpr : public MemberAccessExpr {
485public:
486 /// Return the member name used for the "all-results" access.
487 static StringRef getMemberName() { return "$results"; }
488
489 static AllResultsMemberAccessExpr *create(Context &ctx, SMRange loc,
490 const Expr *parentExpr, Type type) {
491 return cast<AllResultsMemberAccessExpr>(
492 MemberAccessExpr::create(ctx, loc, parentExpr, getMemberName(), type));
493 }
494
495 /// Provide type casting support.
496 static bool classof(const Node *node) {
497 const MemberAccessExpr *memAccess = dyn_cast<MemberAccessExpr>(node);
498 return memAccess && memAccess->getMemberName() == getMemberName();
499 }
500};
501
502//===----------------------------------------------------------------------===//
503// OperationExpr
504//===----------------------------------------------------------------------===//
505
506/// This expression represents the structural form of an MLIR Operation. It
507/// represents either an input operation to match, or an operation to create
508/// within a rewrite.
509class OperationExpr final
510 : public Node::NodeBase<OperationExpr, Expr>,
511 private llvm::TrailingObjects<OperationExpr, Expr *,
512 NamedAttributeDecl *> {
513public:
514 static OperationExpr *create(Context &ctx, SMRange loc,
515 const ods::Operation *odsOp,
516 const OpNameDecl *nameDecl,
517 ArrayRef<Expr *> operands,
518 ArrayRef<Expr *> resultTypes,
520
521 /// Return the name of the operation, or std::nullopt if there isn't one.
522 std::optional<StringRef> getName() const;
523
524 /// Return the declaration of the operation name.
525 const OpNameDecl *getNameDecl() const { return nameDecl; }
526
527 /// Return the location of the name of the operation expression, or an invalid
528 /// location if there isn't a name.
529 SMRange getNameLoc() const { return nameLoc; }
530
531 /// Return the operands of this operation.
533 return getTrailingObjects<Expr *>(numOperands);
534 }
536 return getTrailingObjects<Expr *>(numOperands);
537 }
538
539 /// Return the result types of this operation.
541 return {getTrailingObjects<Expr *>() + numOperands, numResultTypes};
542 }
544 return const_cast<OperationExpr *>(this)->getResultTypes();
545 }
546
547 /// Return the attributes of this operation.
549 return getTrailingObjects<NamedAttributeDecl *>(numAttributes);
550 }
552 return getTrailingObjects<NamedAttributeDecl *>(numAttributes);
553 }
554
555private:
556 OperationExpr(SMRange loc, Type type, const OpNameDecl *nameDecl,
557 unsigned numOperands, unsigned numResultTypes,
558 unsigned numAttributes, SMRange nameLoc)
559 : Base(loc, type), nameDecl(nameDecl), numOperands(numOperands),
560 numResultTypes(numResultTypes), numAttributes(numAttributes),
561 nameLoc(nameLoc) {}
562
563 /// The name decl of this expression.
564 const OpNameDecl *nameDecl;
565
566 /// The number of operands, result types, and attributes of the operation.
567 unsigned numOperands, numResultTypes, numAttributes;
568
569 /// The location of the operation name in the expression if it has a name.
570 SMRange nameLoc;
571
572 /// TrailingObject utilities.
573 friend llvm::TrailingObjects<OperationExpr, Expr *, NamedAttributeDecl *>;
574 size_t numTrailingObjects(OverloadToken<Expr *>) const {
575 return numOperands + numResultTypes;
576 }
577};
578
579//===----------------------------------------------------------------------===//
580// RangeExpr
581//===----------------------------------------------------------------------===//
582
583/// This expression builds a range from a set of element values (which may be
584/// ranges themselves).
585class RangeExpr final : public Node::NodeBase<RangeExpr, Expr>,
586 private llvm::TrailingObjects<RangeExpr, Expr *> {
587public:
588 static RangeExpr *create(Context &ctx, SMRange loc, ArrayRef<Expr *> elements,
589 RangeType type);
590
591 /// Return the element expressions of this range.
593 return getTrailingObjects(numElements);
594 }
596 return getTrailingObjects(numElements);
597 }
598
599 /// Return the range result type of this expression.
600 RangeType getType() const { return mlir::cast<RangeType>(Base::getType()); }
601
602private:
603 RangeExpr(SMRange loc, RangeType type, unsigned numElements)
604 : Base(loc, type), numElements(numElements) {}
605
606 /// The number of element values for this range.
607 unsigned numElements;
608
609 /// TrailingObject utilities.
610 friend class llvm::TrailingObjects<RangeExpr, Expr *>;
611};
612
613//===----------------------------------------------------------------------===//
614// TupleExpr
615//===----------------------------------------------------------------------===//
616
617/// This expression builds a tuple from a set of element values.
618class TupleExpr final : public Node::NodeBase<TupleExpr, Expr>,
619 private llvm::TrailingObjects<TupleExpr, Expr *> {
620public:
621 static TupleExpr *create(Context &ctx, SMRange loc, ArrayRef<Expr *> elements,
622 ArrayRef<StringRef> elementNames);
623
624 /// Return the element expressions of this tuple.
626 return getTrailingObjects(getType().size());
627 }
629 return getTrailingObjects(getType().size());
630 }
631
632 /// Return the tuple result type of this expression.
633 TupleType getType() const { return mlir::cast<TupleType>(Base::getType()); }
634
635private:
636 TupleExpr(SMRange loc, TupleType type) : Base(loc, type) {}
637
638 /// TrailingObject utilities.
639 friend class llvm::TrailingObjects<TupleExpr, Expr *>;
640};
641
642//===----------------------------------------------------------------------===//
643// TypeExpr
644//===----------------------------------------------------------------------===//
645
646/// This expression represents a literal MLIR Type, and contains the textual
647/// assembly format of that type.
648class TypeExpr : public Node::NodeBase<TypeExpr, Expr> {
649public:
650 static TypeExpr *create(Context &ctx, SMRange loc, StringRef value);
651
652 /// Get the raw value of this expression. This is the textual assembly format
653 /// of the MLIR Type.
654 StringRef getValue() const { return value; }
655
656private:
657 TypeExpr(Context &ctx, SMRange loc, StringRef value)
658 : Base(loc, TypeType::get(ctx)), value(value) {}
659
660 /// The value referenced by this expression.
661 StringRef value;
662};
663
664//===----------------------------------------------------------------------===//
665// Decl
666//===----------------------------------------------------------------------===//
667
668/// This class represents the base Decl node.
669class Decl : public Node {
670public:
671 /// Return the name of the decl, or nullptr if it doesn't have one.
672 const Name *getName() const { return name; }
673
674 /// Provide type casting support.
675 static bool classof(const Node *node);
676
677 /// Set the documentation comment for this decl.
678 void setDocComment(Context &ctx, StringRef comment);
679
680 /// Return the documentation comment attached to this decl if it has been set.
681 /// Otherwise, returns std::nullopt.
682 std::optional<StringRef> getDocComment() const { return docComment; }
683
684protected:
685 Decl(TypeID typeID, SMRange loc, const Name *name = nullptr)
686 : Node(typeID, loc), name(name) {}
687
688private:
689 /// The name of the decl. This is optional for some decls, such as
690 /// PatternDecl.
691 const Name *name;
692
693 /// The documentation comment attached to this decl. Defaults to std::nullopt
694 /// if the comment is unset/unknown.
695 std::optional<StringRef> docComment;
696};
697
698//===----------------------------------------------------------------------===//
699// ConstraintDecl
700//===----------------------------------------------------------------------===//
701
702/// This class represents the base of all AST Constraint decls. Constraints
703/// apply matcher conditions to, and define the type of PDLL variables.
704class ConstraintDecl : public Decl {
705public:
706 /// Provide type casting support.
707 static bool classof(const Node *node);
708
709protected:
710 ConstraintDecl(TypeID typeID, SMRange loc, const Name *name = nullptr)
711 : Decl(typeID, loc, name) {}
712};
713
714/// This class represents a reference to a constraint, and contains a constraint
715/// and the location of the reference.
725
726//===----------------------------------------------------------------------===//
727// CoreConstraintDecl
728//===----------------------------------------------------------------------===//
729
730/// This class represents the base of all "core" constraints. Core constraints
731/// are those that generally represent a concrete IR construct, such as
732/// `Type`s or `Value`s.
734public:
735 /// Provide type casting support.
736 static bool classof(const Node *node);
737
738protected:
739 CoreConstraintDecl(TypeID typeID, SMRange loc, const Name *name = nullptr)
740 : ConstraintDecl(typeID, loc, name) {}
741};
742
743//===----------------------------------------------------------------------===//
744// AttrConstraintDecl
745//===----------------------------------------------------------------------===//
746
747/// The class represents an Attribute constraint, and constrains a variable to
748/// be an Attribute.
750 : public Node::NodeBase<AttrConstraintDecl, CoreConstraintDecl> {
751public:
752 static AttrConstraintDecl *create(Context &ctx, SMRange loc,
753 Expr *typeExpr = nullptr);
754
755 /// Return the optional type the attribute is constrained to.
756 Expr *getTypeExpr() { return typeExpr; }
757 const Expr *getTypeExpr() const { return typeExpr; }
758
759protected:
761 : Base(loc), typeExpr(typeExpr) {}
762
763 /// An optional type that the attribute is constrained to.
765};
766
767//===----------------------------------------------------------------------===//
768// OpConstraintDecl
769//===----------------------------------------------------------------------===//
770
771/// The class represents an Operation constraint, and constrains a variable to
772/// be an Operation.
774 : public Node::NodeBase<OpConstraintDecl, CoreConstraintDecl> {
775public:
776 static OpConstraintDecl *create(Context &ctx, SMRange loc,
777 const OpNameDecl *nameDecl = nullptr);
778
779 /// Return the name of the operation, or std::nullopt if there isn't one.
780 std::optional<StringRef> getName() const;
781
782 /// Return the declaration of the operation name.
783 const OpNameDecl *getNameDecl() const { return nameDecl; }
784
785protected:
786 explicit OpConstraintDecl(SMRange loc, const OpNameDecl *nameDecl)
787 : Base(loc), nameDecl(nameDecl) {}
788
789 /// The operation name of this constraint.
791};
792
793//===----------------------------------------------------------------------===//
794// TypeConstraintDecl
795//===----------------------------------------------------------------------===//
796
797/// The class represents a Type constraint, and constrains a variable to be a
798/// Type.
800 : public Node::NodeBase<TypeConstraintDecl, CoreConstraintDecl> {
801public:
802 static TypeConstraintDecl *create(Context &ctx, SMRange loc);
803
804protected:
805 using Base::Base;
806};
807
808//===----------------------------------------------------------------------===//
809// TypeRangeConstraintDecl
810//===----------------------------------------------------------------------===//
811
812/// The class represents a TypeRange constraint, and constrains a variable to be
813/// a TypeRange.
815 : public Node::NodeBase<TypeRangeConstraintDecl, CoreConstraintDecl> {
816public:
817 static TypeRangeConstraintDecl *create(Context &ctx, SMRange loc);
818
819protected:
820 using Base::Base;
821};
822
823//===----------------------------------------------------------------------===//
824// ValueConstraintDecl
825//===----------------------------------------------------------------------===//
826
827/// The class represents a Value constraint, and constrains a variable to be a
828/// Value.
830 : public Node::NodeBase<ValueConstraintDecl, CoreConstraintDecl> {
831public:
832 static ValueConstraintDecl *create(Context &ctx, SMRange loc, Expr *typeExpr);
833
834 /// Return the optional type the value is constrained to.
835 Expr *getTypeExpr() { return typeExpr; }
836 const Expr *getTypeExpr() const { return typeExpr; }
837
838protected:
840 : Base(loc), typeExpr(typeExpr) {}
841
842 /// An optional type that the value is constrained to.
844};
845
846//===----------------------------------------------------------------------===//
847// ValueRangeConstraintDecl
848//===----------------------------------------------------------------------===//
849
850/// The class represents a ValueRange constraint, and constrains a variable to
851/// be a ValueRange.
853 : public Node::NodeBase<ValueRangeConstraintDecl, CoreConstraintDecl> {
854public:
855 static ValueRangeConstraintDecl *create(Context &ctx, SMRange loc,
856 Expr *typeExpr = nullptr);
857
858 /// Return the optional type the value range is constrained to.
859 Expr *getTypeExpr() { return typeExpr; }
860 const Expr *getTypeExpr() const { return typeExpr; }
861
862protected:
865
866 /// An optional type that the value range is constrained to.
868};
869
870//===----------------------------------------------------------------------===//
871// UserConstraintDecl
872//===----------------------------------------------------------------------===//
873
874/// This decl represents a user defined constraint. This is either:
875/// * an imported native constraint
876/// - Similar to an external function declaration. This is a native
877/// constraint defined externally, and imported into PDLL via a
878/// declaration.
879/// * a native constraint defined in PDLL
880/// - This is a native constraint, i.e. a constraint whose implementation is
881/// defined in C++(or potentially some other non-PDLL language). The
882/// implementation of this constraint is specified as a string code block
883/// in PDLL.
884/// * a PDLL constraint
885/// - This is a constraint which is defined using only PDLL constructs.
886class UserConstraintDecl final
887 : public Node::NodeBase<UserConstraintDecl, ConstraintDecl>,
888 llvm::TrailingObjects<UserConstraintDecl, VariableDecl *, StringRef> {
889public:
890 /// Create a native constraint with the given optional code block.
891 static UserConstraintDecl *
894 std::optional<StringRef> codeBlock, Type resultType,
895 ArrayRef<StringRef> nativeInputTypes = {}) {
896 return createImpl(ctx, name, inputs, nativeInputTypes, results, codeBlock,
897 /*body=*/nullptr, resultType);
898 }
899
900 /// Create a PDLL constraint with the given body.
901 static UserConstraintDecl *createPDLL(Context &ctx, const Name &name,
904 const CompoundStmt *body,
905 Type resultType) {
906 return createImpl(ctx, name, inputs, /*nativeInputTypes=*/{}, results,
907 /*codeBlock=*/std::nullopt, body, resultType);
908 }
909
910 /// Return the name of the constraint.
911 const Name &getName() const { return *Decl::getName(); }
912
913 /// Return the input arguments of this constraint.
915 return getTrailingObjects<VariableDecl *>(numInputs);
916 }
918 return getTrailingObjects<VariableDecl *>(numInputs);
919 }
920
921 /// Return the explicit native type to use for the given input. Returns
922 /// std::nullopt if no explicit type was set.
923 std::optional<StringRef> getNativeInputType(unsigned index) const;
924
925 /// Return the explicit results of the constraint declaration. May be empty,
926 /// even if the constraint has results (e.g. in the case of inferred results).
928 return {getTrailingObjects<VariableDecl *>() + numInputs, numResults};
929 }
931 return const_cast<UserConstraintDecl *>(this)->getResults();
932 }
933
934 /// Return the optional code block of this constraint, if this is a native
935 /// constraint with a provided implementation.
936 std::optional<StringRef> getCodeBlock() const { return codeBlock; }
937
938 /// Return the body of this constraint if this constraint is a PDLL
939 /// constraint, otherwise returns nullptr.
940 const CompoundStmt *getBody() const { return constraintBody; }
941
942 /// Return the result type of this constraint.
943 Type getResultType() const { return resultType; }
944
945 /// Returns true if this constraint is external.
946 bool isExternal() const { return !constraintBody && !codeBlock; }
947
948private:
949 /// Create either a PDLL constraint or a native constraint with the given
950 /// components.
951 static UserConstraintDecl *createImpl(Context &ctx, const Name &name,
953 ArrayRef<StringRef> nativeInputTypes,
955 std::optional<StringRef> codeBlock,
956 const CompoundStmt *body,
957 Type resultType);
958
959 UserConstraintDecl(const Name &name, unsigned numInputs,
960 bool hasNativeInputTypes, unsigned numResults,
961 std::optional<StringRef> codeBlock,
962 const CompoundStmt *body, Type resultType)
963 : Base(name.getLoc(), &name), numInputs(numInputs),
964 numResults(numResults), codeBlock(codeBlock), constraintBody(body),
965 resultType(resultType), hasNativeInputTypes(hasNativeInputTypes) {}
966
967 /// The number of inputs to this constraint.
968 unsigned numInputs;
969
970 /// The number of explicit results to this constraint.
971 unsigned numResults;
972
973 /// The optional code block of this constraint.
974 std::optional<StringRef> codeBlock;
975
976 /// The optional body of this constraint.
977 const CompoundStmt *constraintBody;
978
979 /// The result type of the constraint.
980 Type resultType;
981
982 /// Flag indicating if this constraint has explicit native input types.
983 bool hasNativeInputTypes;
984
985 /// Allow access to various internals.
986 friend llvm::TrailingObjects<UserConstraintDecl, VariableDecl *, StringRef>;
987 size_t numTrailingObjects(OverloadToken<VariableDecl *>) const {
988 return numInputs + numResults;
989 }
990};
991
992//===----------------------------------------------------------------------===//
993// NamedAttributeDecl
994//===----------------------------------------------------------------------===//
995
996/// This Decl represents a NamedAttribute, and contains a string name and
997/// attribute value.
998class NamedAttributeDecl : public Node::NodeBase<NamedAttributeDecl, Decl> {
999public:
1000 static NamedAttributeDecl *create(Context &ctx, const Name &name,
1001 Expr *value);
1002
1003 /// Return the name of the attribute.
1004 const Name &getName() const { return *Decl::getName(); }
1005
1006 /// Return value of the attribute.
1007 Expr *getValue() const { return value; }
1008
1009private:
1010 NamedAttributeDecl(const Name &name, Expr *value)
1011 : Base(name.getLoc(), &name), value(value) {}
1012
1013 /// The value of the attribute.
1014 Expr *value;
1015};
1016
1017//===----------------------------------------------------------------------===//
1018// OpNameDecl
1019//===----------------------------------------------------------------------===//
1020
1021/// This Decl represents an OperationName.
1022class OpNameDecl : public Node::NodeBase<OpNameDecl, Decl> {
1023public:
1024 static OpNameDecl *create(Context &ctx, const Name &name);
1025 static OpNameDecl *create(Context &ctx, SMRange loc);
1026
1027 /// Return the name of this operation, or std::nullopt if the name is unknown.
1028 std::optional<StringRef> getName() const {
1029 const Name *name = Decl::getName();
1030 return name ? std::optional<StringRef>(name->getName()) : std::nullopt;
1031 }
1032
1033private:
1034 explicit OpNameDecl(const Name &name) : Base(name.getLoc(), &name) {}
1035 explicit OpNameDecl(SMRange loc) : Base(loc) {}
1036};
1037
1038//===----------------------------------------------------------------------===//
1039// PatternDecl
1040//===----------------------------------------------------------------------===//
1041
1042/// This Decl represents a single Pattern.
1043class PatternDecl : public Node::NodeBase<PatternDecl, Decl> {
1044public:
1045 static PatternDecl *create(Context &ctx, SMRange location, const Name *name,
1046 std::optional<uint16_t> benefit,
1047 bool hasBoundedRecursion,
1048 const CompoundStmt *body);
1049
1050 /// Return the benefit of this pattern if specified, or std::nullopt.
1051 std::optional<uint16_t> getBenefit() const { return benefit; }
1052
1053 /// Return if this pattern has bounded rewrite recursion.
1054 bool hasBoundedRewriteRecursion() const { return hasBoundedRecursion; }
1055
1056 /// Return the body of this pattern.
1057 const CompoundStmt *getBody() const { return patternBody; }
1058
1059 /// Return the root rewrite statement of this pattern.
1061 return cast<OpRewriteStmt>(patternBody->getChildren().back());
1062 }
1063
1064private:
1065 PatternDecl(SMRange loc, const Name *name, std::optional<uint16_t> benefit,
1066 bool hasBoundedRecursion, const CompoundStmt *body)
1067 : Base(loc, name), benefit(benefit),
1068 hasBoundedRecursion(hasBoundedRecursion), patternBody(body) {}
1069
1070 /// The benefit of the pattern if it was explicitly specified, std::nullopt
1071 /// otherwise.
1072 std::optional<uint16_t> benefit;
1073
1074 /// If the pattern has properly bounded rewrite recursion or not.
1075 bool hasBoundedRecursion;
1076
1077 /// The compound statement representing the body of the pattern.
1078 const CompoundStmt *patternBody;
1079};
1080
1081//===----------------------------------------------------------------------===//
1082// UserRewriteDecl
1083//===----------------------------------------------------------------------===//
1084
1085/// This decl represents a user defined rewrite. This is either:
1086/// * an imported native rewrite
1087/// - Similar to an external function declaration. This is a native
1088/// rewrite defined externally, and imported into PDLL via a declaration.
1089/// * a native rewrite defined in PDLL
1090/// - This is a native rewrite, i.e. a rewrite whose implementation is
1091/// defined in C++(or potentially some other non-PDLL language). The
1092/// implementation of this rewrite is specified as a string code block
1093/// in PDLL.
1094/// * a PDLL rewrite
1095/// - This is a rewrite which is defined using only PDLL constructs.
1096class UserRewriteDecl final
1097 : public Node::NodeBase<UserRewriteDecl, Decl>,
1098 llvm::TrailingObjects<UserRewriteDecl, VariableDecl *> {
1099public:
1100 /// Create a native rewrite with the given optional code block.
1101 static UserRewriteDecl *createNative(Context &ctx, const Name &name,
1104 std::optional<StringRef> codeBlock,
1105 Type resultType) {
1106 return createImpl(ctx, name, inputs, results, codeBlock, /*body=*/nullptr,
1107 resultType);
1108 }
1109
1110 /// Create a PDLL rewrite with the given body.
1111 static UserRewriteDecl *createPDLL(Context &ctx, const Name &name,
1114 const CompoundStmt *body,
1115 Type resultType) {
1116 return createImpl(ctx, name, inputs, results, /*codeBlock=*/std::nullopt,
1117 body, resultType);
1118 }
1119
1120 /// Return the name of the rewrite.
1121 const Name &getName() const { return *Decl::getName(); }
1122
1123 /// Return the input arguments of this rewrite.
1125 return getTrailingObjects(numInputs);
1126 }
1128 return getTrailingObjects(numInputs);
1129 }
1130
1131 /// Return the explicit results of the rewrite declaration. May be empty,
1132 /// even if the rewrite has results (e.g. in the case of inferred results).
1134 return {getTrailingObjects() + numInputs, numResults};
1135 }
1137 return const_cast<UserRewriteDecl *>(this)->getResults();
1138 }
1139
1140 /// Return the optional code block of this rewrite, if this is a native
1141 /// rewrite with a provided implementation.
1142 std::optional<StringRef> getCodeBlock() const { return codeBlock; }
1143
1144 /// Return the body of this rewrite if this rewrite is a PDLL rewrite,
1145 /// otherwise returns nullptr.
1146 const CompoundStmt *getBody() const { return rewriteBody; }
1147
1148 /// Return the result type of this rewrite.
1149 Type getResultType() const { return resultType; }
1150
1151 /// Returns true if this rewrite is external.
1152 bool isExternal() const { return !rewriteBody && !codeBlock; }
1153
1154private:
1155 /// Create either a PDLL rewrite or a native rewrite with the given
1156 /// components.
1157 static UserRewriteDecl *createImpl(Context &ctx, const Name &name,
1160 std::optional<StringRef> codeBlock,
1161 const CompoundStmt *body, Type resultType);
1162
1163 UserRewriteDecl(const Name &name, unsigned numInputs, unsigned numResults,
1164 std::optional<StringRef> codeBlock, const CompoundStmt *body,
1165 Type resultType)
1166 : Base(name.getLoc(), &name), numInputs(numInputs),
1167 numResults(numResults), codeBlock(codeBlock), rewriteBody(body),
1168 resultType(resultType) {}
1169
1170 /// The number of inputs to this rewrite.
1171 unsigned numInputs;
1172
1173 /// The number of explicit results to this rewrite.
1174 unsigned numResults;
1175
1176 /// The optional code block of this rewrite.
1177 std::optional<StringRef> codeBlock;
1178
1179 /// The optional body of this rewrite.
1180 const CompoundStmt *rewriteBody;
1181
1182 /// The result type of the rewrite.
1183 Type resultType;
1184
1185 /// Allow access to various internals.
1186 friend llvm::TrailingObjects<UserRewriteDecl, VariableDecl *>;
1187};
1188
1189//===----------------------------------------------------------------------===//
1190// CallableDecl
1191//===----------------------------------------------------------------------===//
1192
1193/// This decl represents a shared interface for all callable decls.
1194class CallableDecl : public Decl {
1195public:
1196 /// Return the callable type of this decl.
1197 StringRef getCallableType() const {
1198 if (isa<UserConstraintDecl>(this))
1199 return "constraint";
1200 assert(isa<UserRewriteDecl>(this) && "unknown callable type");
1201 return "rewrite";
1202 }
1203
1204 /// Return the inputs of this decl.
1206 if (const auto *cst = dyn_cast<UserConstraintDecl>(this))
1207 return cst->getInputs();
1208 return cast<UserRewriteDecl>(this)->getInputs();
1209 }
1210
1211 /// Return the result type of this decl.
1213 if (const auto *cst = dyn_cast<UserConstraintDecl>(this))
1214 return cst->getResultType();
1215 return cast<UserRewriteDecl>(this)->getResultType();
1216 }
1217
1218 /// Return the explicit results of the declaration. Note that these may be
1219 /// empty, even if the callable has results (e.g. in the case of inferred
1220 /// results).
1222 if (const auto *cst = dyn_cast<UserConstraintDecl>(this))
1223 return cst->getResults();
1224 return cast<UserRewriteDecl>(this)->getResults();
1225 }
1226
1227 /// Return the optional code block of this callable, if this is a native
1228 /// callable with a provided implementation.
1229 std::optional<StringRef> getCodeBlock() const {
1230 if (const auto *cst = dyn_cast<UserConstraintDecl>(this))
1231 return cst->getCodeBlock();
1232 return cast<UserRewriteDecl>(this)->getCodeBlock();
1233 }
1234
1235 /// Support LLVM type casting facilities.
1236 static bool classof(const Node *decl) {
1237 return isa<UserConstraintDecl, UserRewriteDecl>(decl);
1238 }
1239};
1240
1241//===----------------------------------------------------------------------===//
1242// VariableDecl
1243//===----------------------------------------------------------------------===//
1244
1245/// This Decl represents the definition of a PDLL variable.
1246class VariableDecl final
1247 : public Node::NodeBase<VariableDecl, Decl>,
1248 private llvm::TrailingObjects<VariableDecl, ConstraintRef> {
1249public:
1250 static VariableDecl *create(Context &ctx, const Name &name, Type type,
1251 Expr *initExpr,
1252 ArrayRef<ConstraintRef> constraints);
1253
1254 /// Return the constraints of this variable.
1256 return getTrailingObjects(numConstraints);
1257 }
1259 return getTrailingObjects(numConstraints);
1260 }
1261
1262 /// Return the initializer expression of this statement, or nullptr if there
1263 /// was no initializer.
1264 Expr *getInitExpr() const { return initExpr; }
1265
1266 /// Return the name of the decl.
1267 const Name &getName() const { return *Decl::getName(); }
1268
1269 /// Return the type of the decl.
1270 Type getType() const { return type; }
1271
1272private:
1273 VariableDecl(const Name &name, Type type, Expr *initExpr,
1274 unsigned numConstraints)
1275 : Base(name.getLoc(), &name), type(type), initExpr(initExpr),
1276 numConstraints(numConstraints) {}
1277
1278 /// The type of the variable.
1279 Type type;
1280
1281 /// The optional initializer expression of this statement.
1282 Expr *initExpr;
1283
1284 /// The number of constraints attached to this variable.
1285 unsigned numConstraints;
1286
1287 /// Allow access to various internals.
1288 friend llvm::TrailingObjects<VariableDecl, ConstraintRef>;
1289};
1290
1291//===----------------------------------------------------------------------===//
1292// Module
1293//===----------------------------------------------------------------------===//
1294
1295/// This class represents a top-level AST module.
1296class Module final : public Node::NodeBase<Module, Node>,
1297 private llvm::TrailingObjects<Module, Decl *> {
1298public:
1299 static Module *create(Context &ctx, SMLoc loc, ArrayRef<Decl *> children);
1300
1301 /// Return the children of this module.
1303 return getTrailingObjects(numChildren);
1304 }
1306 return getTrailingObjects(numChildren);
1307 }
1308
1309private:
1310 Module(SMLoc loc, unsigned numChildren)
1311 : Base(SMRange{loc, loc}), numChildren(numChildren) {}
1312
1313 /// The number of decls held by this module.
1314 unsigned numChildren;
1315
1316 /// Allow access to various internals.
1317 friend llvm::TrailingObjects<Module, Decl *>;
1318};
1319
1320//===----------------------------------------------------------------------===//
1321// Defered Method Definitions
1322//===----------------------------------------------------------------------===//
1323
1324inline bool Decl::classof(const Node *node) {
1327}
1328
1329inline bool ConstraintDecl::classof(const Node *node) {
1330 return isa<CoreConstraintDecl, UserConstraintDecl>(node);
1331}
1332
1338
1339inline bool Expr::classof(const Node *node) {
1342}
1343
1344inline bool OpRewriteStmt::classof(const Node *node) {
1345 return isa<EraseStmt, ReplaceStmt, RewriteStmt>(node);
1346}
1347
1348inline bool Stmt::classof(const Node *node) {
1349 return isa<CompoundStmt, LetStmt, OpRewriteStmt, Expr>(node);
1350}
1351
1352} // namespace ast
1353} // namespace pdll
1354} // namespace mlir
1355
1356#endif // MLIR_TOOLS_PDLL_AST_NODES_H_
#define add(a, b)
This class provides an efficient unique identifier for a specific C++ type.
Definition TypeID.h:107
static TypeID get()
Construct a type info object for the given type T.
Definition TypeID.h:245
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of MemberAccessExpr that references all results of an operation.
Definition Nodes.h:484
static StringRef getMemberName()
Return the member name used for the "all-results" access.
Definition Nodes.h:487
static bool classof(const Node *node)
Provide type casting support.
Definition Nodes.h:496
static AllResultsMemberAccessExpr * create(Context &ctx, SMRange loc, const Expr *parentExpr, Type type)
Definition Nodes.h:489
The class represents an Attribute constraint, and constrains a variable to be an Attribute.
Definition Nodes.h:750
Expr * getTypeExpr()
Return the optional type the attribute is constrained to.
Definition Nodes.h:756
Expr * typeExpr
An optional type that the attribute is constrained to.
Definition Nodes.h:764
static AttrConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr=nullptr)
Definition Nodes.cpp:385
const Expr * getTypeExpr() const
Definition Nodes.h:757
AttrConstraintDecl(SMRange loc, Expr *typeExpr)
Definition Nodes.h:760
This expression represents a literal MLIR Attribute, and contains the textual assembly format of that...
Definition Nodes.h:370
static AttributeExpr * create(Context &ctx, SMRange loc, StringRef value)
Definition Nodes.cpp:259
StringRef getValue() const
Get the raw value of this expression.
Definition Nodes.h:376
This class represents a PDLL type that corresponds to an mlir::Attribute.
Definition Types.h:107
This expression represents a call to a decl, such as a UserConstraintDecl/UserRewriteDecl.
Definition Nodes.h:393
ArrayRef< Expr * > getArguments() const
Definition Nodes.h:404
Expr * getCallableExpr() const
Return the callable of this call.
Definition Nodes.h:400
MutableArrayRef< Expr * > getArguments()
Return the arguments of this call.
Definition Nodes.h:403
static CallExpr * create(Context &ctx, SMRange loc, Expr *callable, ArrayRef< Expr * > arguments, Type resultType, bool isNegated=false)
Definition Nodes.cpp:269
bool getIsNegated() const
Returns whether the result of this call is to be negated.
Definition Nodes.h:407
This decl represents a shared interface for all callable decls.
Definition Nodes.h:1194
std::optional< StringRef > getCodeBlock() const
Return the optional code block of this callable, if this is a native callable with a provided impleme...
Definition Nodes.h:1229
Type getResultType() const
Return the result type of this decl.
Definition Nodes.h:1212
ArrayRef< VariableDecl * > getResults() const
Return the explicit results of the declaration.
Definition Nodes.h:1221
StringRef getCallableType() const
Return the callable type of this decl.
Definition Nodes.h:1197
static bool classof(const Node *decl)
Support LLVM type casting facilities.
Definition Nodes.h:1236
ArrayRef< VariableDecl * > getInputs() const
Return the inputs of this decl.
Definition Nodes.h:1205
This statement represents a compound statement, which contains a collection of other statements.
Definition Nodes.h:179
ArrayRef< Stmt * > getChildren() const
Definition Nodes.h:188
ArrayRef< Stmt * >::iterator end() const
Definition Nodes.h:192
MutableArrayRef< Stmt * > getChildren()
Return the children of this compound statement.
Definition Nodes.h:185
ArrayRef< Stmt * >::iterator begin() const
Definition Nodes.h:191
static CompoundStmt * create(Context &ctx, SMRange location, ArrayRef< Stmt * > children)
Definition Nodes.cpp:192
This class represents the base of all AST Constraint decls.
Definition Nodes.h:704
static bool classof(const Node *node)
Provide type casting support.
Definition Nodes.h:1329
ConstraintDecl(TypeID typeID, SMRange loc, const Name *name=nullptr)
Definition Nodes.h:710
This class represents the main context of the PDLL AST.
Definition Context.h:25
static bool classof(const Node *node)
Provide type casting support.
Definition Nodes.h:1333
CoreConstraintDecl(TypeID typeID, SMRange loc, const Name *name=nullptr)
Definition Nodes.h:739
This expression represents a reference to a Decl node.
Definition Nodes.h:433
Decl * getDecl() const
Get the decl referenced by this expression.
Definition Nodes.h:438
static DeclRefExpr * create(Context &ctx, SMRange loc, Decl *decl, Type type)
Definition Nodes.cpp:285
This class represents a scope for named AST decls.
Definition Nodes.h:64
const DeclScope * getParentScope() const
Definition Nodes.h:71
auto getDecls() const
Return all of the decls within this scope.
Definition Nodes.h:74
const T * lookup(StringRef name) const
Definition Nodes.h:90
const Decl * lookup(StringRef name) const
Definition Nodes.h:86
Decl * lookup(StringRef name)
Lookup a decl with the given name starting from this scope.
Definition Nodes.cpp:182
DeclScope(DeclScope *parent=nullptr)
Create a new scope with an optional parent scope.
Definition Nodes.h:67
T * lookup(StringRef name)
Definition Nodes.h:83
DeclScope * getParentScope()
Return the parent scope of this scope, or nullptr if there is no parent.
Definition Nodes.h:70
This class represents the base Decl node.
Definition Nodes.h:669
std::optional< StringRef > getDocComment() const
Return the documentation comment attached to this decl if it has been set.
Definition Nodes.h:682
Decl(TypeID typeID, SMRange loc, const Name *name=nullptr)
Definition Nodes.h:685
static bool classof(const Node *node)
Provide type casting support.
Definition Nodes.h:1324
void setDocComment(Context &ctx, StringRef comment)
Set the documentation comment for this decl.
Definition Nodes.cpp:377
const Name * getName() const
Return the name of the decl, or nullptr if it doesn't have one.
Definition Nodes.h:672
static EraseStmt * create(Context &ctx, SMRange loc, Expr *rootOp)
Definition Nodes.cpp:218
This class represents a base AST Expression node.
Definition Nodes.h:348
Expr(TypeID typeID, SMRange loc, Type type)
Definition Nodes.h:357
static bool classof(const Node *node)
Provide type casting support.
Definition Nodes.h:1339
Type getType() const
Return the type of this expression.
Definition Nodes.h:351
This statement represents a let statement in PDLL.
Definition Nodes.h:211
VariableDecl * getVarDecl() const
Return the variable defined by this statement.
Definition Nodes.h:216
static LetStmt * create(Context &ctx, SMRange loc, VariableDecl *varDecl)
Definition Nodes.cpp:206
This expression represents a named member or field access of a given parent expression.
Definition Nodes.h:454
static MemberAccessExpr * create(Context &ctx, SMRange loc, const Expr *parentExpr, StringRef memberName, Type type)
Definition Nodes.cpp:295
const Expr * getParentExpr() const
Get the parent expression of this access.
Definition Nodes.h:461
StringRef getMemberName() const
Return the name of the member being accessed.
Definition Nodes.h:464
This class represents a top-level AST module.
Definition Nodes.h:1297
static Module * create(Context &ctx, SMLoc loc, ArrayRef< Decl * > children)
Definition Nodes.cpp:566
ArrayRef< Decl * > getChildren() const
Definition Nodes.h:1305
MutableArrayRef< Decl * > getChildren()
Return the children of this module.
Definition Nodes.h:1302
This Decl represents a NamedAttribute, and contains a string name and attribute value.
Definition Nodes.h:998
const Name & getName() const
Return the name of the attribute.
Definition Nodes.h:1004
static NamedAttributeDecl * create(Context &ctx, const Name &name, Expr *value)
Definition Nodes.cpp:492
Expr * getValue() const
Return value of the attribute.
Definition Nodes.h:1007
static bool classof(const Node *node)
Provide type casting support.
Definition Nodes.h:117
NodeBase< T, BaseT > Base
Definition Nodes.h:114
NodeBase(SMRange loc, Args &&...args)
Definition Nodes.h:123
This class represents a base AST node.
Definition Nodes.h:108
Node(TypeID typeID, SMRange loc)
Definition Nodes.h:149
void walk(function_ref< void(const Node *)> walkFn) const
Walk all of the nodes including, and nested under, this node in pre-order.
Definition Nodes.cpp:167
std::enable_if_t<!std::is_convertible< const Node *, ArgT >::value > walk(WalkFnT &&walkFn) const
Definition Nodes.h:141
SMRange getLoc() const
Return the location of this node.
Definition Nodes.h:131
void print(raw_ostream &os) const
Print this node to the given stream.
TypeID getTypeID() const
Return the type identifier of this node.
Definition Nodes.h:128
The class represents an Operation constraint, and constrains a variable to be an Operation.
Definition Nodes.h:774
static OpConstraintDecl * create(Context &ctx, SMRange loc, const OpNameDecl *nameDecl=nullptr)
Definition Nodes.cpp:395
std::optional< StringRef > getName() const
Return the name of the operation, or std::nullopt if there isn't one.
Definition Nodes.cpp:404
OpConstraintDecl(SMRange loc, const OpNameDecl *nameDecl)
Definition Nodes.h:786
const OpNameDecl * getNameDecl() const
Return the declaration of the operation name.
Definition Nodes.h:783
const OpNameDecl * nameDecl
The operation name of this constraint.
Definition Nodes.h:790
This Decl represents an OperationName.
Definition Nodes.h:1022
std::optional< StringRef > getName() const
Return the name of this operation, or std::nullopt if the name is unknown.
Definition Nodes.h:1028
static OpNameDecl * create(Context &ctx, const Name &name)
Definition Nodes.cpp:502
This class represents a base operation rewrite statement.
Definition Nodes.h:231
static bool classof(const Node *node)
Provide type casting support.
Definition Nodes.h:1344
OpRewriteStmt(TypeID typeID, SMRange loc, Expr *rootOp)
Definition Nodes.h:240
Expr * rootOp
The root operation being rewritten.
Definition Nodes.h:245
Expr * getRootOpExpr() const
Return the root operation of this rewrite.
Definition Nodes.h:237
This expression represents the structural form of an MLIR Operation.
Definition Nodes.h:512
ArrayRef< NamedAttributeDecl * > getAttributes() const
Definition Nodes.h:551
MutableArrayRef< Expr * > getOperands()
Return the operands of this operation.
Definition Nodes.h:532
ArrayRef< Expr * > getOperands() const
Definition Nodes.h:535
SMRange getNameLoc() const
Return the location of the name of the operation expression, or an invalid location if there isn't a ...
Definition Nodes.h:529
MutableArrayRef< NamedAttributeDecl * > getAttributes()
Return the attributes of this operation.
Definition Nodes.h:548
const OpNameDecl * getNameDecl() const
Return the declaration of the operation name.
Definition Nodes.h:525
MutableArrayRef< Expr * > getResultTypes()
Return the result types of this operation.
Definition Nodes.h:540
static OperationExpr * create(Context &ctx, SMRange loc, const ods::Operation *odsOp, const OpNameDecl *nameDecl, ArrayRef< Expr * > operands, ArrayRef< Expr * > resultTypes, ArrayRef< NamedAttributeDecl * > attributes)
Definition Nodes.cpp:307
MutableArrayRef< Expr * > getResultTypes() const
Definition Nodes.h:543
std::optional< StringRef > getName() const
Return the name of the operation, or std::nullopt if there isn't one.
Definition Nodes.cpp:327
This Decl represents a single Pattern.
Definition Nodes.h:1043
const OpRewriteStmt * getRootRewriteStmt() const
Return the root rewrite statement of this pattern.
Definition Nodes.h:1060
static PatternDecl * create(Context &ctx, SMRange location, const Name *name, std::optional< uint16_t > benefit, bool hasBoundedRecursion, const CompoundStmt *body)
Definition Nodes.cpp:513
const CompoundStmt * getBody() const
Return the body of this pattern.
Definition Nodes.h:1057
std::optional< uint16_t > getBenefit() const
Return the benefit of this pattern if specified, or std::nullopt.
Definition Nodes.h:1051
bool hasBoundedRewriteRecursion() const
Return if this pattern has bounded rewrite recursion.
Definition Nodes.h:1054
This expression builds a range from a set of element values (which may be ranges themselves).
Definition Nodes.h:586
static RangeExpr * create(Context &ctx, SMRange loc, ArrayRef< Expr * > elements, RangeType type)
Definition Nodes.cpp:335
ArrayRef< Expr * > getElements() const
Definition Nodes.h:595
RangeType getType() const
Return the range result type of this expression.
Definition Nodes.h:600
MutableArrayRef< Expr * > getElements()
Return the element expressions of this range.
Definition Nodes.h:592
This class represents a PDLL type that corresponds to a range of elements with a given element type.
Definition Types.h:159
This statement represents the replace statement in PDLL.
Definition Nodes.h:271
MutableArrayRef< Expr * > getReplExprs()
Return the replacement values of this statement.
Definition Nodes.h:277
ArrayRef< Expr * > getReplExprs() const
Definition Nodes.h:280
static ReplaceStmt * create(Context &ctx, SMRange loc, Expr *rootOp, ArrayRef< Expr * > replExprs)
Definition Nodes.cpp:226
This statement represents a return from a "callable" like decl, e.g.
Definition Nodes.h:324
const Expr * getResultExpr() const
Definition Nodes.h:330
static ReturnStmt * create(Context &ctx, SMRange loc, Expr *resultExpr)
Definition Nodes.cpp:250
void setResultExpr(Expr *expr)
Set the result expression of this statement.
Definition Nodes.h:333
Expr * getResultExpr()
Return the result expression of this statement.
Definition Nodes.h:329
This statement represents an operation rewrite that contains a block of nested rewrite commands.
Definition Nodes.h:302
CompoundStmt * getRewriteBody() const
Return the compound rewrite body.
Definition Nodes.h:308
static RewriteStmt * create(Context &ctx, SMRange loc, Expr *rootOp, CompoundStmt *rewriteBody)
Definition Nodes.cpp:240
This class represents a base AST Statement node.
Definition Nodes.h:164
static bool classof(const Node *node)
Provide type casting support.
Definition Nodes.h:1348
Node(TypeID typeID, SMRange loc)
Definition Nodes.h:149
This expression builds a tuple from a set of element values.
Definition Nodes.h:619
TupleType getType() const
Return the tuple result type of this expression.
Definition Nodes.h:633
static TupleExpr * create(Context &ctx, SMRange loc, ArrayRef< Expr * > elements, ArrayRef< StringRef > elementNames)
Definition Nodes.cpp:349
ArrayRef< Expr * > getElements() const
Definition Nodes.h:628
MutableArrayRef< Expr * > getElements()
Return the element expressions of this tuple.
Definition Nodes.h:625
This class represents a PDLL tuple type, i.e.
Definition Types.h:222
The class represents a Type constraint, and constrains a variable to be a Type.
Definition Nodes.h:800
static TypeConstraintDecl * create(Context &ctx, SMRange loc)
Definition Nodes.cpp:412
This expression represents a literal MLIR Type, and contains the textual assembly format of that type...
Definition Nodes.h:648
StringRef getValue() const
Get the raw value of this expression.
Definition Nodes.h:654
static TypeExpr * create(Context &ctx, SMRange loc, StringRef value)
Definition Nodes.cpp:368
The class represents a TypeRange constraint, and constrains a variable to be a TypeRange.
Definition Nodes.h:815
static TypeRangeConstraintDecl * create(Context &ctx, SMRange loc)
Definition Nodes.cpp:421
This class represents a PDLL type that corresponds to an mlir::Type.
Definition Types.h:249
This decl represents a user defined constraint.
Definition Nodes.h:888
static UserConstraintDecl * createNative(Context &ctx, const Name &name, ArrayRef< VariableDecl * > inputs, ArrayRef< VariableDecl * > results, std::optional< StringRef > codeBlock, Type resultType, ArrayRef< StringRef > nativeInputTypes={})
Create a native constraint with the given optional code block.
Definition Nodes.h:892
const Name & getName() const
Return the name of the constraint.
Definition Nodes.h:911
ArrayRef< VariableDecl * > getInputs() const
Definition Nodes.h:917
const CompoundStmt * getBody() const
Return the body of this constraint if this constraint is a PDLL constraint, otherwise returns nullptr...
Definition Nodes.h:940
bool isExternal() const
Returns true if this constraint is external.
Definition Nodes.h:946
std::optional< StringRef > getNativeInputType(unsigned index) const
Return the explicit native type to use for the given input.
Definition Nodes.cpp:452
MutableArrayRef< VariableDecl * > getResults()
Return the explicit results of the constraint declaration.
Definition Nodes.h:927
ArrayRef< VariableDecl * > getResults() const
Definition Nodes.h:930
static UserConstraintDecl * createPDLL(Context &ctx, const Name &name, ArrayRef< VariableDecl * > inputs, ArrayRef< VariableDecl * > results, const CompoundStmt *body, Type resultType)
Create a PDLL constraint with the given body.
Definition Nodes.h:901
MutableArrayRef< VariableDecl * > getInputs()
Return the input arguments of this constraint.
Definition Nodes.h:914
Type getResultType() const
Return the result type of this constraint.
Definition Nodes.h:943
std::optional< StringRef > getCodeBlock() const
Return the optional code block of this constraint, if this is a native constraint with a provided imp...
Definition Nodes.h:936
This decl represents a user defined rewrite.
Definition Nodes.h:1098
std::optional< StringRef > getCodeBlock() const
Return the optional code block of this rewrite, if this is a native rewrite with a provided implement...
Definition Nodes.h:1142
const Name & getName() const
Return the name of the rewrite.
Definition Nodes.h:1121
ArrayRef< VariableDecl * > getResults() const
Definition Nodes.h:1136
const CompoundStmt * getBody() const
Return the body of this rewrite if this rewrite is a PDLL rewrite, otherwise returns nullptr.
Definition Nodes.h:1146
static UserRewriteDecl * createNative(Context &ctx, const Name &name, ArrayRef< VariableDecl * > inputs, ArrayRef< VariableDecl * > results, std::optional< StringRef > codeBlock, Type resultType)
Create a native rewrite with the given optional code block.
Definition Nodes.h:1101
Type getResultType() const
Return the result type of this rewrite.
Definition Nodes.h:1149
ArrayRef< VariableDecl * > getInputs() const
Definition Nodes.h:1127
MutableArrayRef< VariableDecl * > getResults()
Return the explicit results of the rewrite declaration.
Definition Nodes.h:1133
bool isExternal() const
Returns true if this rewrite is external.
Definition Nodes.h:1152
static UserRewriteDecl * createPDLL(Context &ctx, const Name &name, ArrayRef< VariableDecl * > inputs, ArrayRef< VariableDecl * > results, const CompoundStmt *body, Type resultType)
Create a PDLL rewrite with the given body.
Definition Nodes.h:1111
MutableArrayRef< VariableDecl * > getInputs()
Return the input arguments of this rewrite.
Definition Nodes.h:1124
The class represents a Value constraint, and constrains a variable to be a Value.
Definition Nodes.h:830
static ValueConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr)
Definition Nodes.cpp:431
Expr * typeExpr
An optional type that the value is constrained to.
Definition Nodes.h:843
ValueConstraintDecl(SMRange loc, Expr *typeExpr)
Definition Nodes.h:839
const Expr * getTypeExpr() const
Definition Nodes.h:836
Expr * getTypeExpr()
Return the optional type the value is constrained to.
Definition Nodes.h:835
The class represents a ValueRange constraint, and constrains a variable to be a ValueRange.
Definition Nodes.h:853
static ValueRangeConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr=nullptr)
Definition Nodes.cpp:442
ValueRangeConstraintDecl(SMRange loc, Expr *typeExpr)
Definition Nodes.h:863
Expr * typeExpr
An optional type that the value range is constrained to.
Definition Nodes.h:867
Expr * getTypeExpr()
Return the optional type the value range is constrained to.
Definition Nodes.h:859
This Decl represents the definition of a PDLL variable.
Definition Nodes.h:1248
MutableArrayRef< ConstraintRef > getConstraints()
Return the constraints of this variable.
Definition Nodes.h:1255
static VariableDecl * create(Context &ctx, const Name &name, Type type, Expr *initExpr, ArrayRef< ConstraintRef > constraints)
Definition Nodes.cpp:549
Expr * getInitExpr() const
Return the initializer expression of this statement, or nullptr if there was no initializer.
Definition Nodes.h:1264
const Name & getName() const
Return the name of the decl.
Definition Nodes.h:1267
ArrayRef< ConstraintRef > getConstraints() const
Definition Nodes.h:1258
Type getType() const
Return the type of the decl.
Definition Nodes.h:1270
This class provides an ODS representation of a specific operation.
Definition Operation.h:125
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
ConstraintRef(const ConstraintDecl *constraint, SMRange refLoc)
Definition Nodes.h:717
const ConstraintDecl * constraint
Definition Nodes.h:722
ConstraintRef(const ConstraintDecl *constraint)
Definition Nodes.h:719
This class provides a convenient API for interacting with source names.
Definition Nodes.h:37
StringRef getName() const
Return the raw string name.
Definition Nodes.h:41
SMRange getLoc() const
Get the location of this name.
Definition Nodes.h:44
static const Name & create(Context &ctx, StringRef name, SMRange location)
Definition Nodes.cpp:33