9#ifndef MLIR_TOOLS_PDLL_AST_NODES_H_ 
   10#define MLIR_TOOLS_PDLL_AST_NODES_H_ 
   14#include "llvm/ADT/StringMap.h" 
   15#include "llvm/ADT/StringRef.h" 
   16#include "llvm/Support/SMLoc.h" 
   17#include "llvm/Support/SourceMgr.h" 
   18#include "llvm/Support/TrailingObjects.h" 
   38  static const Name &
create(
Context &ctx, StringRef name, SMRange location);
 
   41  StringRef 
getName()
 const { 
return name; }
 
   44  SMRange 
getLoc()
 const { 
return location; }
 
   49  Name &operator=(
const Name &) = 
delete;
 
   50  Name(StringRef name, SMRange location) : name(name), location(location) {}
 
 
   74  auto getDecls()
 const { 
return llvm::make_second_range(decls); }
 
   84    return dyn_cast_or_null<T>(
lookup(name));
 
 
   90  const T *
lookup(StringRef name)
 const {
 
   91    return dyn_cast_or_null<T>(
lookup(name));
 
 
   98  llvm::StringMap<Decl *> decls;
 
 
  111  template <
typename T, 
typename BaseT>
 
  122    template <
typename... Args>
 
  124        : BaseT(
TypeID::
get<T>(), loc, std::forward<Args>(args)...) {}
 
 
 
  138  template <
typename WalkFnT, 
typename ArgT = 
typename llvm::function_traits<
 
  139                                  WalkFnT>::template arg_t<0>>
 
  140  std::enable_if_t<!std::is_convertible<const Node *, ArgT>::value>
 
  143      if (
const ArgT *derivedNode = dyn_cast<ArgT>(node))
 
 
  149  Node(
TypeID typeID, SMRange loc) : typeID(typeID), loc(loc) {}
 
 
  178class CompoundStmt final : 
public Node::NodeBase<CompoundStmt, Stmt>,
 
  179                           private llvm::TrailingObjects<CompoundStmt, Stmt *> {
 
  186    return getTrailingObjects(numChildren);
 
 
  189    return getTrailingObjects(numChildren);
 
 
  196      : 
Base(location), numChildren(numChildren) {}
 
  199  unsigned numChildren;
 
  202  friend class llvm::TrailingObjects<CompoundStmt, 
Stmt *>;
 
 
  211class LetStmt final : 
public Node::NodeBase<LetStmt, Stmt> {
 
  222  VariableDecl *varDecl;
 
 
  255class EraseStmt final : 
public Node::NodeBase<EraseStmt, OpRewriteStmt> {
 
 
  270class ReplaceStmt final : 
public Node::NodeBase<ReplaceStmt, OpRewriteStmt>,
 
  271                          private llvm::TrailingObjects<ReplaceStmt, Expr *> {
 
  278    return getTrailingObjects(numReplExprs);
 
 
  281    return getTrailingObjects(numReplExprs);
 
 
  286      : 
Base(loc, 
rootOp), numReplExprs(numReplExprs) {}
 
  289  unsigned numReplExprs;
 
  292  friend class llvm::TrailingObjects<ReplaceStmt, 
Expr *>;
 
 
  302class RewriteStmt final : 
public Node::NodeBase<RewriteStmt, OpRewriteStmt> {
 
  312      : 
Base(loc, 
rootOp), rewriteBody(rewriteBody) {}
 
  315  CompoundStmt *rewriteBody;
 
 
  324class ReturnStmt final : 
public Node::NodeBase<ReturnStmt, Stmt> {
 
  337      : 
Base(loc), resultExpr(resultExpr) {}
 
 
  370class AttributeExpr : 
public Node::NodeBase<AttributeExpr, Expr> {
 
  372  static AttributeExpr *
create(
Context &ctx, SMRange loc, StringRef value);
 
 
  392class CallExpr final : 
public Node::NodeBase<CallExpr, Expr>,
 
  393                       private llvm::TrailingObjects<CallExpr, Expr *> {
 
  397                          bool isNegated = 
false);
 
  412      : 
Base(loc, type), callable(callable), numArgs(numArgs),
 
  413        isNegated(isNegated) {}
 
  422  friend llvm::TrailingObjects<CallExpr, Expr *>;
 
 
  433class DeclRefExpr : 
public Node::NodeBase<DeclRefExpr, Expr> {
 
  442      : 
Base(loc, type), decl(decl) {}
 
 
  454class MemberAccessExpr : 
public Node::NodeBase<MemberAccessExpr, Expr> {
 
  457                                  const Expr *parentExpr, StringRef memberName,
 
  469      : 
Base(loc, type), parentExpr(parentExpr), memberName(memberName) {}
 
  472  const Expr *parentExpr;
 
  475  StringRef memberName;
 
 
  490                                            const Expr *parentExpr, 
Type type) {
 
  491    return cast<AllResultsMemberAccessExpr>(
 
 
  497    const MemberAccessExpr *memAccess = dyn_cast<MemberAccessExpr>(node);
 
 
 
  509class OperationExpr final
 
  510    : 
public Node::NodeBase<OperationExpr, Expr>,
 
  511      private llvm::TrailingObjects<OperationExpr, Expr *,
 
  512                                    NamedAttributeDecl *> {
 
  522  std::optional<StringRef> 
getName() 
const;
 
  533    return getTrailingObjects<Expr *>(numOperands);
 
 
  536    return getTrailingObjects<Expr *>(numOperands);
 
 
  541    return {getTrailingObjects<Expr *>() + numOperands, numResultTypes};
 
 
  549    return getTrailingObjects<NamedAttributeDecl *>(numAttributes);
 
 
  552    return getTrailingObjects<NamedAttributeDecl *>(numAttributes);
 
 
  557                unsigned numOperands, 
unsigned numResultTypes,
 
  558                unsigned numAttributes, SMRange nameLoc)
 
  559      : 
Base(loc, type), nameDecl(nameDecl), numOperands(numOperands),
 
  560        numResultTypes(numResultTypes), numAttributes(numAttributes),
 
  564  const OpNameDecl *nameDecl;
 
  567  unsigned numOperands, numResultTypes, numAttributes;
 
  573  friend llvm::TrailingObjects<OperationExpr, Expr *, NamedAttributeDecl *>;
 
  574  size_t numTrailingObjects(OverloadToken<Expr *>)
 const {
 
  575    return numOperands + numResultTypes;
 
 
  585class RangeExpr final : 
public Node::NodeBase<RangeExpr, Expr>,
 
  586                        private llvm::TrailingObjects<RangeExpr, Expr *> {
 
  593    return getTrailingObjects(numElements);
 
 
  596    return getTrailingObjects(numElements);
 
 
  604      : 
Base(loc, type), numElements(numElements) {}
 
  607  unsigned numElements;
 
  610  friend class llvm::TrailingObjects<RangeExpr, 
Expr *>;
 
 
  618class TupleExpr final : 
public Node::NodeBase<TupleExpr, Expr>,
 
  619                        private llvm::TrailingObjects<TupleExpr, Expr *> {
 
  626    return getTrailingObjects(
getType().size());
 
 
  629    return getTrailingObjects(
getType().size());
 
 
 
  648class TypeExpr : 
public Node::NodeBase<TypeExpr, Expr> {
 
  650  static TypeExpr *
create(
Context &ctx, SMRange loc, StringRef value);
 
 
  686      : 
Node(typeID, loc), name(name) {}
 
 
  695  std::optional<StringRef> docComment;
 
 
  711      : 
Decl(typeID, loc, name) {}
 
 
 
  750    : 
public Node::NodeBase<AttrConstraintDecl, CoreConstraintDecl> {
 
 
  774    : 
public Node::NodeBase<OpConstraintDecl, CoreConstraintDecl> {
 
  780  std::optional<StringRef> 
getName() 
const;
 
 
  800    : 
public Node::NodeBase<TypeConstraintDecl, CoreConstraintDecl> {
 
 
  815    : 
public Node::NodeBase<TypeRangeConstraintDecl, CoreConstraintDecl> {
 
 
  830    : 
public Node::NodeBase<ValueConstraintDecl, CoreConstraintDecl> {
 
 
  853    : 
public Node::NodeBase<ValueRangeConstraintDecl, CoreConstraintDecl> {
 
 
  886class UserConstraintDecl final
 
  887    : 
public Node::NodeBase<UserConstraintDecl, ConstraintDecl>,
 
  888      llvm::TrailingObjects<UserConstraintDecl, VariableDecl *, StringRef> {
 
  891  static UserConstraintDecl *
 
  894               std::optional<StringRef> codeBlock, 
Type resultType,
 
  896    return createImpl(ctx, name, inputs, nativeInputTypes, results, codeBlock,
 
  897                      nullptr, resultType);
 
 
  906    return createImpl(ctx, name, inputs, {}, results,
 
  907                      std::nullopt, body, resultType);
 
 
  915    return getTrailingObjects<VariableDecl *>(numInputs);
 
 
  918    return getTrailingObjects<VariableDecl *>(numInputs);
 
 
  928    return {getTrailingObjects<VariableDecl *>() + numInputs, numResults};
 
 
  931    return const_cast<UserConstraintDecl *
>(
this)->
getResults();
 
 
  946  bool isExternal()
 const { 
return !constraintBody && !codeBlock; }
 
  955                                        std::optional<StringRef> codeBlock,
 
  960                     bool hasNativeInputTypes, 
unsigned numResults,
 
  961                     std::optional<StringRef> codeBlock,
 
  963      : 
Base(name.
getLoc(), &name), numInputs(numInputs),
 
  964        numResults(numResults), codeBlock(codeBlock), constraintBody(body),
 
  965        resultType(resultType), hasNativeInputTypes(hasNativeInputTypes) {}
 
  974  std::optional<StringRef> codeBlock;
 
  977  const CompoundStmt *constraintBody;
 
  983  bool hasNativeInputTypes;
 
  986  friend llvm::TrailingObjects<UserConstraintDecl, VariableDecl *, StringRef>;
 
  987  size_t numTrailingObjects(OverloadToken<VariableDecl *>)
 const {
 
  988    return numInputs + numResults;
 
 
  998class NamedAttributeDecl : 
public Node::NodeBase<NamedAttributeDecl, Decl> {
 
 1011      : 
Base(name.
getLoc(), &name), value(value) {}
 
 
 1022class OpNameDecl : 
public Node::NodeBase<OpNameDecl, Decl> {
 
 1030    return name ? std::optional<StringRef>(name->getName()) : std::nullopt;
 
 
 1035  explicit OpNameDecl(SMRange loc) : 
Base(loc) {}
 
 
 1043class PatternDecl : 
public Node::NodeBase<PatternDecl, Decl> {
 
 1046                             std::optional<uint16_t> benefit,
 
 1047                             bool hasBoundedRecursion,
 
 1061    return cast<OpRewriteStmt>(patternBody->getChildren().back());
 
 
 1065  PatternDecl(SMRange loc, 
const Name *name, std::optional<uint16_t> benefit,
 
 1067      : 
Base(loc, name), benefit(benefit),
 
 1068        hasBoundedRecursion(hasBoundedRecursion), patternBody(body) {}
 
 1072  std::optional<uint16_t> benefit;
 
 1075  bool hasBoundedRecursion;
 
 1078  const CompoundStmt *patternBody;
 
 
 1096class UserRewriteDecl final
 
 1097    : 
public Node::NodeBase<UserRewriteDecl, Decl>,
 
 1098      llvm::TrailingObjects<UserRewriteDecl, VariableDecl *> {
 
 1104                                       std::optional<StringRef> codeBlock,
 
 1106    return createImpl(ctx, name, inputs, results, codeBlock, 
nullptr,
 
 
 1116    return createImpl(ctx, name, inputs, results, std::nullopt,
 
 
 1125    return getTrailingObjects(numInputs);
 
 
 1128    return getTrailingObjects(numInputs);
 
 
 1134    return {getTrailingObjects() + numInputs, numResults};
 
 
 1137    return const_cast<UserRewriteDecl *
>(
this)->
getResults();
 
 
 1160                                     std::optional<StringRef> codeBlock,
 
 1164                  std::optional<StringRef> codeBlock, 
const CompoundStmt *body,
 
 1166      : 
Base(name.
getLoc(), &name), numInputs(numInputs),
 
 1167        numResults(numResults), codeBlock(codeBlock), rewriteBody(body),
 
 1168        resultType(resultType) {}
 
 1174  unsigned numResults;
 
 1177  std::optional<StringRef> codeBlock;
 
 1180  const CompoundStmt *rewriteBody;
 
 1186  friend llvm::TrailingObjects<UserRewriteDecl, VariableDecl *>;
 
 
 1198    if (isa<UserConstraintDecl>(
this))
 
 1199      return "constraint";
 
 1200    assert(isa<UserRewriteDecl>(
this) && 
"unknown callable type");
 
 
 1206    if (
const auto *cst = dyn_cast<UserConstraintDecl>(
this))
 
 1207      return cst->getInputs();
 
 1208    return cast<UserRewriteDecl>(
this)->getInputs();
 
 
 1213    if (
const auto *cst = dyn_cast<UserConstraintDecl>(
this))
 
 1214      return cst->getResultType();
 
 1215    return cast<UserRewriteDecl>(
this)->getResultType();
 
 
 1222    if (
const auto *cst = dyn_cast<UserConstraintDecl>(
this))
 
 1223      return cst->getResults();
 
 1224    return cast<UserRewriteDecl>(
this)->getResults();
 
 
 1230    if (
const auto *cst = dyn_cast<UserConstraintDecl>(
this))
 
 1231      return cst->getCodeBlock();
 
 1232    return cast<UserRewriteDecl>(
this)->getCodeBlock();
 
 
 1237    return isa<UserConstraintDecl, UserRewriteDecl>(decl);
 
 
 
 1246class VariableDecl final
 
 1247    : 
public Node::NodeBase<VariableDecl, Decl>,
 
 1248      private llvm::TrailingObjects<VariableDecl, ConstraintRef> {
 
 1256    return getTrailingObjects(numConstraints);
 
 
 1259    return getTrailingObjects(numConstraints);
 
 
 1274               unsigned numConstraints)
 
 1275      : 
Base(name.
getLoc(), &name), type(type), initExpr(initExpr),
 
 1276        numConstraints(numConstraints) {}
 
 1285  unsigned numConstraints;
 
 1288  friend llvm::TrailingObjects<VariableDecl, ConstraintRef>;
 
 
 1296class Module final : 
public Node::NodeBase<Module, Node>,
 
 1297                     private llvm::TrailingObjects<Module, Decl *> {
 
 1303    return getTrailingObjects(numChildren);
 
 
 1306    return getTrailingObjects(numChildren);
 
 
 1310  Module(SMLoc loc, 
unsigned numChildren)
 
 1311      : 
Base(SMRange{loc, loc}), numChildren(numChildren) {}
 
 1314  unsigned numChildren;
 
 1317  friend llvm::TrailingObjects<Module, Decl *>;
 
 
 1330  return isa<CoreConstraintDecl, UserConstraintDecl>(node);
 
 
 1345  return isa<EraseStmt, ReplaceStmt, RewriteStmt>(node);
 
 
 1349  return isa<CompoundStmt, LetStmt, OpRewriteStmt, Expr>(node);
 
 
This class provides an efficient unique identifier for a specific C++ type.
 
static TypeID get()
Construct a type info object for the given type T.
 
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
 
This class represents an instance of MemberAccessExpr that references all results of an operation.
 
static StringRef getMemberName()
Return the member name used for the "all-results" access.
 
static bool classof(const Node *node)
Provide type casting support.
 
static AllResultsMemberAccessExpr * create(Context &ctx, SMRange loc, const Expr *parentExpr, Type type)
 
The class represents an Attribute constraint, and constrains a variable to be an Attribute.
 
Expr * getTypeExpr()
Return the optional type the attribute is constrained to.
 
Expr * typeExpr
An optional type that the attribute is constrained to.
 
static AttrConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr=nullptr)
 
const Expr * getTypeExpr() const
 
AttrConstraintDecl(SMRange loc, Expr *typeExpr)
 
This expression represents a literal MLIR Attribute, and contains the textual assembly format of that...
 
static AttributeExpr * create(Context &ctx, SMRange loc, StringRef value)
 
StringRef getValue() const
Get the raw value of this expression.
 
This class represents a PDLL type that corresponds to an mlir::Attribute.
 
This expression represents a call to a decl, such as a UserConstraintDecl/UserRewriteDecl.
 
ArrayRef< Expr * > getArguments() const
 
Expr * getCallableExpr() const
Return the callable of this call.
 
MutableArrayRef< Expr * > getArguments()
Return the arguments of this call.
 
static CallExpr * create(Context &ctx, SMRange loc, Expr *callable, ArrayRef< Expr * > arguments, Type resultType, bool isNegated=false)
 
bool getIsNegated() const
Returns whether the result of this call is to be negated.
 
This decl represents a shared interface for all callable decls.
 
std::optional< StringRef > getCodeBlock() const
Return the optional code block of this callable, if this is a native callable with a provided impleme...
 
Type getResultType() const
Return the result type of this decl.
 
ArrayRef< VariableDecl * > getResults() const
Return the explicit results of the declaration.
 
StringRef getCallableType() const
Return the callable type of this decl.
 
static bool classof(const Node *decl)
Support LLVM type casting facilities.
 
ArrayRef< VariableDecl * > getInputs() const
Return the inputs of this decl.
 
This statement represents a compound statement, which contains a collection of other statements.
 
ArrayRef< Stmt * > getChildren() const
 
ArrayRef< Stmt * >::iterator end() const
 
MutableArrayRef< Stmt * > getChildren()
Return the children of this compound statement.
 
ArrayRef< Stmt * >::iterator begin() const
 
static CompoundStmt * create(Context &ctx, SMRange location, ArrayRef< Stmt * > children)
 
This class represents the base of all AST Constraint decls.
 
static bool classof(const Node *node)
Provide type casting support.
 
ConstraintDecl(TypeID typeID, SMRange loc, const Name *name=nullptr)
 
This class represents the main context of the PDLL AST.
 
static bool classof(const Node *node)
Provide type casting support.
 
CoreConstraintDecl(TypeID typeID, SMRange loc, const Name *name=nullptr)
 
This expression represents a reference to a Decl node.
 
Decl * getDecl() const
Get the decl referenced by this expression.
 
static DeclRefExpr * create(Context &ctx, SMRange loc, Decl *decl, Type type)
 
This class represents a scope for named AST decls.
 
const DeclScope * getParentScope() const
 
auto getDecls() const
Return all of the decls within this scope.
 
const T * lookup(StringRef name) const
 
const Decl * lookup(StringRef name) const
 
Decl * lookup(StringRef name)
Lookup a decl with the given name starting from this scope.
 
DeclScope(DeclScope *parent=nullptr)
Create a new scope with an optional parent scope.
 
T * lookup(StringRef name)
 
DeclScope * getParentScope()
Return the parent scope of this scope, or nullptr if there is no parent.
 
This class represents the base Decl node.
 
std::optional< StringRef > getDocComment() const
Return the documentation comment attached to this decl if it has been set.
 
Decl(TypeID typeID, SMRange loc, const Name *name=nullptr)
 
static bool classof(const Node *node)
Provide type casting support.
 
void setDocComment(Context &ctx, StringRef comment)
Set the documentation comment for this decl.
 
const Name * getName() const
Return the name of the decl, or nullptr if it doesn't have one.
 
static EraseStmt * create(Context &ctx, SMRange loc, Expr *rootOp)
 
This class represents a base AST Expression node.
 
Expr(TypeID typeID, SMRange loc, Type type)
 
static bool classof(const Node *node)
Provide type casting support.
 
Type getType() const
Return the type of this expression.
 
This statement represents a let statement in PDLL.
 
VariableDecl * getVarDecl() const
Return the variable defined by this statement.
 
static LetStmt * create(Context &ctx, SMRange loc, VariableDecl *varDecl)
 
This expression represents a named member or field access of a given parent expression.
 
static MemberAccessExpr * create(Context &ctx, SMRange loc, const Expr *parentExpr, StringRef memberName, Type type)
 
const Expr * getParentExpr() const
Get the parent expression of this access.
 
StringRef getMemberName() const
Return the name of the member being accessed.
 
This class represents a top-level AST module.
 
static Module * create(Context &ctx, SMLoc loc, ArrayRef< Decl * > children)
 
ArrayRef< Decl * > getChildren() const
 
MutableArrayRef< Decl * > getChildren()
Return the children of this module.
 
This Decl represents a NamedAttribute, and contains a string name and attribute value.
 
const Name & getName() const
Return the name of the attribute.
 
static NamedAttributeDecl * create(Context &ctx, const Name &name, Expr *value)
 
Expr * getValue() const
Return value of the attribute.
 
static bool classof(const Node *node)
Provide type casting support.
 
NodeBase< T, BaseT > Base
 
NodeBase(SMRange loc, Args &&...args)
 
This class represents a base AST node.
 
Node(TypeID typeID, SMRange loc)
 
void walk(function_ref< void(const Node *)> walkFn) const
Walk all of the nodes including, and nested under, this node in pre-order.
 
std::enable_if_t<!std::is_convertible< const Node *, ArgT >::value > walk(WalkFnT &&walkFn) const
 
SMRange getLoc() const
Return the location of this node.
 
void print(raw_ostream &os) const
Print this node to the given stream.
 
TypeID getTypeID() const
Return the type identifier of this node.
 
The class represents an Operation constraint, and constrains a variable to be an Operation.
 
static OpConstraintDecl * create(Context &ctx, SMRange loc, const OpNameDecl *nameDecl=nullptr)
 
std::optional< StringRef > getName() const
Return the name of the operation, or std::nullopt if there isn't one.
 
OpConstraintDecl(SMRange loc, const OpNameDecl *nameDecl)
 
const OpNameDecl * getNameDecl() const
Return the declaration of the operation name.
 
const OpNameDecl * nameDecl
The operation name of this constraint.
 
This Decl represents an OperationName.
 
std::optional< StringRef > getName() const
Return the name of this operation, or std::nullopt if the name is unknown.
 
static OpNameDecl * create(Context &ctx, const Name &name)
 
This class represents a base operation rewrite statement.
 
static bool classof(const Node *node)
Provide type casting support.
 
OpRewriteStmt(TypeID typeID, SMRange loc, Expr *rootOp)
 
Expr * rootOp
The root operation being rewritten.
 
Expr * getRootOpExpr() const
Return the root operation of this rewrite.
 
This expression represents the structural form of an MLIR Operation.
 
ArrayRef< NamedAttributeDecl * > getAttributes() const
 
MutableArrayRef< Expr * > getOperands()
Return the operands of this operation.
 
ArrayRef< Expr * > getOperands() const
 
SMRange getNameLoc() const
Return the location of the name of the operation expression, or an invalid location if there isn't a ...
 
MutableArrayRef< NamedAttributeDecl * > getAttributes()
Return the attributes of this operation.
 
const OpNameDecl * getNameDecl() const
Return the declaration of the operation name.
 
MutableArrayRef< Expr * > getResultTypes()
Return the result types of this operation.
 
static OperationExpr * create(Context &ctx, SMRange loc, const ods::Operation *odsOp, const OpNameDecl *nameDecl, ArrayRef< Expr * > operands, ArrayRef< Expr * > resultTypes, ArrayRef< NamedAttributeDecl * > attributes)
 
MutableArrayRef< Expr * > getResultTypes() const
 
std::optional< StringRef > getName() const
Return the name of the operation, or std::nullopt if there isn't one.
 
This Decl represents a single Pattern.
 
const OpRewriteStmt * getRootRewriteStmt() const
Return the root rewrite statement of this pattern.
 
static PatternDecl * create(Context &ctx, SMRange location, const Name *name, std::optional< uint16_t > benefit, bool hasBoundedRecursion, const CompoundStmt *body)
 
const CompoundStmt * getBody() const
Return the body of this pattern.
 
std::optional< uint16_t > getBenefit() const
Return the benefit of this pattern if specified, or std::nullopt.
 
bool hasBoundedRewriteRecursion() const
Return if this pattern has bounded rewrite recursion.
 
This expression builds a range from a set of element values (which may be ranges themselves).
 
static RangeExpr * create(Context &ctx, SMRange loc, ArrayRef< Expr * > elements, RangeType type)
 
ArrayRef< Expr * > getElements() const
 
RangeType getType() const
Return the range result type of this expression.
 
MutableArrayRef< Expr * > getElements()
Return the element expressions of this range.
 
This class represents a PDLL type that corresponds to a range of elements with a given element type.
 
This statement represents the replace statement in PDLL.
 
MutableArrayRef< Expr * > getReplExprs()
Return the replacement values of this statement.
 
ArrayRef< Expr * > getReplExprs() const
 
static ReplaceStmt * create(Context &ctx, SMRange loc, Expr *rootOp, ArrayRef< Expr * > replExprs)
 
This statement represents a return from a "callable" like decl, e.g.
 
const Expr * getResultExpr() const
 
static ReturnStmt * create(Context &ctx, SMRange loc, Expr *resultExpr)
 
void setResultExpr(Expr *expr)
Set the result expression of this statement.
 
Expr * getResultExpr()
Return the result expression of this statement.
 
This statement represents an operation rewrite that contains a block of nested rewrite commands.
 
CompoundStmt * getRewriteBody() const
Return the compound rewrite body.
 
static RewriteStmt * create(Context &ctx, SMRange loc, Expr *rootOp, CompoundStmt *rewriteBody)
 
This class represents a base AST Statement node.
 
static bool classof(const Node *node)
Provide type casting support.
 
Node(TypeID typeID, SMRange loc)
 
This expression builds a tuple from a set of element values.
 
TupleType getType() const
Return the tuple result type of this expression.
 
static TupleExpr * create(Context &ctx, SMRange loc, ArrayRef< Expr * > elements, ArrayRef< StringRef > elementNames)
 
ArrayRef< Expr * > getElements() const
 
MutableArrayRef< Expr * > getElements()
Return the element expressions of this tuple.
 
This class represents a PDLL tuple type, i.e.
 
The class represents a Type constraint, and constrains a variable to be a Type.
 
static TypeConstraintDecl * create(Context &ctx, SMRange loc)
 
This expression represents a literal MLIR Type, and contains the textual assembly format of that type...
 
StringRef getValue() const
Get the raw value of this expression.
 
static TypeExpr * create(Context &ctx, SMRange loc, StringRef value)
 
The class represents a TypeRange constraint, and constrains a variable to be a TypeRange.
 
static TypeRangeConstraintDecl * create(Context &ctx, SMRange loc)
 
This class represents a PDLL type that corresponds to an mlir::Type.
 
This decl represents a user defined constraint.
 
static UserConstraintDecl * createNative(Context &ctx, const Name &name, ArrayRef< VariableDecl * > inputs, ArrayRef< VariableDecl * > results, std::optional< StringRef > codeBlock, Type resultType, ArrayRef< StringRef > nativeInputTypes={})
Create a native constraint with the given optional code block.
 
const Name & getName() const
Return the name of the constraint.
 
ArrayRef< VariableDecl * > getInputs() const
 
const CompoundStmt * getBody() const
Return the body of this constraint if this constraint is a PDLL constraint, otherwise returns nullptr...
 
bool isExternal() const
Returns true if this constraint is external.
 
std::optional< StringRef > getNativeInputType(unsigned index) const
Return the explicit native type to use for the given input.
 
MutableArrayRef< VariableDecl * > getResults()
Return the explicit results of the constraint declaration.
 
ArrayRef< VariableDecl * > getResults() const
 
static UserConstraintDecl * createPDLL(Context &ctx, const Name &name, ArrayRef< VariableDecl * > inputs, ArrayRef< VariableDecl * > results, const CompoundStmt *body, Type resultType)
Create a PDLL constraint with the given body.
 
MutableArrayRef< VariableDecl * > getInputs()
Return the input arguments of this constraint.
 
Type getResultType() const
Return the result type of this constraint.
 
std::optional< StringRef > getCodeBlock() const
Return the optional code block of this constraint, if this is a native constraint with a provided imp...
 
This decl represents a user defined rewrite.
 
std::optional< StringRef > getCodeBlock() const
Return the optional code block of this rewrite, if this is a native rewrite with a provided implement...
 
const Name & getName() const
Return the name of the rewrite.
 
ArrayRef< VariableDecl * > getResults() const
 
const CompoundStmt * getBody() const
Return the body of this rewrite if this rewrite is a PDLL rewrite, otherwise returns nullptr.
 
static UserRewriteDecl * createNative(Context &ctx, const Name &name, ArrayRef< VariableDecl * > inputs, ArrayRef< VariableDecl * > results, std::optional< StringRef > codeBlock, Type resultType)
Create a native rewrite with the given optional code block.
 
Type getResultType() const
Return the result type of this rewrite.
 
ArrayRef< VariableDecl * > getInputs() const
 
MutableArrayRef< VariableDecl * > getResults()
Return the explicit results of the rewrite declaration.
 
bool isExternal() const
Returns true if this rewrite is external.
 
static UserRewriteDecl * createPDLL(Context &ctx, const Name &name, ArrayRef< VariableDecl * > inputs, ArrayRef< VariableDecl * > results, const CompoundStmt *body, Type resultType)
Create a PDLL rewrite with the given body.
 
MutableArrayRef< VariableDecl * > getInputs()
Return the input arguments of this rewrite.
 
The class represents a Value constraint, and constrains a variable to be a Value.
 
static ValueConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr)
 
Expr * typeExpr
An optional type that the value is constrained to.
 
ValueConstraintDecl(SMRange loc, Expr *typeExpr)
 
const Expr * getTypeExpr() const
 
Expr * getTypeExpr()
Return the optional type the value is constrained to.
 
The class represents a ValueRange constraint, and constrains a variable to be a ValueRange.
 
static ValueRangeConstraintDecl * create(Context &ctx, SMRange loc, Expr *typeExpr=nullptr)
 
const Expr * getTypeExpr() const
 
ValueRangeConstraintDecl(SMRange loc, Expr *typeExpr)
 
Expr * typeExpr
An optional type that the value range is constrained to.
 
Expr * getTypeExpr()
Return the optional type the value range is constrained to.
 
This Decl represents the definition of a PDLL variable.
 
MutableArrayRef< ConstraintRef > getConstraints()
Return the constraints of this variable.
 
static VariableDecl * create(Context &ctx, const Name &name, Type type, Expr *initExpr, ArrayRef< ConstraintRef > constraints)
 
Expr * getInitExpr() const
Return the initializer expression of this statement, or nullptr if there was no initializer.
 
const Name & getName() const
Return the name of the decl.
 
ArrayRef< ConstraintRef > getConstraints() const
 
Type getType() const
Return the type of the decl.
 
This class provides an ODS representation of a specific operation.
 
Include the generated interface declarations.
 
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
 
llvm::function_ref< Fn > function_ref
 
ConstraintRef(const ConstraintDecl *constraint, SMRange refLoc)
 
const ConstraintDecl * constraint
 
ConstraintRef(const ConstraintDecl *constraint)
 
This class provides a convenient API for interacting with source names.
 
StringRef getName() const
Return the raw string name.
 
SMRange getLoc() const
Get the location of this name.
 
static const Name & create(Context &ctx, StringRef name, SMRange location)