14 #include "llvm/ADT/MapVector.h"
15 #include "llvm/ADT/SmallPtrSet.h"
16 #include "llvm/ADT/TypeSwitch.h"
17 #include "llvm/Support/Debug.h"
18 #include "llvm/Support/DebugLog.h"
21 #define DEBUG_TYPE "pdl-predicate-tree"
42 return llvm::count_if(values.
getTypes(),
43 [](
Type type) { return !isa<pdl::RangeType>(type); });
50 assert(isa<pdl::AttributeType>(val.
getType()) &&
"expected attribute type");
55 if (
Value type = attr.getValueType())
57 else if (
Attribute value = attr.getValueAttr())
68 bool isVariadic = isa<pdl::RangeType>(valueType);
72 .Case<pdl::OperandOp, pdl::OperandsOp>([&](
auto op) {
75 if (std::is_same<pdl::OperandOp, decltype(op)>::value ||
76 cast<OperandGroupPosition>(pos)->getOperandGroupNumber())
79 if (
Value type = op.getValueType())
83 .Case<pdl::ResultOp, pdl::ResultsOp>([&](
auto op) {
84 std::optional<unsigned> index = op.getIndex();
92 predList.emplace_back(parentPos, builder.
getIsNotNull());
97 if (std::is_same<pdl::ResultOp, decltype(op)>::value)
98 resultPos = builder.
getResult(parentPos, *index);
101 predList.emplace_back(resultPos, builder.
getEqualTo(pos));
113 std::optional<unsigned> ignoreOperand = std::nullopt) {
114 assert(isa<pdl::OperationType>(val.
getType()) &&
"expected operation");
115 pdl::OperationOp op = cast<pdl::OperationOp>(val.
getDefiningOp());
123 if (std::optional<StringRef> opName = op.getOpName())
130 if (minOperands != operands.size()) {
141 if (minResults == types.size())
147 for (
auto [attrName, attr] :
148 llvm::zip(op.getAttributeValueNames(), op.getAttributeValues())) {
150 predList, attr, builder, inputs,
160 if (operands.size() == 1 && isa<pdl::RangeType>(operands[0].
getType())) {
167 bool foundVariableLength =
false;
169 bool isVariadic = isa<pdl::RangeType>(operandIt.value().getType());
170 foundVariableLength |= isVariadic;
174 if (ignoreOperand == operandIt.index())
180 : builder.
getOperand(opPos, operandIt.index());
185 if (types.size() == 1 && isa<pdl::RangeType>(types[0].
getType())) {
191 bool foundVariableLength =
false;
193 bool isVariadic = isa<pdl::RangeType>(typeValue.getType());
194 foundVariableLength |= isVariadic;
196 auto *resultPos = foundVariableLength
199 predList.emplace_back(resultPos, builder.
getIsNotNull());
211 if (
Attribute type = typeOp.getConstantTypeAttr())
213 }
else if (pdl::TypesOp typeOp = val.
getDefiningOp<pdl::TypesOp>()) {
214 if (
Attribute typeAttr = typeOp.getConstantTypesAttr())
225 auto it = inputs.try_emplace(val, pos);
229 if (isa<pdl::AttributeOp, pdl::OperandOp, pdl::OperandsOp, pdl::OperationOp,
231 auto minMaxPositions =
233 predList.emplace_back(minMaxPositions.second,
243 .Case<OperandPosition, OperandGroupPosition>([&](
auto *pos) {
246 .Default([](
auto *) { llvm_unreachable(
"unexpected position kind"); });
250 std::vector<PositionalPredicate> &predList,
257 assert(value &&
"expected non-tree `pdl.attribute` to contain a value");
262 std::vector<PositionalPredicate> &predList,
267 std::vector<Position *> allPositions;
268 allPositions.reserve(arguments.size());
269 for (
Value arg : arguments)
270 allPositions.push_back(inputs.lookup(arg));
283 auto [it, inserted] = inputs.try_emplace(result, pos);
290 std::tie(second, first) = std::make_pair(first, second);
292 predList.emplace_back(second, builder.
getEqualTo(first));
295 predList.emplace_back(pos, pred);
299 std::vector<PositionalPredicate> &predList,
307 auto *parentPos = cast<OperationPosition>(inputs.lookup(op.getParent()));
308 resultPos = builder.
getResult(parentPos, op.getIndex());
309 predList.emplace_back(resultPos, builder.
getIsNotNull());
313 std::vector<PositionalPredicate> &predList,
321 auto *parentPos = cast<OperationPosition>(inputs.lookup(op.getParent()));
322 bool isVariadic = isa<pdl::RangeType>(op.getType());
323 std::optional<unsigned> index = op.getIndex();
326 predList.emplace_back(resultPos, builder.
getIsNotNull());
333 Position *&typePos = inputs[typeValue];
338 "expected non-tree `pdl.type`/`pdl.types` to contain a value");
345 std::vector<PositionalPredicate> &predList,
348 for (
Operation &op : pattern.getBodyRegion().getOps()) {
350 .Case([&](pdl::AttributeOp attrOp) {
353 .Case<pdl::ApplyNativeConstraintOp>([&](
auto constraintOp) {
356 .Case<pdl::ResultOp, pdl::ResultsOp>([&](
auto resultOp) {
359 .Case([&](pdl::TypeOp typeOp) {
361 typeOp, [&] {
return typeOp.getConstantTypeAttr(); }, builder,
364 .Case([&](pdl::TypesOp typeOp) {
366 typeOp, [&] {
return typeOp.getConstantTypesAttr(); }, builder,
377 std::optional<unsigned> index;
392 for (
auto operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>()) {
393 for (
Value operand : operationOp.getOperandValues())
395 .Case<pdl::ResultOp, pdl::ResultsOp>(
396 [&used](
auto resultOp) { used.insert(resultOp.getParent()); });
401 if (
Value root = pattern.getRewriter().getRoot())
406 for (
Value operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>())
407 if (!used.contains(operationOp))
408 roots.push_back(operationOp);
421 ParentMaps &parentMaps) {
429 Entry(
Value value,
Value parent, std::optional<unsigned> index,
431 : value(value), parent(parent), index(index), depth(depth) {}
435 std::optional<unsigned> index;
447 llvm::MapVector<Value, SmallVector<RootDepth, 1>> connectorsRootsDepths;
450 for (
Value root : roots) {
454 std::queue<Entry> toVisit;
455 toVisit.emplace(root,
Value(), 0, 0);
460 while (!toVisit.empty()) {
461 Entry entry = toVisit.front();
464 if (!parentMap.insert({entry.value, {entry.parent, entry.index}}).second)
468 connectorsRootsDepths[entry.value].push_back({root, entry.depth});
474 .Case<pdl::OperationOp>([&](
auto operationOp) {
478 if (operands.size() == 1 &&
479 isa<pdl::RangeType>(operands[0].
getType())) {
480 toVisit.emplace(operands[0], entry.value, std::nullopt,
488 toVisit.emplace(p.value(), entry.value, p.index(),
491 .Case<pdl::ResultOp, pdl::ResultsOp>([&](
auto resultOp) {
492 toVisit.emplace(resultOp.getParent(), entry.value,
493 resultOp.getIndex(), entry.depth);
501 for (
const auto &connectorRootsDepths : connectorsRootsDepths) {
502 Value value = connectorRootsDepths.first;
506 if (rootsDepths.size() == 1)
509 for (
const RootDepth &p : rootsDepths) {
510 for (
const RootDepth &q : rootsDepths) {
517 entry.
cost.second = nextID++;
518 entry.
cost.first = q.depth;
525 assert((llvm::hasSingleElement(roots) || graph.size() == roots.size()) &&
526 "the pattern contains a candidate root disconnected from the others");
533 assert(index < operands.size() &&
"operand index out of range");
534 for (
unsigned i = 0; i <= index; ++i)
535 if (isa<pdl::RangeType>(operands[i].
getType()))
541 static void visitUpward(std::vector<PositionalPredicate> &predList,
545 Value value = opIndex.parent;
547 .Case<pdl::OperationOp>([&](
auto operationOp) {
548 LDBG() <<
" * Value: " << value;
557 if (!opIndex.index) {
562 Type type = operationOp.getOperandValues()[*opIndex.index].getType();
563 bool variadic = isa<pdl::RangeType>(type);
567 operandPos = builder.
getOperand(opPos, *opIndex.index);
569 predList.emplace_back(operandPos, builder.
getEqualTo(pos));
575 bool inserted = valueToPosition.try_emplace(value, opPos).second;
577 assert(inserted &&
"duplicate upward visit");
586 .Case<pdl::ResultOp>([&](
auto resultOp) {
588 auto *opPos = dyn_cast<OperationPosition>(pos);
589 assert(opPos &&
"operations and results must be interleaved");
590 pos = builder.
getResult(opPos, *opIndex.index);
593 valueToPosition.try_emplace(value, pos);
595 .Case<pdl::ResultsOp>([&](
auto resultOp) {
597 auto *opPos = dyn_cast<OperationPosition>(pos);
598 assert(opPos &&
"operations and results must be interleaved");
599 bool isVariadic = isa<pdl::RangeType>(value.
getType());
606 valueToPosition.try_emplace(value, pos);
614 std::vector<PositionalPredicate> &predList,
620 ParentMaps parentMaps;
623 for (
auto &target : graph) {
624 LDBG() <<
" * " << target.first.getLoc() <<
" " << target.first;
625 for (
auto &source : target.second) {
627 LDBG() <<
" <- " << source.first <<
": " << entry.
cost.first <<
":"
634 Value bestRoot = pattern.getRewriter().getRoot();
637 unsigned bestCost = 0;
638 LDBG() <<
"Candidate roots:";
639 for (
Value root : roots) {
641 unsigned cost = solver.
solve();
642 LDBG() <<
" * " << root <<
": " << cost;
643 if (!bestRoot || bestCost > cost) {
656 LDBG() <<
"Best tree:";
657 for (
const std::pair<Value, Value> &edge : bestEdges) {
659 LDBG() <<
" * " << edge.first <<
" <- " << edge.second;
661 LDBG() <<
" * " << edge.first;
664 LDBG() <<
"Calling key getTreePredicates (Value: " << bestRoot <<
")";
675 Value target = it.value().first;
676 Value source = it.value().second;
682 if (valueToPosition.count(target))
686 Value connector = graph[target][source].connector;
687 assert(connector &&
"invalid edge");
688 LDBG() <<
" * Connector: " << connector.
getLoc();
690 Position *pos = valueToPosition.lookup(connector);
691 assert(pos &&
"connector has not been traversed yet");
694 for (
Value value = connector; value != target;) {
695 OpIndex opIndex = parentMap.lookup(value);
696 assert(opIndex.parent &&
"missing parent");
697 visitUpward(predList, opIndex, builder, valueToPosition, pos, it.index());
698 value = opIndex.parent;
716 struct OrderedPredicate {
717 OrderedPredicate(
const std::pair<Position *, Qualifier *> &ip)
718 : position(ip.first), question(ip.second) {}
720 : position(ip.position), question(ip.question) {}
731 unsigned primary = 0;
736 unsigned secondary = 0;
749 bool operator<(
const OrderedPredicate &rhs)
const {
756 auto *rhsPos = rhs.position;
757 return std::make_tuple(primary, secondary, rhsPos->getOperationDepth(),
758 rhsPos->getKind(), rhs.question->getKind(), rhs.id) >
759 std::make_tuple(rhs.primary, rhs.secondary,
767 struct OrderedPredicateDenseInfo {
770 static OrderedPredicate getEmptyKey() {
return Base::getEmptyKey(); }
771 static OrderedPredicate getTombstoneKey() {
return Base::getTombstoneKey(); }
772 static bool isEqual(
const OrderedPredicate &lhs,
773 const OrderedPredicate &rhs) {
774 return lhs.position == rhs.position && lhs.question == rhs.question;
776 static unsigned getHashValue(
const OrderedPredicate &p) {
777 return llvm::hash_combine(p.position, p.question);
783 struct OrderedPredicateList {
784 OrderedPredicateList(pdl::PatternOp pattern,
Value root)
785 : pattern(pattern), root(root) {}
787 pdl::PatternOp pattern;
797 return node->
getPosition() == predicate->position &&
803 static std::unique_ptr<MatcherNode> &
805 pdl::PatternOp pattern) {
807 "expected matcher to equal the given predicate");
809 auto it = predicate->patternToAnswer.find(pattern);
810 assert(it != predicate->patternToAnswer.end() &&
811 "expected pattern to exist in predicate");
820 OrderedPredicateList &list,
821 std::vector<OrderedPredicate *>::iterator current,
822 std::vector<OrderedPredicate *>::iterator end) {
823 if (current == end) {
826 std::make_unique<SuccessNode>(list.pattern, list.root, std::move(node));
829 }
else if (!list.predicates.contains(*current)) {
836 node = std::make_unique<SwitchNode>((*current)->position,
837 (*current)->question);
840 list, std::next(current), end);
847 list, std::next(current), end);
862 if (
SwitchNode *switchNode = dyn_cast<SwitchNode>(&*node)) {
864 for (
auto &it : children)
869 if (children.size() == 1) {
870 auto *childIt = children.begin();
871 node = std::make_unique<BoolNode>(
872 node->getPosition(), node->getQuestion(), childIt->first,
873 std::move(childIt->second), std::move(node->getFailureNode()));
875 }
else if (
BoolNode *boolNode = dyn_cast<BoolNode>(&*node)) {
885 root = &(*root)->getFailureNode();
886 *root = std::make_unique<ExitNode>();
890 template <
typename Iterator,
typename Compare>
892 while (begin != end) {
897 for (
auto i = begin; i != end; ++i) {
898 if (std::none_of(begin, end, [&](
auto const &b) {
return cmp(b, *i); }))
899 sortBeforeOthers.insert(*i);
902 auto const next = std::stable_partition(begin, end, [&](
auto const &a) {
903 return sortBeforeOthers.contains(a);
905 assert(next != begin &&
"not a partial ordering");
911 static bool dependsOn(OrderedPredicate *a, OrderedPredicate *b) {
912 auto *cqa = dyn_cast<ConstraintQuestion>(a->question);
916 auto positionDependsOnA = [&](
Position *p) {
917 auto *cp = dyn_cast<ConstraintPosition>(p);
918 return cp && cp->getQuestion() == cqa;
921 if (
auto *cqb = dyn_cast<ConstraintQuestion>(b->question)) {
923 return llvm::any_of(cqb->getArgs(), positionDependsOnA);
925 if (
auto *equalTo = dyn_cast<EqualToQuestion>(b->question)) {
926 return positionDependsOnA(b->position) ||
927 positionDependsOnA(equalTo->getValue());
929 return positionDependsOnA(b->position);
934 std::unique_ptr<MatcherNode>
939 struct PatternPredicates {
940 PatternPredicates(pdl::PatternOp pattern,
Value root,
941 std::vector<PositionalPredicate> predicates)
942 : pattern(pattern), root(root), predicates(std::move(predicates)) {}
945 pdl::PatternOp pattern;
951 std::vector<PositionalPredicate> predicates;
955 for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
956 std::vector<PositionalPredicate> predicateList;
959 patternsAndPredicates.emplace_back(pattern, root, std::move(predicateList));
964 for (
auto &patternAndPredList : patternsAndPredicates) {
965 for (
auto &predicate : patternAndPredList.predicates) {
966 auto it = uniqued.insert(predicate);
967 it.first->patternToAnswer.try_emplace(patternAndPredList.pattern,
971 it.first->id = uniqued.size() - 1;
976 std::vector<OrderedPredicateList> lists;
977 lists.reserve(patternsAndPredicates.size());
978 for (
auto &patternAndPredList : patternsAndPredicates) {
979 OrderedPredicateList list(patternAndPredList.pattern,
980 patternAndPredList.root);
981 for (
auto &predicate : patternAndPredList.predicates) {
982 OrderedPredicate *orderedPredicate = &*uniqued.find(predicate);
983 list.predicates.insert(orderedPredicate);
986 ++orderedPredicate->primary;
988 lists.push_back(std::move(list));
994 for (
auto &list : lists) {
996 for (
auto *predicate : list.predicates)
997 total += predicate->primary * predicate->primary;
998 for (
auto *predicate : list.predicates)
999 predicate->secondary += total;
1004 std::vector<OrderedPredicate *> ordered;
1005 ordered.reserve(uniqued.size());
1006 for (
auto &ip : uniqued)
1007 ordered.push_back(&ip);
1008 llvm::sort(ordered, [](OrderedPredicate *lhs, OrderedPredicate *rhs) {
1017 std::unique_ptr<MatcherNode> root;
1018 for (OrderedPredicateList &list : lists)
1032 std::unique_ptr<MatcherNode> failureNode)
1033 : position(p), question(q), failureNode(std::move(failureNode)),
1034 matcherTypeID(matcherTypeID) {}
1041 std::unique_ptr<MatcherNode> successNode,
1042 std::unique_ptr<MatcherNode> failureNode)
1044 std::move(failureNode)),
1045 answer(answer), successNode(std::move(successNode)) {}
1052 std::unique_ptr<MatcherNode> failureNode)
1054 nullptr, std::move(failureNode)),
1055 pattern(pattern), root(root) {}
static Value buildPredicateList(pdl::PatternOp pattern, PredicateBuilder &builder, std::vector< PositionalPredicate > &predList, DenseMap< Value, Position * > &valueToPosition)
Given a pattern operation, build the set of matcher predicates necessary to match this pattern.
static std::unique_ptr< MatcherNode > & getOrCreateChild(SwitchNode *node, OrderedPredicate *predicate, pdl::PatternOp pattern)
Get or insert a child matcher for the given parent switch node, given a predicate and parent pattern.
static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
static bool dependsOn(OrderedPredicate *a, OrderedPredicate *b)
Returns true if 'b' depends on a result of 'a'.
static void getTypePredicates(Value typeValue, function_ref< Attribute()> typeAttrFn, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
static void getNonTreePredicates(pdl::PatternOp pattern, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
Collect all of the predicates that cannot be determined via walking the tree.
static bool useOperandGroup(pdl::OperationOp op, unsigned index)
Returns true if the operand at the given index needs to be queried using an operand group,...
static void getTreePredicates(std::vector< PositionalPredicate > &predList, Value val, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs, Position *pos)
Collect the tree predicates anchored at the given value.
static void getResultPredicates(pdl::ResultOp op, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
static void getOperandTreePredicates(std::vector< PositionalPredicate > &predList, Value val, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs, Position *pos)
Collect all of the predicates for the given operand position.
static bool comparePosDepth(Position *lhs, Position *rhs)
Compares the depths of two positions.
static void visitUpward(std::vector< PositionalPredicate > &predList, OpIndex opIndex, PredicateBuilder &builder, DenseMap< Value, Position * > &valueToPosition, Position *&pos, unsigned rootID)
Visit a node during upward traversal.
static unsigned getNumNonRangeValues(ValueRange values)
Returns the number of non-range elements within values.
static SmallVector< Value > detectRoots(pdl::PatternOp pattern)
Given a pattern, determines the set of roots present in this pattern.
static void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp)
Sorts the range begin/end with the partial order given by cmp.
static void insertExitNode(std::unique_ptr< MatcherNode > *root)
Insert an exit node at the end of the failure path of the root.
static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate)
Returns true if the given matcher refers to the same predicate as the given ordered predicate.
static void foldSwitchToBool(std::unique_ptr< MatcherNode > &node)
Fold any switch nodes nested under node to boolean nodes when possible.
static void buildCostGraph(ArrayRef< Value > roots, RootOrderingGraph &graph, ParentMaps &parentMaps)
Given a list of candidate roots, builds the cost graph for connecting them.
static void getAttributePredicates(pdl::AttributeOp op, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
static void propagatePattern(std::unique_ptr< MatcherNode > &node, OrderedPredicateList &list, std::vector< OrderedPredicate * >::iterator current, std::vector< OrderedPredicate * >::iterator end)
Build the matcher CFG by "pushing" patterns through by sorted predicate order.
Attributes are known-constant values of operations.
This class implements the operand iterators for the Operation class.
type_range getType() const
Operation is the basic unit of execution within MLIR.
This class implements the result iterators for the Operation class.
type_range getTypes() const
This class provides an efficient unique identifier for a specific C++ type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This class represents the base of a predicate matcher node.
Position * getPosition() const
Returns the position on which the question predicate should be checked.
Qualifier * getQuestion() const
Returns the predicate checked on this node.
The optimal branching algorithm solver.
unsigned solve()
Runs the Edmonds' algorithm for the current graph, returning the total cost of the minimum-weight spa...
std::vector< std::pair< Value, Value > > EdgeList
A list of edges (child, parent).
EdgeList preOrderTraversal(ArrayRef< Value > nodes) const
Returns the computed edges as visited in the preorder traversal.
A position describes a value on the input IR on which a predicate may be applied, such as an operatio...
unsigned getOperationDepth() const
Returns the depth of the first ancestor operation position.
Predicates::Kind getKind() const
Returns the kind of this position.
const KeyTy & getValue() const
Return the key value of this predicate.
This class provides utilities for constructing predicates.
ConstraintPosition * getConstraintPosition(ConstraintQuestion *q, unsigned index)
Position * getTypeLiteral(Attribute attr)
Returns a type position for the given type value.
Predicate getOperandCount(unsigned count)
Create a predicate comparing the number of operands of an operation to a known value.
OperationPosition * getPassthroughOp(Position *p)
Returns the operation position equivalent to the given position.
Predicate getIsNotNull()
Create a predicate comparing a value with null.
Predicate getOperandCountAtLeast(unsigned count)
Predicate getResultCountAtLeast(unsigned count)
Position * getType(Position *p)
Returns a type position for the given entity.
Position * getAttribute(OperationPosition *p, StringRef name)
Returns an attribute position for an attribute of the given operation.
Position * getOperandGroup(OperationPosition *p, std::optional< unsigned > group, bool isVariadic)
Returns a position for a group of operands of the given operation.
Position * getForEach(Position *p, unsigned id)
Position * getOperand(OperationPosition *p, unsigned operand)
Returns an operand position for an operand of the given operation.
Position * getResult(OperationPosition *p, unsigned result)
Returns a result position for a result of the given operation.
Position * getRoot()
Returns the root operation position.
Predicate getAttributeConstraint(Attribute attr)
Create a predicate comparing an attribute to a known value.
Position * getResultGroup(OperationPosition *p, std::optional< unsigned > group, bool isVariadic)
Returns a position for a group of results of the given operation.
Position * getAllResults(OperationPosition *p)
UsersPosition * getUsers(Position *p, bool useRepresentative)
Returns the users of a position using the value at the given operand.
Predicate getTypeConstraint(Attribute type)
Create a predicate comparing the type of an attribute or value to a known type.
OperationPosition * getOperandDefiningOp(Position *p)
Returns the parent position defining the value held by the given operand.
Predicate getResultCount(unsigned count)
Create a predicate comparing the number of results of an operation to a known value.
std::pair< Qualifier *, Qualifier * > Predicate
An ordinal predicate consists of a "Question" and a set of acceptable "Answers" (later converted to o...
Predicate getEqualTo(Position *pos)
Create a predicate checking if two values are equal.
Position * getAllOperands(OperationPosition *p)
Position * getAttributeLiteral(Attribute attr)
Returns an attribute position for the given attribute.
Predicate getConstraint(StringRef name, ArrayRef< Position * > args, ArrayRef< Type > resultTypes, bool isNegated)
Create a predicate that applies a generic constraint.
Predicate getOperationName(StringRef name)
Create a predicate comparing the name of an operation to a known value.
An ordinal predicate consists of a "Question" and a set of acceptable "Answers" (later converted to o...
Predicates::Kind getKind() const
Returns the kind of this qualifier.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
bool operator<(const Fraction &x, const Fraction &y)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
A position describing an attribute of an operation.
A BoolNode denotes a question with a boolean-like result.
BoolNode(Position *position, Qualifier *question, Qualifier *answer, std::unique_ptr< MatcherNode > successNode, std::unique_ptr< MatcherNode > failureNode=nullptr)
A position describing the result of a native constraint.
Apply a parameterized constraint to multiple position values and possibly produce results.
An operation position describes an operation node in the IR.
bool isRoot() const
Returns if this operation position corresponds to the root.
bool isOperandDefiningOp() const
Returns if this operation represents an operand defining op.
A PositionalPredicate is a predicate that is associated with a specific positional value.
The information associated with an edge in the cost graph.
Value connector
The connector value in the intersection of the two subtrees rooted at the source and target root that...
std::pair< unsigned, unsigned > cost
The depth of the connector Value w.r.t.
A SuccessNode denotes that a given high level pattern has successfully been matched.
SuccessNode(pdl::PatternOp pattern, Value root, std::unique_ptr< MatcherNode > failureNode)
A SwitchNode denotes a question with multiple potential results.
llvm::MapVector< Qualifier *, std::unique_ptr< MatcherNode > > ChildMapT
Returns the children of this switch node.
ChildMapT & getChildren()
SwitchNode(Position *position, Qualifier *question)
A position describing the result type of an entity, i.e.