14 #include "llvm/ADT/MapVector.h"
15 #include "llvm/ADT/SmallPtrSet.h"
16 #include "llvm/ADT/TypeSwitch.h"
17 #include "llvm/Support/Debug.h"
20 #define DEBUG_TYPE "pdl-predicate-tree"
41 return llvm::count_if(values.
getTypes(),
42 [](
Type type) { return !isa<pdl::RangeType>(type); });
49 assert(isa<pdl::AttributeType>(val.
getType()) &&
"expected attribute type");
52 if (
auto attr = dyn_cast<pdl::AttributeOp>(val.
getDefiningOp())) {
54 if (
Value type = attr.getValueType())
56 else if (
Attribute value = attr.getValueAttr())
67 bool isVariadic = isa<pdl::RangeType>(valueType);
71 .Case<pdl::OperandOp, pdl::OperandsOp>([&](
auto op) {
74 if (std::is_same<pdl::OperandOp, decltype(op)>::value ||
75 cast<OperandGroupPosition>(pos)->getOperandGroupNumber())
78 if (
Value type = op.getValueType())
82 .Case<pdl::ResultOp, pdl::ResultsOp>([&](
auto op) {
83 std::optional<unsigned> index = op.getIndex();
91 predList.emplace_back(parentPos, builder.
getIsNotNull());
96 if (std::is_same<pdl::ResultOp, decltype(op)>::value)
97 resultPos = builder.
getResult(parentPos, *index);
100 predList.emplace_back(resultPos, builder.
getEqualTo(pos));
112 std::optional<unsigned> ignoreOperand = std::nullopt) {
113 assert(isa<pdl::OperationType>(val.
getType()) &&
"expected operation");
114 pdl::OperationOp op = cast<pdl::OperationOp>(val.
getDefiningOp());
122 if (std::optional<StringRef> opName = op.getOpName())
129 if (minOperands != operands.size()) {
140 if (minResults == types.size())
146 for (
auto [attrName, attr] :
147 llvm::zip(op.getAttributeValueNames(), op.getAttributeValues())) {
149 predList, attr, builder, inputs,
159 if (operands.size() == 1 && isa<pdl::RangeType>(operands[0].
getType())) {
166 bool foundVariableLength =
false;
168 bool isVariadic = isa<pdl::RangeType>(operandIt.value().getType());
169 foundVariableLength |= isVariadic;
173 if (ignoreOperand == operandIt.index())
179 : builder.
getOperand(opPos, operandIt.index());
184 if (types.size() == 1 && isa<pdl::RangeType>(types[0].
getType())) {
190 bool foundVariableLength =
false;
192 bool isVariadic = isa<pdl::RangeType>(typeValue.getType());
193 foundVariableLength |= isVariadic;
195 auto *resultPos = foundVariableLength
198 predList.emplace_back(resultPos, builder.
getIsNotNull());
210 if (
Attribute type = typeOp.getConstantTypeAttr())
212 }
else if (pdl::TypesOp typeOp = val.
getDefiningOp<pdl::TypesOp>()) {
213 if (
Attribute typeAttr = typeOp.getConstantTypesAttr())
224 auto it = inputs.try_emplace(val, pos);
228 if (isa<pdl::AttributeOp, pdl::OperandOp, pdl::OperandsOp, pdl::OperationOp,
230 auto minMaxPositions =
232 predList.emplace_back(minMaxPositions.second,
242 .Case<OperandPosition, OperandGroupPosition>([&](
auto *pos) {
245 .Default([](
auto *) { llvm_unreachable(
"unexpected position kind"); });
249 std::vector<PositionalPredicate> &predList,
256 assert(value &&
"expected non-tree `pdl.attribute` to contain a value");
261 std::vector<PositionalPredicate> &predList,
266 std::vector<Position *> allPositions;
267 allPositions.reserve(arguments.size());
268 for (
Value arg : arguments)
269 allPositions.push_back(inputs.lookup(arg));
282 auto [it, inserted] = inputs.try_emplace(result, pos);
289 std::tie(second, first) = std::make_pair(first, second);
291 predList.emplace_back(second, builder.
getEqualTo(first));
294 predList.emplace_back(pos, pred);
298 std::vector<PositionalPredicate> &predList,
306 auto *parentPos = cast<OperationPosition>(inputs.lookup(op.getParent()));
307 resultPos = builder.
getResult(parentPos, op.getIndex());
308 predList.emplace_back(resultPos, builder.
getIsNotNull());
312 std::vector<PositionalPredicate> &predList,
320 auto *parentPos = cast<OperationPosition>(inputs.lookup(op.getParent()));
321 bool isVariadic = isa<pdl::RangeType>(op.getType());
322 std::optional<unsigned> index = op.getIndex();
325 predList.emplace_back(resultPos, builder.
getIsNotNull());
332 Position *&typePos = inputs[typeValue];
337 "expected non-tree `pdl.type`/`pdl.types` to contain a value");
344 std::vector<PositionalPredicate> &predList,
347 for (
Operation &op : pattern.getBodyRegion().getOps()) {
349 .Case([&](pdl::AttributeOp attrOp) {
352 .Case<pdl::ApplyNativeConstraintOp>([&](
auto constraintOp) {
355 .Case<pdl::ResultOp, pdl::ResultsOp>([&](
auto resultOp) {
358 .Case([&](pdl::TypeOp typeOp) {
360 typeOp, [&] {
return typeOp.getConstantTypeAttr(); }, builder,
363 .Case([&](pdl::TypesOp typeOp) {
365 typeOp, [&] {
return typeOp.getConstantTypesAttr(); }, builder,
376 std::optional<unsigned> index;
391 for (
auto operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>()) {
392 for (
Value operand : operationOp.getOperandValues())
394 .Case<pdl::ResultOp, pdl::ResultsOp>(
395 [&used](
auto resultOp) { used.insert(resultOp.getParent()); });
400 if (
Value root = pattern.getRewriter().getRoot())
405 for (
Value operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>())
406 if (!used.contains(operationOp))
407 roots.push_back(operationOp);
420 ParentMaps &parentMaps) {
428 Entry(
Value value,
Value parent, std::optional<unsigned> index,
430 : value(value), parent(parent), index(index), depth(depth) {}
434 std::optional<unsigned> index;
446 llvm::MapVector<Value, SmallVector<RootDepth, 1>> connectorsRootsDepths;
449 for (
Value root : roots) {
453 std::queue<Entry> toVisit;
454 toVisit.emplace(root,
Value(), 0, 0);
459 while (!toVisit.empty()) {
460 Entry entry = toVisit.front();
463 if (!parentMap.insert({entry.value, {entry.parent, entry.index}}).second)
467 connectorsRootsDepths[entry.value].push_back({root, entry.depth});
473 .Case<pdl::OperationOp>([&](
auto operationOp) {
477 if (operands.size() == 1 &&
478 isa<pdl::RangeType>(operands[0].
getType())) {
479 toVisit.emplace(operands[0], entry.value, std::nullopt,
487 toVisit.emplace(p.value(), entry.value, p.index(),
490 .Case<pdl::ResultOp, pdl::ResultsOp>([&](
auto resultOp) {
491 toVisit.emplace(resultOp.getParent(), entry.value,
492 resultOp.getIndex(), entry.depth);
500 for (
const auto &connectorRootsDepths : connectorsRootsDepths) {
501 Value value = connectorRootsDepths.first;
505 if (rootsDepths.size() == 1)
508 for (
const RootDepth &p : rootsDepths) {
509 for (
const RootDepth &q : rootsDepths) {
516 entry.
cost.second = nextID++;
517 entry.
cost.first = q.depth;
524 assert((llvm::hasSingleElement(roots) || graph.size() == roots.size()) &&
525 "the pattern contains a candidate root disconnected from the others");
532 assert(index < operands.size() &&
"operand index out of range");
533 for (
unsigned i = 0; i <= index; ++i)
534 if (isa<pdl::RangeType>(operands[i].
getType()))
540 static void visitUpward(std::vector<PositionalPredicate> &predList,
544 Value value = opIndex.parent;
546 .Case<pdl::OperationOp>([&](
auto operationOp) {
547 LLVM_DEBUG(llvm::dbgs() <<
" * Value: " << value <<
"\n");
556 if (!opIndex.index) {
561 Type type = operationOp.getOperandValues()[*opIndex.index].getType();
562 bool variadic = isa<pdl::RangeType>(type);
566 operandPos = builder.
getOperand(opPos, *opIndex.index);
568 predList.emplace_back(operandPos, builder.
getEqualTo(pos));
574 bool inserted = valueToPosition.try_emplace(value, opPos).second;
576 assert(inserted &&
"duplicate upward visit");
585 .Case<pdl::ResultOp>([&](
auto resultOp) {
587 auto *opPos = dyn_cast<OperationPosition>(pos);
588 assert(opPos &&
"operations and results must be interleaved");
589 pos = builder.
getResult(opPos, *opIndex.index);
592 valueToPosition.try_emplace(value, pos);
594 .Case<pdl::ResultsOp>([&](
auto resultOp) {
596 auto *opPos = dyn_cast<OperationPosition>(pos);
597 assert(opPos &&
"operations and results must be interleaved");
598 bool isVariadic = isa<pdl::RangeType>(value.
getType());
605 valueToPosition.try_emplace(value, pos);
613 std::vector<PositionalPredicate> &predList,
619 ParentMaps parentMaps;
622 llvm::dbgs() <<
"Graph:\n";
623 for (
auto &target : graph) {
624 llvm::dbgs() <<
" * " << target.first.getLoc() <<
" " << target.first
626 for (
auto &source : target.second) {
628 llvm::dbgs() <<
" <- " << source.first <<
": " << entry.
cost.first
629 <<
":" << entry.
cost.second <<
" via "
637 Value bestRoot = pattern.getRewriter().getRoot();
640 unsigned bestCost = 0;
641 LLVM_DEBUG(llvm::dbgs() <<
"Candidate roots:\n");
642 for (
Value root : roots) {
644 unsigned cost = solver.
solve();
645 LLVM_DEBUG(llvm::dbgs() <<
" * " << root <<
": " << cost <<
"\n");
646 if (!bestRoot || bestCost > cost) {
660 llvm::dbgs() <<
"Best tree:\n";
661 for (
const std::pair<Value, Value> &edge : bestEdges) {
662 llvm::dbgs() <<
" * " << edge.first;
664 llvm::dbgs() <<
" <- " << edge.second;
665 llvm::dbgs() <<
"\n";
669 LLVM_DEBUG(llvm::dbgs() <<
"Calling key getTreePredicates:\n");
670 LLVM_DEBUG(llvm::dbgs() <<
" * Value: " << bestRoot <<
"\n");
681 Value target = it.value().first;
682 Value source = it.value().second;
688 if (valueToPosition.count(target))
692 Value connector = graph[target][source].connector;
693 assert(connector &&
"invalid edge");
694 LLVM_DEBUG(llvm::dbgs() <<
" * Connector: " << connector.
getLoc() <<
"\n");
696 Position *pos = valueToPosition.lookup(connector);
697 assert(pos &&
"connector has not been traversed yet");
700 for (
Value value = connector; value != target;) {
701 OpIndex opIndex = parentMap.lookup(value);
702 assert(opIndex.parent &&
"missing parent");
703 visitUpward(predList, opIndex, builder, valueToPosition, pos, it.index());
704 value = opIndex.parent;
722 struct OrderedPredicate {
723 OrderedPredicate(
const std::pair<Position *, Qualifier *> &ip)
724 : position(ip.first), question(ip.second) {}
726 : position(ip.position), question(ip.question) {}
737 unsigned primary = 0;
742 unsigned secondary = 0;
755 bool operator<(
const OrderedPredicate &rhs)
const {
762 auto *rhsPos = rhs.position;
763 return std::make_tuple(primary, secondary, rhsPos->getOperationDepth(),
764 rhsPos->getKind(), rhs.question->getKind(), rhs.id) >
765 std::make_tuple(rhs.primary, rhs.secondary,
773 struct OrderedPredicateDenseInfo {
776 static OrderedPredicate getEmptyKey() {
return Base::getEmptyKey(); }
777 static OrderedPredicate getTombstoneKey() {
return Base::getTombstoneKey(); }
778 static bool isEqual(
const OrderedPredicate &lhs,
779 const OrderedPredicate &rhs) {
780 return lhs.position == rhs.position && lhs.question == rhs.question;
782 static unsigned getHashValue(
const OrderedPredicate &p) {
783 return llvm::hash_combine(p.position, p.question);
789 struct OrderedPredicateList {
790 OrderedPredicateList(pdl::PatternOp pattern,
Value root)
791 : pattern(pattern), root(root) {}
793 pdl::PatternOp pattern;
803 return node->
getPosition() == predicate->position &&
810 OrderedPredicate *predicate,
811 pdl::PatternOp pattern) {
813 "expected matcher to equal the given predicate");
815 auto it = predicate->patternToAnswer.find(pattern);
816 assert(it != predicate->patternToAnswer.end() &&
817 "expected pattern to exist in predicate");
826 OrderedPredicateList &list,
827 std::vector<OrderedPredicate *>::iterator current,
828 std::vector<OrderedPredicate *>::iterator end) {
829 if (current == end) {
832 std::make_unique<SuccessNode>(list.pattern, list.root, std::move(node));
835 }
else if (!list.predicates.contains(*current)) {
842 node = std::make_unique<SwitchNode>((*current)->position,
843 (*current)->question);
846 list, std::next(current), end);
853 list, std::next(current), end);
868 if (
SwitchNode *switchNode = dyn_cast<SwitchNode>(&*node)) {
870 for (
auto &it : children)
875 if (children.size() == 1) {
876 auto *childIt = children.begin();
877 node = std::make_unique<BoolNode>(
878 node->getPosition(), node->getQuestion(), childIt->first,
879 std::move(childIt->second), std::move(node->getFailureNode()));
881 }
else if (
BoolNode *boolNode = dyn_cast<BoolNode>(&*node)) {
891 root = &(*root)->getFailureNode();
892 *root = std::make_unique<ExitNode>();
896 template <
typename Iterator,
typename Compare>
898 while (begin != end) {
903 for (
auto i = begin; i != end; ++i) {
904 if (std::none_of(begin, end, [&](
auto const &b) {
return cmp(b, *i); }))
905 sortBeforeOthers.insert(*i);
908 auto const next = std::stable_partition(begin, end, [&](
auto const &a) {
909 return sortBeforeOthers.contains(a);
911 assert(next != begin &&
"not a partial ordering");
917 static bool dependsOn(OrderedPredicate *a, OrderedPredicate *b) {
918 auto *cqa = dyn_cast<ConstraintQuestion>(a->question);
922 auto positionDependsOnA = [&](
Position *p) {
923 auto *cp = dyn_cast<ConstraintPosition>(p);
924 return cp && cp->getQuestion() == cqa;
927 if (
auto *cqb = dyn_cast<ConstraintQuestion>(b->question)) {
929 return llvm::any_of(cqb->getArgs(), positionDependsOnA);
931 if (
auto *equalTo = dyn_cast<EqualToQuestion>(b->question)) {
932 return positionDependsOnA(b->position) ||
933 positionDependsOnA(equalTo->getValue());
935 return positionDependsOnA(b->position);
940 std::unique_ptr<MatcherNode>
945 struct PatternPredicates {
946 PatternPredicates(pdl::PatternOp pattern,
Value root,
947 std::vector<PositionalPredicate> predicates)
948 : pattern(pattern), root(root), predicates(std::move(predicates)) {}
951 pdl::PatternOp pattern;
957 std::vector<PositionalPredicate> predicates;
961 for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
962 std::vector<PositionalPredicate> predicateList;
965 patternsAndPredicates.emplace_back(pattern, root, std::move(predicateList));
970 for (
auto &patternAndPredList : patternsAndPredicates) {
971 for (
auto &predicate : patternAndPredList.predicates) {
972 auto it = uniqued.insert(predicate);
973 it.first->patternToAnswer.try_emplace(patternAndPredList.pattern,
977 it.first->id = uniqued.size() - 1;
982 std::vector<OrderedPredicateList> lists;
983 lists.reserve(patternsAndPredicates.size());
984 for (
auto &patternAndPredList : patternsAndPredicates) {
985 OrderedPredicateList list(patternAndPredList.pattern,
986 patternAndPredList.root);
987 for (
auto &predicate : patternAndPredList.predicates) {
988 OrderedPredicate *orderedPredicate = &*uniqued.find(predicate);
989 list.predicates.insert(orderedPredicate);
992 ++orderedPredicate->primary;
994 lists.push_back(std::move(list));
1000 for (
auto &list : lists) {
1002 for (
auto *predicate : list.predicates)
1003 total += predicate->primary * predicate->primary;
1004 for (
auto *predicate : list.predicates)
1005 predicate->secondary += total;
1010 std::vector<OrderedPredicate *> ordered;
1011 ordered.reserve(uniqued.size());
1012 for (
auto &ip : uniqued)
1013 ordered.push_back(&ip);
1014 llvm::sort(ordered, [](OrderedPredicate *lhs, OrderedPredicate *rhs) {
1023 std::unique_ptr<MatcherNode> root;
1024 for (OrderedPredicateList &list : lists)
1038 std::unique_ptr<MatcherNode> failureNode)
1039 : position(p), question(q), failureNode(std::move(failureNode)),
1040 matcherTypeID(matcherTypeID) {}
1047 std::unique_ptr<MatcherNode> successNode,
1048 std::unique_ptr<MatcherNode> failureNode)
1050 std::move(failureNode)),
1051 answer(answer), successNode(std::move(successNode)) {}
1058 std::unique_ptr<MatcherNode> failureNode)
1060 nullptr, std::move(failureNode)),
1061 pattern(pattern), root(root) {}
static Value buildPredicateList(pdl::PatternOp pattern, PredicateBuilder &builder, std::vector< PositionalPredicate > &predList, DenseMap< Value, Position * > &valueToPosition)
Given a pattern operation, build the set of matcher predicates necessary to match this pattern.
static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
static bool dependsOn(OrderedPredicate *a, OrderedPredicate *b)
Returns true if 'b' depends on a result of 'a'.
static void getTypePredicates(Value typeValue, function_ref< Attribute()> typeAttrFn, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
static void getNonTreePredicates(pdl::PatternOp pattern, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
Collect all of the predicates that cannot be determined via walking the tree.
static bool useOperandGroup(pdl::OperationOp op, unsigned index)
Returns true if the operand at the given index needs to be queried using an operand group,...
static void getTreePredicates(std::vector< PositionalPredicate > &predList, Value val, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs, Position *pos)
Collect the tree predicates anchored at the given value.
static void getResultPredicates(pdl::ResultOp op, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
static void getOperandTreePredicates(std::vector< PositionalPredicate > &predList, Value val, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs, Position *pos)
Collect all of the predicates for the given operand position.
std::unique_ptr< MatcherNode > & getOrCreateChild(SwitchNode *node, OrderedPredicate *predicate, pdl::PatternOp pattern)
Get or insert a child matcher for the given parent switch node, given a predicate and parent pattern.
static bool comparePosDepth(Position *lhs, Position *rhs)
Compares the depths of two positions.
static void visitUpward(std::vector< PositionalPredicate > &predList, OpIndex opIndex, PredicateBuilder &builder, DenseMap< Value, Position * > &valueToPosition, Position *&pos, unsigned rootID)
Visit a node during upward traversal.
static unsigned getNumNonRangeValues(ValueRange values)
Returns the number of non-range elements within values.
static SmallVector< Value > detectRoots(pdl::PatternOp pattern)
Given a pattern, determines the set of roots present in this pattern.
static void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp)
Sorts the range begin/end with the partial order given by cmp.
static void insertExitNode(std::unique_ptr< MatcherNode > *root)
Insert an exit node at the end of the failure path of the root.
static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate)
Returns true if the given matcher refers to the same predicate as the given ordered predicate.
static void foldSwitchToBool(std::unique_ptr< MatcherNode > &node)
Fold any switch nodes nested under node to boolean nodes when possible.
static void buildCostGraph(ArrayRef< Value > roots, RootOrderingGraph &graph, ParentMaps &parentMaps)
Given a list of candidate roots, builds the cost graph for connecting them.
static void getAttributePredicates(pdl::AttributeOp op, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
static void propagatePattern(std::unique_ptr< MatcherNode > &node, OrderedPredicateList &list, std::vector< OrderedPredicate * >::iterator current, std::vector< OrderedPredicate * >::iterator end)
Build the matcher CFG by "pushing" patterns through by sorted predicate order.
Attributes are known-constant values of operations.
This class implements the operand iterators for the Operation class.
type_range getType() const
Operation is the basic unit of execution within MLIR.
This class implements the result iterators for the Operation class.
type_range getTypes() const
This class provides an efficient unique identifier for a specific C++ type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This class represents the base of a predicate matcher node.
Position * getPosition() const
Returns the position on which the question predicate should be checked.
Qualifier * getQuestion() const
Returns the predicate checked on this node.
The optimal branching algorithm solver.
unsigned solve()
Runs the Edmonds' algorithm for the current graph, returning the total cost of the minimum-weight spa...
std::vector< std::pair< Value, Value > > EdgeList
A list of edges (child, parent).
EdgeList preOrderTraversal(ArrayRef< Value > nodes) const
Returns the computed edges as visited in the preorder traversal.
A position describes a value on the input IR on which a predicate may be applied, such as an operatio...
unsigned getOperationDepth() const
Returns the depth of the first ancestor operation position.
Predicates::Kind getKind() const
Returns the kind of this position.
const KeyTy & getValue() const
Return the key value of this predicate.
This class provides utilities for constructing predicates.
ConstraintPosition * getConstraintPosition(ConstraintQuestion *q, unsigned index)
Position * getTypeLiteral(Attribute attr)
Returns a type position for the given type value.
Predicate getOperandCount(unsigned count)
Create a predicate comparing the number of operands of an operation to a known value.
OperationPosition * getPassthroughOp(Position *p)
Returns the operation position equivalent to the given position.
Predicate getIsNotNull()
Create a predicate comparing a value with null.
Predicate getOperandCountAtLeast(unsigned count)
Predicate getResultCountAtLeast(unsigned count)
Position * getType(Position *p)
Returns a type position for the given entity.
Position * getAttribute(OperationPosition *p, StringRef name)
Returns an attribute position for an attribute of the given operation.
Position * getOperandGroup(OperationPosition *p, std::optional< unsigned > group, bool isVariadic)
Returns a position for a group of operands of the given operation.
Position * getForEach(Position *p, unsigned id)
Position * getOperand(OperationPosition *p, unsigned operand)
Returns an operand position for an operand of the given operation.
Position * getResult(OperationPosition *p, unsigned result)
Returns a result position for a result of the given operation.
Position * getRoot()
Returns the root operation position.
Predicate getAttributeConstraint(Attribute attr)
Create a predicate comparing an attribute to a known value.
Position * getResultGroup(OperationPosition *p, std::optional< unsigned > group, bool isVariadic)
Returns a position for a group of results of the given operation.
Position * getAllResults(OperationPosition *p)
UsersPosition * getUsers(Position *p, bool useRepresentative)
Returns the users of a position using the value at the given operand.
Predicate getTypeConstraint(Attribute type)
Create a predicate comparing the type of an attribute or value to a known type.
OperationPosition * getOperandDefiningOp(Position *p)
Returns the parent position defining the value held by the given operand.
Predicate getResultCount(unsigned count)
Create a predicate comparing the number of results of an operation to a known value.
std::pair< Qualifier *, Qualifier * > Predicate
An ordinal predicate consists of a "Question" and a set of acceptable "Answers" (later converted to o...
Predicate getEqualTo(Position *pos)
Create a predicate checking if two values are equal.
Position * getAllOperands(OperationPosition *p)
Position * getAttributeLiteral(Attribute attr)
Returns an attribute position for the given attribute.
Predicate getConstraint(StringRef name, ArrayRef< Position * > args, ArrayRef< Type > resultTypes, bool isNegated)
Create a predicate that applies a generic constraint.
Predicate getOperationName(StringRef name)
Create a predicate comparing the name of an operation to a known value.
An ordinal predicate consists of a "Question" and a set of acceptable "Answers" (later converted to o...
Predicates::Kind getKind() const
Returns the kind of this qualifier.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
bool operator<(const Fraction &x, const Fraction &y)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
A position describing an attribute of an operation.
A BoolNode denotes a question with a boolean-like result.
BoolNode(Position *position, Qualifier *question, Qualifier *answer, std::unique_ptr< MatcherNode > successNode, std::unique_ptr< MatcherNode > failureNode=nullptr)
A position describing the result of a native constraint.
Apply a parameterized constraint to multiple position values and possibly produce results.
An operation position describes an operation node in the IR.
bool isRoot() const
Returns if this operation position corresponds to the root.
bool isOperandDefiningOp() const
Returns if this operation represents an operand defining op.
A PositionalPredicate is a predicate that is associated with a specific positional value.
The information associated with an edge in the cost graph.
Value connector
The connector value in the intersection of the two subtrees rooted at the source and target root that...
std::pair< unsigned, unsigned > cost
The depth of the connector Value w.r.t.
A SuccessNode denotes that a given high level pattern has successfully been matched.
SuccessNode(pdl::PatternOp pattern, Value root, std::unique_ptr< MatcherNode > failureNode)
A SwitchNode denotes a question with multiple potential results.
llvm::MapVector< Qualifier *, std::unique_ptr< MatcherNode > > ChildMapT
Returns the children of this switch node.
ChildMapT & getChildren()
SwitchNode(Position *position, Qualifier *question)
A position describing the result type of an entity, i.e.