32 #ifndef MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_
33 #define MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_
41 namespace pdl_to_pdl_interp {
42 namespace Predicates {
82 template <
typename ConcreteT,
typename BaseT,
typename Key,
89 template <
typename KeyT>
94 template <
typename... Args>
96 return uniquer.
get<ConcreteT>({}, std::forward<Args>(args)...);
100 template <
typename KeyT>
103 return new (alloc.
allocate<ConcreteT>()) ConcreteT(std::forward<KeyT>(
key));
118 template <
typename ConcreteT,
typename BaseT, Predicates::Kind Kind>
126 return uniquer.
get<ConcreteT>();
135 struct OperationPosition;
174 std::pair<OperationPosition *, StringAttr>,
175 Predicates::AttributePos> {
188 :
public PredicateBase<AttributeLiteralPosition, Position, Attribute,
189 Predicates::AttributeLiteralPos> {
199 std::pair<Position *, unsigned>,
200 Predicates::ForEachPos> {
215 std::pair<OperationPosition *, unsigned>,
216 Predicates::OperandPos> {
230 OperandGroupPosition, Position,
231 std::tuple<OperationPosition *, std::optional<unsigned>, bool>,
232 Predicates::OperandGroupPos> {
243 return std::get<1>(
key);
258 std::pair<Position *, unsigned>,
259 Predicates::OperationPos> {
300 std::pair<ConstraintQuestion *, unsigned>,
301 Predicates::ConstraintResultPos> {
319 std::pair<OperationPosition *, unsigned>,
320 Predicates::ResultPos> {
334 ResultGroupPosition, Position,
335 std::tuple<OperationPosition *, std::optional<unsigned>, bool>,
336 Predicates::ResultGroupPos> {
349 return std::get<1>(
key);
364 Predicates::TypePos> {
368 "expected parent to be an attribute, operand, or result");
380 :
public PredicateBase<TypeLiteralPosition, Position, Attribute,
381 Predicates::TypeLiteralPos> {
393 :
public PredicateBase<UsersPosition, Position, std::pair<Position *, bool>,
394 Predicates::UsersPos> {
440 :
public PredicateBase<AttributeAnswer, Qualifier, Attribute,
441 Predicates::AttributeAnswer> {
447 :
public PredicateBase<OperationNameAnswer, Qualifier, OperationName,
448 Predicates::OperationNameAnswer> {
454 :
PredicateBase<TrueAnswer, Qualifier, void, Predicates::TrueAnswer> {
460 :
PredicateBase<FalseAnswer, Qualifier, void, Predicates::FalseAnswer> {
467 Predicates::TypeAnswer> {
474 Predicates::UnsignedAnswer> {
485 Predicates::AttributeQuestion> {};
491 ConstraintQuestion, Qualifier,
492 std::tuple<StringRef, ArrayRef<Position *>, ArrayRef<Type>, bool>,
493 Predicates::ConstraintQuestion> {
525 :
public PredicateBase<EqualToQuestion, Qualifier, Position *,
526 Predicates::EqualToQuestion> {
533 Predicates::IsNotNullQuestion> {};
537 :
public PredicateBase<OperandCountQuestion, Qualifier, void,
538 Predicates::OperandCountQuestion> {};
540 :
public PredicateBase<OperandCountAtLeastQuestion, Qualifier, void,
541 Predicates::OperandCountAtLeastQuestion> {};
545 :
public PredicateBase<OperationNameQuestion, Qualifier, void,
546 Predicates::OperationNameQuestion> {};
551 Predicates::ResultCountQuestion> {};
553 :
public PredicateBase<ResultCountAtLeastQuestion, Qualifier, void,
554 Predicates::ResultCountAtLeastQuestion> {};
558 Predicates::TypeQuestion> {};
570 registerParametricStorageType<AttributePosition>();
571 registerParametricStorageType<AttributeLiteralPosition>();
572 registerParametricStorageType<ConstraintPosition>();
573 registerParametricStorageType<ForEachPosition>();
574 registerParametricStorageType<OperandPosition>();
575 registerParametricStorageType<OperandGroupPosition>();
576 registerParametricStorageType<OperationPosition>();
577 registerParametricStorageType<ResultPosition>();
578 registerParametricStorageType<ResultGroupPosition>();
579 registerParametricStorageType<TypePosition>();
580 registerParametricStorageType<TypeLiteralPosition>();
581 registerParametricStorageType<UsersPosition>();
584 registerParametricStorageType<AttributeAnswer>();
585 registerParametricStorageType<OperationNameAnswer>();
586 registerParametricStorageType<TypeAnswer>();
587 registerParametricStorageType<UnsignedAnswer>();
588 registerSingletonStorageType<FalseAnswer>();
589 registerSingletonStorageType<TrueAnswer>();
592 registerParametricStorageType<ConstraintQuestion>();
593 registerParametricStorageType<EqualToQuestion>();
594 registerSingletonStorageType<AttributeQuestion>();
595 registerSingletonStorageType<IsNotNullQuestion>();
596 registerSingletonStorageType<OperandCountQuestion>();
597 registerSingletonStorageType<OperandCountAtLeastQuestion>();
598 registerSingletonStorageType<OperationNameQuestion>();
599 registerSingletonStorageType<ResultCountQuestion>();
600 registerSingletonStorageType<ResultCountAtLeastQuestion>();
601 registerSingletonStorageType<TypeQuestion>();
613 : uniquer(uniquer), ctx(ctx) {}
624 assert((isa<OperandPosition, OperandGroupPosition>(p)) &&
625 "expected operand position");
631 assert((isa<ForEachPosition>(p)) &&
"expected users position");
696 "expected result position");
729 uniquer, std::make_tuple(name, args, resultTypes, isNegated)),
union mlir::linalg::@1193::ArityGroupAndKind::Kind kind
Attributes are known-constant values of operations.
MLIRContext is the top-level object for a collection of MLIR operations.
This class acts as the base storage that all storage classes must derived from.
This is a utility allocator used to allocate memory for instances of derived types.
ArrayRef< T > copyInto(ArrayRef< T > elements)
Copy the specified array of elements into memory managed by our bump pointer allocator.
T * allocate()
Allocate an instance of the provided type.
A utility class to get or create instances of "storage classes".
Storage * get(function_ref< void(Storage *)> initFn, TypeID id, Args &&...args)
Gets a uniqued instance of 'Storage'.
A position describes a value on the input IR on which a predicate may be applied, such as an operatio...
Position(Predicates::Kind kind)
unsigned getOperationDepth() const
Returns the depth of the first ancestor operation position.
Position * getParent() const
Returns the parent position. The root operation position has no parent.
Predicates::Kind getKind() const
Returns the kind of this position.
Position * parent
Link to the parent position.
Base storage for simple predicates that only unique with the kind.
static bool classof(const BaseT *pred)
static ConcreteT * get(StorageUniquer &uniquer)
Base class for all predicates, used to allow efficient pointer comparison.
bool operator==(const KeyTy &key) const
Utility methods required by the storage allocator.
static ConcreteT * construct(StorageUniquer::StorageAllocator &alloc, KeyT &&key)
Construct an instance with the given storage allocator.
static ConcreteT * get(StorageUniquer &uniquer, Args &&...args)
Get an instance of this position.
PredicateBase(KeyT &&key)
static bool classof(const BaseT *pred)
const KeyTy & getValue() const
Return the key value of this predicate.
PredicateBase< ConcreteT, BaseT, Key, Kind > Base
This class provides utilities for constructing predicates.
ConstraintPosition * getConstraintPosition(ConstraintQuestion *q, unsigned index)
Position * getTypeLiteral(Attribute attr)
Returns a type position for the given type value.
Predicate getOperandCount(unsigned count)
Create a predicate comparing the number of operands of an operation to a known value.
OperationPosition * getPassthroughOp(Position *p)
Returns the operation position equivalent to the given position.
Predicate getIsNotNull()
Create a predicate comparing a value with null.
Predicate getOperandCountAtLeast(unsigned count)
Predicate getResultCountAtLeast(unsigned count)
Position * getType(Position *p)
Returns a type position for the given entity.
Position * getAttribute(OperationPosition *p, StringRef name)
Returns an attribute position for an attribute of the given operation.
Position * getOperandGroup(OperationPosition *p, std::optional< unsigned > group, bool isVariadic)
Returns a position for a group of operands of the given operation.
Position * getForEach(Position *p, unsigned id)
Position * getOperand(OperationPosition *p, unsigned operand)
Returns an operand position for an operand of the given operation.
Position * getResult(OperationPosition *p, unsigned result)
Returns a result position for a result of the given operation.
Position * getRoot()
Returns the root operation position.
Predicate getAttributeConstraint(Attribute attr)
Create a predicate comparing an attribute to a known value.
Position * getResultGroup(OperationPosition *p, std::optional< unsigned > group, bool isVariadic)
Returns a position for a group of results of the given operation.
Position * getAllResults(OperationPosition *p)
UsersPosition * getUsers(Position *p, bool useRepresentative)
Returns the users of a position using the value at the given operand.
Predicate getTypeConstraint(Attribute type)
Create a predicate comparing the type of an attribute or value to a known type.
OperationPosition * getOperandDefiningOp(Position *p)
Returns the parent position defining the value held by the given operand.
Predicate getResultCount(unsigned count)
Create a predicate comparing the number of results of an operation to a known value.
std::pair< Qualifier *, Qualifier * > Predicate
An ordinal predicate consists of a "Question" and a set of acceptable "Answers" (later converted to o...
Predicate getEqualTo(Position *pos)
Create a predicate checking if two values are equal.
Position * getAllOperands(OperationPosition *p)
PredicateBuilder(PredicateUniquer &uniquer, MLIRContext *ctx)
Position * getAttributeLiteral(Attribute attr)
Returns an attribute position for the given attribute.
Predicate getConstraint(StringRef name, ArrayRef< Position * > args, ArrayRef< Type > resultTypes, bool isNegated)
Create a predicate that applies a generic constraint.
Predicate getNotEqualTo(Position *pos)
Create a predicate checking if two values are not equal.
Predicate getOperationName(StringRef name)
Create a predicate comparing the name of an operation to a known value.
This class provides a storage uniquer that is used to allocate predicate instances.
An ordinal predicate consists of a "Question" and a set of acceptable "Answers" (later converted to o...
Predicates::Kind getKind() const
Returns the kind of this qualifier.
Qualifier(Predicates::Kind kind)
Kind
An enumeration of the kinds of predicates.
@ ResultCountAtLeastQuestion
@ OperationPos
Positions, ordered by decreasing priority.
@ OperandCountAtLeastQuestion
inline ::llvm::hash_code hash_value(const PolynomialBase< D, T > &arg)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
An Answer representing an Attribute value.
A position describing a literal attribute.
A position describing an attribute of an operation.
AttributePosition(const KeyTy &key)
StringAttr getName() const
Returns the attribute name of this position.
Compare an Attribute to a constant value.
A position describing the result of a native constraint.
unsigned getIndex() const
ConstraintQuestion * getQuestion() const
Returns the ConstraintQuestion to enable keeping track of the native constraint this position stems f...
Apply a parameterized constraint to multiple position values and possibly produce results.
StringRef getName() const
Return the name of the constraint.
ArrayRef< Type > getResultTypes() const
Return the result types of the constraint.
ArrayRef< Position * > getArgs() const
Return the arguments of the constraint.
static ConstraintQuestion * construct(StorageUniquer::StorageAllocator &alloc, KeyTy key)
Construct an instance with the given storage allocator.
static llvm::hash_code hashKey(const KeyTy &key)
Returns a hash suitable for the given keytype.
bool getIsNegated() const
Return the negation status of the constraint.
Compare the equality of two values.
An Answer representing a boolean 'false' value.
A position describing an iterative choice of an operation.
unsigned getID() const
Returns the ID, for differentiating various loops.
ForEachPosition(const KeyTy &key)
Compare a positional value with null, i.e. check if it exists.
Compare the number of operands of an operation with a known value.
A position describing an operand group of an operation.
bool isVariadic() const
Returns if the operand group has unknown size.
std::optional< unsigned > getOperandGroupNumber() const
Returns the group number of this position.
static llvm::hash_code hashKey(const KeyTy &key)
Returns a hash suitable for the given keytype.
OperandGroupPosition(const KeyTy &key)
A position describing an operand of an operation.
OperandPosition(const KeyTy &key)
unsigned getOperandNumber() const
Returns the operand number of this position.
An Answer representing an OperationName value.
Compare the name of an operation with a known value.
An operation position describes an operation node in the IR.
static OperationPosition * getRoot(StorageUniquer &uniquer)
Gets the root position.
bool isRoot() const
Returns if this operation position corresponds to the root.
OperationPosition(const KeyTy &key)
unsigned getDepth() const
Returns the depth of this position.
static llvm::hash_code hashKey(const KeyTy &key)
Returns a hash suitable for the given keytype.
bool isOperandDefiningOp() const
Returns if this operation represents an operand defining op.
static OperationPosition * get(StorageUniquer &uniquer, Position *parent)
Gets an operation position with the given parent.
Compare the number of results of an operation with a known value.
A position describing a result group of an operation.
ResultGroupPosition(const KeyTy &key)
bool isVariadic() const
Returns if the result group has unknown size.
std::optional< unsigned > getResultGroupNumber() const
Returns the group number of this position.
static llvm::hash_code hashKey(const KeyTy &key)
Returns a hash suitable for the given keytype.
A position describing a result of an operation.
ResultPosition(const KeyTy &key)
unsigned getResultNumber() const
Returns the result number of this position.
An Answer representing a boolean true value.
An Answer representing a Type value.
A position describing a literal type or type range.
A position describing the result type of an entity, i.e.
TypePosition(const KeyTy &key)
Compare the type of an attribute or value with a known type.
An Answer representing an unsigned value.
A position describing the users of a value or a range of values.
bool useRepresentative() const
Indicates whether to compute a range of a representative.
static llvm::hash_code hashKey(const KeyTy &key)
Returns a hash suitable for the given keytype.
UsersPosition(const KeyTy &key)