32 #ifndef MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_
33 #define MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_
41 namespace pdl_to_pdl_interp {
42 namespace Predicates {
82 template <
typename ConcreteT,
typename BaseT,
typename Key,
89 template <
typename KeyT>
94 template <
typename... Args>
96 return uniquer.
get<ConcreteT>({}, std::forward<Args>(args)...);
100 template <
typename KeyT>
103 return new (alloc.
allocate<ConcreteT>()) ConcreteT(std::forward<KeyT>(
key));
118 template <
typename ConcreteT,
typename BaseT, Predicates::Kind Kind>
126 return uniquer.
get<ConcreteT>();
135 struct OperationPosition;
173 std::pair<OperationPosition *, StringAttr>,
174 Predicates::AttributePos> {
186 :
public PredicateBase<AttributeLiteralPosition, Position, Attribute,
187 Predicates::AttributeLiteralPos> {
196 std::pair<Position *, unsigned>,
197 Predicates::ForEachPos> {
211 std::pair<OperationPosition *, unsigned>,
212 Predicates::OperandPos> {
225 OperandGroupPosition, Position,
226 std::tuple<OperationPosition *, std::optional<unsigned>, bool>,
227 Predicates::OperandGroupPos> {
238 return std::get<1>(
key);
252 std::pair<Position *, unsigned>,
253 Predicates::OperationPos> {
293 std::pair<ConstraintQuestion *, unsigned>,
294 Predicates::ConstraintResultPos> {
311 std::pair<OperationPosition *, unsigned>,
312 Predicates::ResultPos> {
325 ResultGroupPosition, Position,
326 std::tuple<OperationPosition *, std::optional<unsigned>, bool>,
327 Predicates::ResultGroupPos> {
340 return std::get<1>(
key);
354 Predicates::TypePos> {
358 "expected parent to be an attribute, operand, or result");
369 :
public PredicateBase<TypeLiteralPosition, Position, Attribute,
370 Predicates::TypeLiteralPos> {
381 :
public PredicateBase<UsersPosition, Position, std::pair<Position *, bool>,
382 Predicates::UsersPos> {
427 :
public PredicateBase<AttributeAnswer, Qualifier, Attribute,
428 Predicates::AttributeAnswer> {
434 :
public PredicateBase<OperationNameAnswer, Qualifier, OperationName,
435 Predicates::OperationNameAnswer> {
441 :
PredicateBase<TrueAnswer, Qualifier, void, Predicates::TrueAnswer> {
447 :
PredicateBase<FalseAnswer, Qualifier, void, Predicates::FalseAnswer> {
454 Predicates::TypeAnswer> {
461 Predicates::UnsignedAnswer> {
471 Predicates::AttributeQuestion> {};
477 ConstraintQuestion, Qualifier,
478 std::tuple<StringRef, ArrayRef<Position *>, ArrayRef<Type>, bool>,
479 Predicates::ConstraintQuestion> {
511 :
public PredicateBase<EqualToQuestion, Qualifier, Position *,
512 Predicates::EqualToQuestion> {
519 Predicates::IsNotNullQuestion> {};
523 :
public PredicateBase<OperandCountQuestion, Qualifier, void,
524 Predicates::OperandCountQuestion> {};
526 :
public PredicateBase<OperandCountAtLeastQuestion, Qualifier, void,
527 Predicates::OperandCountAtLeastQuestion> {};
531 :
public PredicateBase<OperationNameQuestion, Qualifier, void,
532 Predicates::OperationNameQuestion> {};
537 Predicates::ResultCountQuestion> {};
539 :
public PredicateBase<ResultCountAtLeastQuestion, Qualifier, void,
540 Predicates::ResultCountAtLeastQuestion> {};
544 Predicates::TypeQuestion> {};
556 registerParametricStorageType<AttributePosition>();
557 registerParametricStorageType<AttributeLiteralPosition>();
558 registerParametricStorageType<ConstraintPosition>();
559 registerParametricStorageType<ForEachPosition>();
560 registerParametricStorageType<OperandPosition>();
561 registerParametricStorageType<OperandGroupPosition>();
562 registerParametricStorageType<OperationPosition>();
563 registerParametricStorageType<ResultPosition>();
564 registerParametricStorageType<ResultGroupPosition>();
565 registerParametricStorageType<TypePosition>();
566 registerParametricStorageType<TypeLiteralPosition>();
567 registerParametricStorageType<UsersPosition>();
570 registerParametricStorageType<AttributeAnswer>();
571 registerParametricStorageType<OperationNameAnswer>();
572 registerParametricStorageType<TypeAnswer>();
573 registerParametricStorageType<UnsignedAnswer>();
574 registerSingletonStorageType<FalseAnswer>();
575 registerSingletonStorageType<TrueAnswer>();
578 registerParametricStorageType<ConstraintQuestion>();
579 registerParametricStorageType<EqualToQuestion>();
580 registerSingletonStorageType<AttributeQuestion>();
581 registerSingletonStorageType<IsNotNullQuestion>();
582 registerSingletonStorageType<OperandCountQuestion>();
583 registerSingletonStorageType<OperandCountAtLeastQuestion>();
584 registerSingletonStorageType<OperationNameQuestion>();
585 registerSingletonStorageType<ResultCountQuestion>();
586 registerSingletonStorageType<ResultCountAtLeastQuestion>();
587 registerSingletonStorageType<TypeQuestion>();
599 : uniquer(uniquer), ctx(ctx) {}
610 assert((isa<OperandPosition, OperandGroupPosition>(p)) &&
611 "expected operand position");
617 assert((isa<ForEachPosition>(p)) &&
"expected users position");
682 "expected result position");
715 uniquer, std::make_tuple(name, args, resultTypes, isNegated)),
Attributes are known-constant values of operations.
MLIRContext is the top-level object for a collection of MLIR operations.
This class acts as the base storage that all storage classes must derived from.
This is a utility allocator used to allocate memory for instances of derived types.
ArrayRef< T > copyInto(ArrayRef< T > elements)
Copy the specified array of elements into memory managed by our bump pointer allocator.
T * allocate()
Allocate an instance of the provided type.
A utility class to get or create instances of "storage classes".
Storage * get(function_ref< void(Storage *)> initFn, TypeID id, Args &&...args)
Gets a uniqued instance of 'Storage'.
A position describes a value on the input IR on which a predicate may be applied, such as an operatio...
Position(Predicates::Kind kind)
unsigned getOperationDepth() const
Returns the depth of the first ancestor operation position.
Position * getParent() const
Returns the parent position. The root operation position has no parent.
Predicates::Kind getKind() const
Returns the kind of this position.
Position * parent
Link to the parent position.
Base storage for simple predicates that only unique with the kind.
static bool classof(const BaseT *pred)
static ConcreteT * get(StorageUniquer &uniquer)
Base class for all predicates, used to allow efficient pointer comparison.
bool operator==(const KeyTy &key) const
Utility methods required by the storage allocator.
static ConcreteT * construct(StorageUniquer::StorageAllocator &alloc, KeyT &&key)
Construct an instance with the given storage allocator.
static ConcreteT * get(StorageUniquer &uniquer, Args &&...args)
Get an instance of this position.
PredicateBase(KeyT &&key)
static bool classof(const BaseT *pred)
const KeyTy & getValue() const
Return the key value of this predicate.
PredicateBase< ConcreteT, BaseT, Key, Kind > Base
This class provides utilities for constructing predicates.
ConstraintPosition * getConstraintPosition(ConstraintQuestion *q, unsigned index)
Position * getTypeLiteral(Attribute attr)
Returns a type position for the given type value.
Predicate getOperandCount(unsigned count)
Create a predicate comparing the number of operands of an operation to a known value.
OperationPosition * getPassthroughOp(Position *p)
Returns the operation position equivalent to the given position.
Predicate getIsNotNull()
Create a predicate comparing a value with null.
Predicate getOperandCountAtLeast(unsigned count)
Predicate getResultCountAtLeast(unsigned count)
Position * getType(Position *p)
Returns a type position for the given entity.
Position * getAttribute(OperationPosition *p, StringRef name)
Returns an attribute position for an attribute of the given operation.
Position * getOperandGroup(OperationPosition *p, std::optional< unsigned > group, bool isVariadic)
Returns a position for a group of operands of the given operation.
Position * getForEach(Position *p, unsigned id)
Position * getOperand(OperationPosition *p, unsigned operand)
Returns an operand position for an operand of the given operation.
Position * getResult(OperationPosition *p, unsigned result)
Returns a result position for a result of the given operation.
Position * getRoot()
Returns the root operation position.
Predicate getAttributeConstraint(Attribute attr)
Create a predicate comparing an attribute to a known value.
Position * getResultGroup(OperationPosition *p, std::optional< unsigned > group, bool isVariadic)
Returns a position for a group of results of the given operation.
Position * getAllResults(OperationPosition *p)
UsersPosition * getUsers(Position *p, bool useRepresentative)
Returns the users of a position using the value at the given operand.
Predicate getTypeConstraint(Attribute type)
Create a predicate comparing the type of an attribute or value to a known type.
OperationPosition * getOperandDefiningOp(Position *p)
Returns the parent position defining the value held by the given operand.
Predicate getResultCount(unsigned count)
Create a predicate comparing the number of results of an operation to a known value.
std::pair< Qualifier *, Qualifier * > Predicate
An ordinal predicate consists of a "Question" and a set of acceptable "Answers" (later converted to o...
Predicate getEqualTo(Position *pos)
Create a predicate checking if two values are equal.
Position * getAllOperands(OperationPosition *p)
PredicateBuilder(PredicateUniquer &uniquer, MLIRContext *ctx)
Position * getAttributeLiteral(Attribute attr)
Returns an attribute position for the given attribute.
Predicate getConstraint(StringRef name, ArrayRef< Position * > args, ArrayRef< Type > resultTypes, bool isNegated)
Create a predicate that applies a generic constraint.
Predicate getNotEqualTo(Position *pos)
Create a predicate checking if two values are not equal.
Predicate getOperationName(StringRef name)
Create a predicate comparing the name of an operation to a known value.
This class provides a storage uniquer that is used to allocate predicate instances.
An ordinal predicate consists of a "Question" and a set of acceptable "Answers" (later converted to o...
Predicates::Kind getKind() const
Returns the kind of this qualifier.
Qualifier(Predicates::Kind kind)
Kind
An enumeration of the kinds of predicates.
@ ResultCountAtLeastQuestion
@ OperationPos
Positions, ordered by decreasing priority.
@ OperandCountAtLeastQuestion
inline ::llvm::hash_code hash_value(const PolynomialBase< D, T > &arg)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
An Answer representing an Attribute value.
A position describing a literal attribute.
A position describing an attribute of an operation.
AttributePosition(const KeyTy &key)
StringAttr getName() const
Returns the attribute name of this position.
Compare an Attribute to a constant value.
A position describing the result of a native constraint.
unsigned getIndex() const
ConstraintQuestion * getQuestion() const
Returns the ConstraintQuestion to enable keeping track of the native constraint this position stems f...
Apply a parameterized constraint to multiple position values and possibly produce results.
StringRef getName() const
Return the name of the constraint.
ArrayRef< Type > getResultTypes() const
Return the result types of the constraint.
ArrayRef< Position * > getArgs() const
Return the arguments of the constraint.
static ConstraintQuestion * construct(StorageUniquer::StorageAllocator &alloc, KeyTy key)
Construct an instance with the given storage allocator.
static llvm::hash_code hashKey(const KeyTy &key)
Returns a hash suitable for the given keytype.
bool getIsNegated() const
Return the negation status of the constraint.
Compare the equality of two values.
An Answer representing a boolean 'false' value.
A position describing an iterative choice of an operation.
unsigned getID() const
Returns the ID, for differentiating various loops.
ForEachPosition(const KeyTy &key)
Compare a positional value with null, i.e. check if it exists.
Compare the number of operands of an operation with a known value.
A position describing an operand group of an operation.
bool isVariadic() const
Returns if the operand group has unknown size.
std::optional< unsigned > getOperandGroupNumber() const
Returns the group number of this position.
static llvm::hash_code hashKey(const KeyTy &key)
Returns a hash suitable for the given keytype.
OperandGroupPosition(const KeyTy &key)
A position describing an operand of an operation.
OperandPosition(const KeyTy &key)
unsigned getOperandNumber() const
Returns the operand number of this position.
An Answer representing an OperationName value.
Compare the name of an operation with a known value.
An operation position describes an operation node in the IR.
static OperationPosition * getRoot(StorageUniquer &uniquer)
Gets the root position.
bool isRoot() const
Returns if this operation position corresponds to the root.
OperationPosition(const KeyTy &key)
unsigned getDepth() const
Returns the depth of this position.
static llvm::hash_code hashKey(const KeyTy &key)
Returns a hash suitable for the given keytype.
bool isOperandDefiningOp() const
Returns if this operation represents an operand defining op.
static OperationPosition * get(StorageUniquer &uniquer, Position *parent)
Gets an operation position with the given parent.
Compare the number of results of an operation with a known value.
A position describing a result group of an operation.
ResultGroupPosition(const KeyTy &key)
bool isVariadic() const
Returns if the result group has unknown size.
std::optional< unsigned > getResultGroupNumber() const
Returns the group number of this position.
static llvm::hash_code hashKey(const KeyTy &key)
Returns a hash suitable for the given keytype.
A position describing a result of an operation.
ResultPosition(const KeyTy &key)
unsigned getResultNumber() const
Returns the result number of this position.
An Answer representing a boolean true value.
An Answer representing a Type value.
A position describing a literal type or type range.
A position describing the result type of an entity, i.e.
TypePosition(const KeyTy &key)
Compare the type of an attribute or value with a known type.
An Answer representing an unsigned value.
A position describing the users of a value or a range of values.
bool useRepresentative() const
Indicates whether to compute a range of a representative.
static llvm::hash_code hashKey(const KeyTy &key)
Returns a hash suitable for the given keytype.
UsersPosition(const KeyTy &key)