17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/SmallPtrSet.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 #include "llvm/Support/Debug.h"
23 #define DEBUG_TYPE "pdl-predicate-tree"
44 return llvm::count_if(values.
getTypes(),
45 [](
Type type) { return !isa<pdl::RangeType>(type); });
52 assert(isa<pdl::AttributeType>(val.
getType()) &&
"expected attribute type");
55 if (
auto attr = dyn_cast<pdl::AttributeOp>(val.
getDefiningOp())) {
57 if (
Value type = attr.getValueType())
59 else if (
Attribute value = attr.getValueAttr())
70 bool isVariadic = isa<pdl::RangeType>(valueType);
74 .Case<pdl::OperandOp, pdl::OperandsOp>([&](
auto op) {
77 if (std::is_same<pdl::OperandOp, decltype(op)>::value ||
78 cast<OperandGroupPosition>(pos)->getOperandGroupNumber())
81 if (
Value type = op.getValueType())
85 .Case<pdl::ResultOp, pdl::ResultsOp>([&](
auto op) {
86 std::optional<unsigned> index = op.getIndex();
94 predList.emplace_back(parentPos, builder.
getIsNotNull());
99 if (std::is_same<pdl::ResultOp, decltype(op)>::value)
100 resultPos = builder.
getResult(parentPos, *index);
103 predList.emplace_back(resultPos, builder.
getEqualTo(pos));
115 std::optional<unsigned> ignoreOperand = std::nullopt) {
116 assert(isa<pdl::OperationType>(val.
getType()) &&
"expected operation");
117 pdl::OperationOp op = cast<pdl::OperationOp>(val.
getDefiningOp());
125 if (std::optional<StringRef> opName = op.getOpName())
132 if (minOperands != operands.size()) {
143 if (minResults == types.size())
149 for (
auto [attrName, attr] :
150 llvm::zip(op.getAttributeValueNames(), op.getAttributeValues())) {
152 predList, attr, builder, inputs,
162 if (operands.size() == 1 && isa<pdl::RangeType>(operands[0].
getType())) {
169 bool foundVariableLength =
false;
171 bool isVariadic = isa<pdl::RangeType>(operandIt.value().getType());
172 foundVariableLength |= isVariadic;
176 if (ignoreOperand && *ignoreOperand == operandIt.index())
182 : builder.
getOperand(opPos, operandIt.index());
187 if (types.size() == 1 && isa<pdl::RangeType>(types[0].
getType())) {
193 bool foundVariableLength =
false;
195 bool isVariadic = isa<pdl::RangeType>(typeValue.getType());
196 foundVariableLength |= isVariadic;
198 auto *resultPos = foundVariableLength
201 predList.emplace_back(resultPos, builder.
getIsNotNull());
213 if (
Attribute type = typeOp.getConstantTypeAttr())
215 }
else if (pdl::TypesOp typeOp = val.
getDefiningOp<pdl::TypesOp>()) {
216 if (
Attribute typeAttr = typeOp.getConstantTypesAttr())
227 auto it = inputs.try_emplace(val, pos);
231 if (isa<pdl::AttributeOp, pdl::OperandOp, pdl::OperandsOp, pdl::OperationOp,
233 auto minMaxPositions =
235 predList.emplace_back(minMaxPositions.second,
245 .Case<OperandPosition, OperandGroupPosition>([&](
auto *pos) {
248 .Default([](
auto *) { llvm_unreachable(
"unexpected position kind"); });
252 std::vector<PositionalPredicate> &predList,
259 assert(value &&
"expected non-tree `pdl.attribute` to contain a value");
264 std::vector<PositionalPredicate> &predList,
269 std::vector<Position *> allPositions;
270 allPositions.reserve(arguments.size());
271 for (
Value arg : arguments)
272 allPositions.push_back(inputs.lookup(arg));
285 auto [it, inserted] = inputs.try_emplace(result, pos);
292 std::tie(second, first) = std::make_pair(first, second);
294 predList.emplace_back(second, builder.
getEqualTo(first));
297 predList.emplace_back(pos, pred);
301 std::vector<PositionalPredicate> &predList,
309 auto *parentPos = cast<OperationPosition>(inputs.lookup(op.getParent()));
310 resultPos = builder.
getResult(parentPos, op.getIndex());
311 predList.emplace_back(resultPos, builder.
getIsNotNull());
315 std::vector<PositionalPredicate> &predList,
323 auto *parentPos = cast<OperationPosition>(inputs.lookup(op.getParent()));
324 bool isVariadic = isa<pdl::RangeType>(op.getType());
325 std::optional<unsigned> index = op.getIndex();
328 predList.emplace_back(resultPos, builder.
getIsNotNull());
335 Position *&typePos = inputs[typeValue];
340 "expected non-tree `pdl.type`/`pdl.types` to contain a value");
347 std::vector<PositionalPredicate> &predList,
350 for (
Operation &op : pattern.getBodyRegion().getOps()) {
352 .Case([&](pdl::AttributeOp attrOp) {
355 .Case<pdl::ApplyNativeConstraintOp>([&](
auto constraintOp) {
358 .Case<pdl::ResultOp, pdl::ResultsOp>([&](
auto resultOp) {
361 .Case([&](pdl::TypeOp typeOp) {
363 typeOp, [&] {
return typeOp.getConstantTypeAttr(); }, builder,
366 .Case([&](pdl::TypesOp typeOp) {
368 typeOp, [&] {
return typeOp.getConstantTypesAttr(); }, builder,
379 std::optional<unsigned> index;
394 for (
auto operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>()) {
395 for (
Value operand : operationOp.getOperandValues())
397 .Case<pdl::ResultOp, pdl::ResultsOp>(
398 [&used](
auto resultOp) { used.insert(resultOp.getParent()); });
403 if (
Value root = pattern.getRewriter().getRoot())
408 for (
Value operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>())
409 if (!used.contains(operationOp))
410 roots.push_back(operationOp);
423 ParentMaps &parentMaps) {
431 Entry(
Value value,
Value parent, std::optional<unsigned> index,
433 : value(value), parent(parent), index(index), depth(depth) {}
437 std::optional<unsigned> index;
449 llvm::MapVector<Value, SmallVector<RootDepth, 1>> connectorsRootsDepths;
452 for (
Value root : roots) {
456 std::queue<Entry> toVisit;
457 toVisit.emplace(root,
Value(), 0, 0);
462 while (!toVisit.empty()) {
463 Entry entry = toVisit.front();
466 if (!parentMap.insert({entry.value, {entry.parent, entry.index}}).second)
470 connectorsRootsDepths[entry.value].push_back({root, entry.depth});
476 .Case<pdl::OperationOp>([&](
auto operationOp) {
480 if (operands.size() == 1 &&
481 isa<pdl::RangeType>(operands[0].
getType())) {
482 toVisit.emplace(operands[0], entry.value, std::nullopt,
490 toVisit.emplace(p.value(), entry.value, p.index(),
493 .Case<pdl::ResultOp, pdl::ResultsOp>([&](
auto resultOp) {
494 toVisit.emplace(resultOp.getParent(), entry.value,
495 resultOp.getIndex(), entry.depth);
503 for (
const auto &connectorRootsDepths : connectorsRootsDepths) {
504 Value value = connectorRootsDepths.first;
508 if (rootsDepths.size() == 1)
511 for (
const RootDepth &p : rootsDepths) {
512 for (
const RootDepth &q : rootsDepths) {
519 entry.
cost.second = nextID++;
520 entry.
cost.first = q.depth;
527 assert((llvm::hasSingleElement(roots) || graph.size() == roots.size()) &&
528 "the pattern contains a candidate root disconnected from the others");
535 assert(index < operands.size() &&
"operand index out of range");
536 for (
unsigned i = 0; i <= index; ++i)
537 if (isa<pdl::RangeType>(operands[i].
getType()))
543 static void visitUpward(std::vector<PositionalPredicate> &predList,
547 Value value = opIndex.parent;
549 .Case<pdl::OperationOp>([&](
auto operationOp) {
550 LLVM_DEBUG(llvm::dbgs() <<
" * Value: " << value <<
"\n");
559 if (!opIndex.index) {
564 Type type = operationOp.getOperandValues()[*opIndex.index].getType();
565 bool variadic = isa<pdl::RangeType>(type);
569 operandPos = builder.
getOperand(opPos, *opIndex.index);
571 predList.emplace_back(operandPos, builder.
getEqualTo(pos));
577 bool inserted = valueToPosition.try_emplace(value, opPos).second;
579 assert(inserted &&
"duplicate upward visit");
588 .Case<pdl::ResultOp>([&](
auto resultOp) {
590 auto *opPos = dyn_cast<OperationPosition>(pos);
591 assert(opPos &&
"operations and results must be interleaved");
592 pos = builder.
getResult(opPos, *opIndex.index);
595 valueToPosition.try_emplace(value, pos);
597 .Case<pdl::ResultsOp>([&](
auto resultOp) {
599 auto *opPos = dyn_cast<OperationPosition>(pos);
600 assert(opPos &&
"operations and results must be interleaved");
601 bool isVariadic = isa<pdl::RangeType>(value.
getType());
608 valueToPosition.try_emplace(value, pos);
616 std::vector<PositionalPredicate> &predList,
622 ParentMaps parentMaps;
625 llvm::dbgs() <<
"Graph:\n";
626 for (
auto &target : graph) {
627 llvm::dbgs() <<
" * " << target.first.getLoc() <<
" " << target.first
629 for (
auto &source : target.second) {
631 llvm::dbgs() <<
" <- " << source.first <<
": " << entry.
cost.first
632 <<
":" << entry.
cost.second <<
" via "
640 Value bestRoot = pattern.getRewriter().getRoot();
643 unsigned bestCost = 0;
644 LLVM_DEBUG(llvm::dbgs() <<
"Candidate roots:\n");
645 for (
Value root : roots) {
647 unsigned cost = solver.
solve();
648 LLVM_DEBUG(llvm::dbgs() <<
" * " << root <<
": " << cost <<
"\n");
649 if (!bestRoot || bestCost > cost) {
663 llvm::dbgs() <<
"Best tree:\n";
664 for (
const std::pair<Value, Value> &edge : bestEdges) {
665 llvm::dbgs() <<
" * " << edge.first;
667 llvm::dbgs() <<
" <- " << edge.second;
668 llvm::dbgs() <<
"\n";
672 LLVM_DEBUG(llvm::dbgs() <<
"Calling key getTreePredicates:\n");
673 LLVM_DEBUG(llvm::dbgs() <<
" * Value: " << bestRoot <<
"\n");
684 Value target = it.value().first;
685 Value source = it.value().second;
691 if (valueToPosition.count(target))
695 Value connector = graph[target][source].connector;
696 assert(connector &&
"invalid edge");
697 LLVM_DEBUG(llvm::dbgs() <<
" * Connector: " << connector.
getLoc() <<
"\n");
699 Position *pos = valueToPosition.lookup(connector);
700 assert(pos &&
"connector has not been traversed yet");
703 for (
Value value = connector; value != target;) {
704 OpIndex opIndex = parentMap.lookup(value);
705 assert(opIndex.parent &&
"missing parent");
706 visitUpward(predList, opIndex, builder, valueToPosition, pos, it.index());
707 value = opIndex.parent;
725 struct OrderedPredicate {
726 OrderedPredicate(
const std::pair<Position *, Qualifier *> &ip)
727 : position(ip.first), question(ip.second) {}
729 : position(ip.position), question(ip.question) {}
740 unsigned primary = 0;
745 unsigned secondary = 0;
758 bool operator<(
const OrderedPredicate &rhs)
const {
765 auto *rhsPos = rhs.position;
766 return std::make_tuple(primary, secondary, rhsPos->getOperationDepth(),
767 rhsPos->getKind(), rhs.question->getKind(), rhs.id) >
768 std::make_tuple(rhs.primary, rhs.secondary,
776 struct OrderedPredicateDenseInfo {
779 static OrderedPredicate getEmptyKey() {
return Base::getEmptyKey(); }
780 static OrderedPredicate getTombstoneKey() {
return Base::getTombstoneKey(); }
781 static bool isEqual(
const OrderedPredicate &lhs,
782 const OrderedPredicate &rhs) {
783 return lhs.position == rhs.position && lhs.question == rhs.question;
785 static unsigned getHashValue(
const OrderedPredicate &p) {
786 return llvm::hash_combine(p.position, p.question);
792 struct OrderedPredicateList {
793 OrderedPredicateList(pdl::PatternOp pattern,
Value root)
794 : pattern(pattern), root(root) {}
796 pdl::PatternOp pattern;
806 return node->
getPosition() == predicate->position &&
813 OrderedPredicate *predicate,
814 pdl::PatternOp pattern) {
816 "expected matcher to equal the given predicate");
818 auto it = predicate->patternToAnswer.find(pattern);
819 assert(it != predicate->patternToAnswer.end() &&
820 "expected pattern to exist in predicate");
821 return node->
getChildren().insert({it->second,
nullptr}).first->second;
829 OrderedPredicateList &list,
830 std::vector<OrderedPredicate *>::iterator current,
831 std::vector<OrderedPredicate *>::iterator end) {
832 if (current == end) {
835 std::make_unique<SuccessNode>(list.pattern, list.root, std::move(node));
838 }
else if (!list.predicates.contains(*current)) {
845 node = std::make_unique<SwitchNode>((*current)->position,
846 (*current)->question);
849 list, std::next(current), end);
856 list, std::next(current), end);
871 if (
SwitchNode *switchNode = dyn_cast<SwitchNode>(&*node)) {
873 for (
auto &it : children)
878 if (children.size() == 1) {
879 auto *childIt = children.begin();
880 node = std::make_unique<BoolNode>(
881 node->getPosition(), node->getQuestion(), childIt->first,
882 std::move(childIt->second), std::move(node->getFailureNode()));
884 }
else if (
BoolNode *boolNode = dyn_cast<BoolNode>(&*node)) {
894 root = &(*root)->getFailureNode();
895 *root = std::make_unique<ExitNode>();
899 template <
typename Iterator,
typename Compare>
901 while (begin != end) {
906 for (
auto i = begin; i != end; ++i) {
907 if (std::none_of(begin, end, [&](
auto const &b) {
return cmp(b, *i); }))
908 sortBeforeOthers.insert(*i);
911 auto const next = std::stable_partition(begin, end, [&](
auto const &a) {
912 return sortBeforeOthers.contains(a);
914 assert(next != begin &&
"not a partial ordering");
920 static bool dependsOn(OrderedPredicate *a, OrderedPredicate *b) {
921 auto *cqa = dyn_cast<ConstraintQuestion>(a->question);
925 auto positionDependsOnA = [&](
Position *p) {
926 auto *cp = dyn_cast<ConstraintPosition>(p);
927 return cp && cp->getQuestion() == cqa;
930 if (
auto *cqb = dyn_cast<ConstraintQuestion>(b->question)) {
932 return llvm::any_of(cqb->getArgs(), positionDependsOnA);
934 if (
auto *equalTo = dyn_cast<EqualToQuestion>(b->question)) {
935 return positionDependsOnA(b->position) ||
936 positionDependsOnA(equalTo->getValue());
938 return positionDependsOnA(b->position);
943 std::unique_ptr<MatcherNode>
948 struct PatternPredicates {
949 PatternPredicates(pdl::PatternOp pattern,
Value root,
950 std::vector<PositionalPredicate> predicates)
951 : pattern(pattern), root(root), predicates(std::move(predicates)) {}
954 pdl::PatternOp pattern;
960 std::vector<PositionalPredicate> predicates;
964 for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
965 std::vector<PositionalPredicate> predicateList;
968 patternsAndPredicates.emplace_back(pattern, root, std::move(predicateList));
973 for (
auto &patternAndPredList : patternsAndPredicates) {
974 for (
auto &predicate : patternAndPredList.predicates) {
975 auto it = uniqued.insert(predicate);
976 it.first->patternToAnswer.try_emplace(patternAndPredList.pattern,
980 it.first->id = uniqued.size() - 1;
985 std::vector<OrderedPredicateList> lists;
986 lists.reserve(patternsAndPredicates.size());
987 for (
auto &patternAndPredList : patternsAndPredicates) {
988 OrderedPredicateList list(patternAndPredList.pattern,
989 patternAndPredList.root);
990 for (
auto &predicate : patternAndPredList.predicates) {
991 OrderedPredicate *orderedPredicate = &*uniqued.find(predicate);
992 list.predicates.insert(orderedPredicate);
995 ++orderedPredicate->primary;
997 lists.push_back(std::move(list));
1003 for (
auto &list : lists) {
1005 for (
auto *predicate : list.predicates)
1006 total += predicate->primary * predicate->primary;
1007 for (
auto *predicate : list.predicates)
1008 predicate->secondary += total;
1013 std::vector<OrderedPredicate *> ordered;
1014 ordered.reserve(uniqued.size());
1015 for (
auto &ip : uniqued)
1016 ordered.push_back(&ip);
1017 llvm::sort(ordered, [](OrderedPredicate *lhs, OrderedPredicate *rhs) {
1026 std::unique_ptr<MatcherNode> root;
1027 for (OrderedPredicateList &list : lists)
1041 std::unique_ptr<MatcherNode> failureNode)
1042 : position(p), question(q), failureNode(std::move(failureNode)),
1043 matcherTypeID(matcherTypeID) {}
1050 std::unique_ptr<MatcherNode> successNode,
1051 std::unique_ptr<MatcherNode> failureNode)
1053 std::move(failureNode)),
1054 answer(answer), successNode(std::move(successNode)) {}
1061 std::unique_ptr<MatcherNode> failureNode)
1063 nullptr, std::move(failureNode)),
1064 pattern(pattern), root(root) {}
static Value buildPredicateList(pdl::PatternOp pattern, PredicateBuilder &builder, std::vector< PositionalPredicate > &predList, DenseMap< Value, Position * > &valueToPosition)
Given a pattern operation, build the set of matcher predicates necessary to match this pattern.
static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
static bool dependsOn(OrderedPredicate *a, OrderedPredicate *b)
Returns true if 'b' depends on a result of 'a'.
static void getTypePredicates(Value typeValue, function_ref< Attribute()> typeAttrFn, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
static void getNonTreePredicates(pdl::PatternOp pattern, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
Collect all of the predicates that cannot be determined via walking the tree.
static bool useOperandGroup(pdl::OperationOp op, unsigned index)
Returns true if the operand at the given index needs to be queried using an operand group,...
static void getTreePredicates(std::vector< PositionalPredicate > &predList, Value val, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs, Position *pos)
Collect the tree predicates anchored at the given value.
static void getResultPredicates(pdl::ResultOp op, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
static void getOperandTreePredicates(std::vector< PositionalPredicate > &predList, Value val, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs, Position *pos)
Collect all of the predicates for the given operand position.
std::unique_ptr< MatcherNode > & getOrCreateChild(SwitchNode *node, OrderedPredicate *predicate, pdl::PatternOp pattern)
Get or insert a child matcher for the given parent switch node, given a predicate and parent pattern.
static bool comparePosDepth(Position *lhs, Position *rhs)
Compares the depths of two positions.
static void visitUpward(std::vector< PositionalPredicate > &predList, OpIndex opIndex, PredicateBuilder &builder, DenseMap< Value, Position * > &valueToPosition, Position *&pos, unsigned rootID)
Visit a node during upward traversal.
static unsigned getNumNonRangeValues(ValueRange values)
Returns the number of non-range elements within values.
static SmallVector< Value > detectRoots(pdl::PatternOp pattern)
Given a pattern, determines the set of roots present in this pattern.
static void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp)
Sorts the range begin/end with the partial order given by cmp.
static void insertExitNode(std::unique_ptr< MatcherNode > *root)
Insert an exit node at the end of the failure path of the root.
static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate)
Returns true if the given matcher refers to the same predicate as the given ordered predicate.
static void foldSwitchToBool(std::unique_ptr< MatcherNode > &node)
Fold any switch nodes nested under node to boolean nodes when possible.
static void buildCostGraph(ArrayRef< Value > roots, RootOrderingGraph &graph, ParentMaps &parentMaps)
Given a list of candidate roots, builds the cost graph for connecting them.
static void getAttributePredicates(pdl::AttributeOp op, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
static void propagatePattern(std::unique_ptr< MatcherNode > &node, OrderedPredicateList &list, std::vector< OrderedPredicate * >::iterator current, std::vector< OrderedPredicate * >::iterator end)
Build the matcher CFG by "pushing" patterns through by sorted predicate order.
Attributes are known-constant values of operations.
This class implements the operand iterators for the Operation class.
type_range getType() const
Operation is the basic unit of execution within MLIR.
OperationName getName()
The name of an operation is the key identifier for it.
result_range getResults()
This class implements the result iterators for the Operation class.
type_range getTypes() const
This class provides an efficient unique identifier for a specific C++ type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This class represents the base of a predicate matcher node.
Position * getPosition() const
Returns the position on which the question predicate should be checked.
Qualifier * getQuestion() const
Returns the predicate checked on this node.
The optimal branching algorithm solver.
unsigned solve()
Runs the Edmonds' algorithm for the current graph, returning the total cost of the minimum-weight spa...
std::vector< std::pair< Value, Value > > EdgeList
A list of edges (child, parent).
EdgeList preOrderTraversal(ArrayRef< Value > nodes) const
Returns the computed edges as visited in the preorder traversal.
A position describes a value on the input IR on which a predicate may be applied, such as an operatio...
unsigned getOperationDepth() const
Returns the depth of the first ancestor operation position.
Predicates::Kind getKind() const
Returns the kind of this position.
const KeyTy & getValue() const
Return the key value of this predicate.
This class provides utilities for constructing predicates.
ConstraintPosition * getConstraintPosition(ConstraintQuestion *q, unsigned index)
Position * getTypeLiteral(Attribute attr)
Returns a type position for the given type value.
Predicate getOperandCount(unsigned count)
Create a predicate comparing the number of operands of an operation to a known value.
OperationPosition * getPassthroughOp(Position *p)
Returns the operation position equivalent to the given position.
Predicate getIsNotNull()
Create a predicate comparing a value with null.
Predicate getOperandCountAtLeast(unsigned count)
Predicate getResultCountAtLeast(unsigned count)
Position * getType(Position *p)
Returns a type position for the given entity.
Position * getAttribute(OperationPosition *p, StringRef name)
Returns an attribute position for an attribute of the given operation.
Position * getOperandGroup(OperationPosition *p, std::optional< unsigned > group, bool isVariadic)
Returns a position for a group of operands of the given operation.
Position * getForEach(Position *p, unsigned id)
Position * getOperand(OperationPosition *p, unsigned operand)
Returns an operand position for an operand of the given operation.
Position * getResult(OperationPosition *p, unsigned result)
Returns a result position for a result of the given operation.
Position * getRoot()
Returns the root operation position.
Predicate getAttributeConstraint(Attribute attr)
Create a predicate comparing an attribute to a known value.
Position * getResultGroup(OperationPosition *p, std::optional< unsigned > group, bool isVariadic)
Returns a position for a group of results of the given operation.
Position * getAllResults(OperationPosition *p)
UsersPosition * getUsers(Position *p, bool useRepresentative)
Returns the users of a position using the value at the given operand.
Predicate getTypeConstraint(Attribute type)
Create a predicate comparing the type of an attribute or value to a known type.
OperationPosition * getOperandDefiningOp(Position *p)
Returns the parent position defining the value held by the given operand.
Predicate getResultCount(unsigned count)
Create a predicate comparing the number of results of an operation to a known value.
std::pair< Qualifier *, Qualifier * > Predicate
An ordinal predicate consists of a "Question" and a set of acceptable "Answers" (later converted to o...
Predicate getEqualTo(Position *pos)
Create a predicate checking if two values are equal.
Position * getAllOperands(OperationPosition *p)
Position * getAttributeLiteral(Attribute attr)
Returns an attribute position for the given attribute.
Predicate getConstraint(StringRef name, ArrayRef< Position * > args, ArrayRef< Type > resultTypes, bool isNegated)
Create a predicate that applies a generic constraint.
Predicate getOperationName(StringRef name)
Create a predicate comparing the name of an operation to a known value.
An ordinal predicate consists of a "Question" and a set of acceptable "Answers" (later converted to o...
Predicates::Kind getKind() const
Returns the kind of this qualifier.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
bool operator<(const Fraction &x, const Fraction &y)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
A position describing an attribute of an operation.
A BoolNode denotes a question with a boolean-like result.
BoolNode(Position *position, Qualifier *question, Qualifier *answer, std::unique_ptr< MatcherNode > successNode, std::unique_ptr< MatcherNode > failureNode=nullptr)
A position describing the result of a native constraint.
Apply a parameterized constraint to multiple position values and possibly produce results.
An operation position describes an operation node in the IR.
bool isRoot() const
Returns if this operation position corresponds to the root.
bool isOperandDefiningOp() const
Returns if this operation represents an operand defining op.
A PositionalPredicate is a predicate that is associated with a specific positional value.
The information associated with an edge in the cost graph.
Value connector
The connector value in the intersection of the two subtrees rooted at the source and target root that...
std::pair< unsigned, unsigned > cost
The depth of the connector Value w.r.t.
A SuccessNode denotes that a given high level pattern has successfully been matched.
SuccessNode(pdl::PatternOp pattern, Value root, std::unique_ptr< MatcherNode > failureNode)
A SwitchNode denotes a question with multiple potential results.
llvm::MapVector< Qualifier *, std::unique_ptr< MatcherNode > > ChildMapT
Returns the children of this switch node.
ChildMapT & getChildren()
SwitchNode(Position *position, Qualifier *question)
A position describing the result type of an entity, i.e.