MLIR 22.0.0git
PredicateTree.cpp
Go to the documentation of this file.
1//===- PredicateTree.cpp - Predicate tree merging -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "PredicateTree.h"
10#include "RootOrdering.h"
11
13#include "mlir/IR/BuiltinOps.h"
14#include "llvm/ADT/MapVector.h"
15#include "llvm/ADT/SmallPtrSet.h"
16#include "llvm/ADT/TypeSwitch.h"
17#include "llvm/Support/Debug.h"
18#include "llvm/Support/DebugLog.h"
19#include <queue>
20
21#define DEBUG_TYPE "pdl-predicate-tree"
22
23using namespace mlir;
24using namespace mlir::pdl_to_pdl_interp;
25
26//===----------------------------------------------------------------------===//
27// Predicate List Building
28//===----------------------------------------------------------------------===//
29
30static void getTreePredicates(std::vector<PositionalPredicate> &predList,
31 Value val, PredicateBuilder &builder,
33 Position *pos);
34
35/// Compares the depths of two positions.
37 return lhs->getOperationDepth() < rhs->getOperationDepth();
38}
39
40/// Returns the number of non-range elements within `values`.
41static unsigned getNumNonRangeValues(ValueRange values) {
42 return llvm::count_if(values.getTypes(),
43 [](Type type) { return !isa<pdl::RangeType>(type); });
44}
45
46static void getTreePredicates(std::vector<PositionalPredicate> &predList,
47 Value val, PredicateBuilder &builder,
49 AttributePosition *pos) {
50 assert(isa<pdl::AttributeType>(val.getType()) && "expected attribute type");
51 predList.emplace_back(pos, builder.getIsNotNull());
52
53 if (auto attr = val.getDefiningOp<pdl::AttributeOp>()) {
54 // If the attribute has a type or value, add a constraint.
55 if (Value type = attr.getValueType())
56 getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
57 else if (Attribute value = attr.getValueAttr())
58 predList.emplace_back(pos, builder.getAttributeConstraint(value));
59 }
60}
61
62/// Collect all of the predicates for the given operand position.
63static void getOperandTreePredicates(std::vector<PositionalPredicate> &predList,
64 Value val, PredicateBuilder &builder,
66 Position *pos) {
67 Type valueType = val.getType();
68 bool isVariadic = isa<pdl::RangeType>(valueType);
69
70 // If this is a typed operand, add a type constraint.
72 .Case<pdl::OperandOp, pdl::OperandsOp>([&](auto op) {
73 // Prevent traversal into a null value if the operand has a proper
74 // index.
75 if (std::is_same<pdl::OperandOp, decltype(op)>::value ||
76 cast<OperandGroupPosition>(pos)->getOperandGroupNumber())
77 predList.emplace_back(pos, builder.getIsNotNull());
78
79 if (Value type = op.getValueType())
80 getTreePredicates(predList, type, builder, inputs,
81 builder.getType(pos));
82 })
83 .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto op) {
84 std::optional<unsigned> index = op.getIndex();
85
86 // Prevent traversal into a null value if the result has a proper index.
87 if (index)
88 predList.emplace_back(pos, builder.getIsNotNull());
89
90 // Get the parent operation of this operand.
91 OperationPosition *parentPos = builder.getOperandDefiningOp(pos);
92 predList.emplace_back(parentPos, builder.getIsNotNull());
93
94 // Ensure that the operands match the corresponding results of the
95 // parent operation.
96 Position *resultPos = nullptr;
97 if (std::is_same<pdl::ResultOp, decltype(op)>::value)
98 resultPos = builder.getResult(parentPos, *index);
99 else
100 resultPos = builder.getResultGroup(parentPos, index, isVariadic);
101 predList.emplace_back(resultPos, builder.getEqualTo(pos));
102
103 // Collect the predicates of the parent operation.
104 getTreePredicates(predList, op.getParent(), builder, inputs,
105 (Position *)parentPos);
106 });
107}
108
109static void
110getTreePredicates(std::vector<PositionalPredicate> &predList, Value val,
111 PredicateBuilder &builder,
113 std::optional<unsigned> ignoreOperand = std::nullopt) {
114 assert(isa<pdl::OperationType>(val.getType()) && "expected operation");
115 pdl::OperationOp op = cast<pdl::OperationOp>(val.getDefiningOp());
116 OperationPosition *opPos = cast<OperationPosition>(pos);
117
118 // Ensure getDefiningOp returns a non-null operation.
119 if (!opPos->isRoot())
120 predList.emplace_back(pos, builder.getIsNotNull());
121
122 // Check that this is the correct root operation.
123 if (std::optional<StringRef> opName = op.getOpName())
124 predList.emplace_back(pos, builder.getOperationName(*opName));
125
126 // Check that the operation has the proper number of operands. If there are
127 // any variable length operands, we check a minimum instead of an exact count.
128 OperandRange operands = op.getOperandValues();
129 unsigned minOperands = getNumNonRangeValues(operands);
130 if (minOperands != operands.size()) {
131 if (minOperands)
132 predList.emplace_back(pos, builder.getOperandCountAtLeast(minOperands));
133 } else {
134 predList.emplace_back(pos, builder.getOperandCount(minOperands));
135 }
136
137 // Check that the operation has the proper number of results. If there are
138 // any variable length results, we check a minimum instead of an exact count.
139 OperandRange types = op.getTypeValues();
140 unsigned minResults = getNumNonRangeValues(types);
141 if (minResults == types.size())
142 predList.emplace_back(pos, builder.getResultCount(types.size()));
143 else if (minResults)
144 predList.emplace_back(pos, builder.getResultCountAtLeast(minResults));
145
146 // Recurse into any attributes, operands, or results.
147 for (auto [attrName, attr] :
148 llvm::zip(op.getAttributeValueNames(), op.getAttributeValues())) {
150 predList, attr, builder, inputs,
151 builder.getAttribute(opPos, cast<StringAttr>(attrName).getValue()));
152 }
153
154 // Process the operands and results of the operation. For all values up to
155 // the first variable length value, we use the concrete operand/result
156 // number. After that, we use the "group" given that we can't know the
157 // concrete indices until runtime. If there is only one variadic operand
158 // group, we treat it as all of the operands/results of the operation.
159 /// Operands.
160 if (operands.size() == 1 && isa<pdl::RangeType>(operands[0].getType())) {
161 // Ignore the operands if we are performing an upward traversal (in that
162 // case, they have already been visited).
163 if (opPos->isRoot() || opPos->isOperandDefiningOp())
164 getTreePredicates(predList, operands.front(), builder, inputs,
165 builder.getAllOperands(opPos));
166 } else {
167 bool foundVariableLength = false;
168 for (const auto &operandIt : llvm::enumerate(operands)) {
169 bool isVariadic = isa<pdl::RangeType>(operandIt.value().getType());
170 foundVariableLength |= isVariadic;
171
172 // Ignore the specified operand, usually because this position was
173 // visited in an upward traversal via an iterative choice.
174 if (ignoreOperand == operandIt.index())
175 continue;
176
177 Position *pos =
178 foundVariableLength
179 ? builder.getOperandGroup(opPos, operandIt.index(), isVariadic)
180 : builder.getOperand(opPos, operandIt.index());
181 getTreePredicates(predList, operandIt.value(), builder, inputs, pos);
182 }
183 }
184 /// Results.
185 if (types.size() == 1 && isa<pdl::RangeType>(types[0].getType())) {
186 getTreePredicates(predList, types.front(), builder, inputs,
187 builder.getType(builder.getAllResults(opPos)));
188 return;
189 }
190
191 bool foundVariableLength = false;
192 for (auto [idx, typeValue] : llvm::enumerate(types)) {
193 bool isVariadic = isa<pdl::RangeType>(typeValue.getType());
194 foundVariableLength |= isVariadic;
195
196 auto *resultPos = foundVariableLength
197 ? builder.getResultGroup(pos, idx, isVariadic)
198 : builder.getResult(pos, idx);
199 predList.emplace_back(resultPos, builder.getIsNotNull());
200 getTreePredicates(predList, typeValue, builder, inputs,
201 builder.getType(resultPos));
202 }
203}
204
205static void getTreePredicates(std::vector<PositionalPredicate> &predList,
206 Value val, PredicateBuilder &builder,
208 TypePosition *pos) {
209 // Check for a constraint on a constant type.
210 if (pdl::TypeOp typeOp = val.getDefiningOp<pdl::TypeOp>()) {
211 if (Attribute type = typeOp.getConstantTypeAttr())
212 predList.emplace_back(pos, builder.getTypeConstraint(type));
213 } else if (pdl::TypesOp typeOp = val.getDefiningOp<pdl::TypesOp>()) {
214 if (Attribute typeAttr = typeOp.getConstantTypesAttr())
215 predList.emplace_back(pos, builder.getTypeConstraint(typeAttr));
216 }
217}
218
219/// Collect the tree predicates anchored at the given value.
220static void getTreePredicates(std::vector<PositionalPredicate> &predList,
221 Value val, PredicateBuilder &builder,
223 Position *pos) {
224 // Make sure this input value is accessible to the rewrite.
225 auto it = inputs.try_emplace(val, pos);
226 if (!it.second) {
227 // If this is an input value that has been visited in the tree, add a
228 // constraint to ensure that both instances refer to the same value.
229 if (isa<pdl::AttributeOp, pdl::OperandOp, pdl::OperandsOp, pdl::OperationOp,
230 pdl::TypeOp>(val.getDefiningOp())) {
231 auto minMaxPositions =
232 std::minmax(pos, it.first->second, comparePosDepth);
233 predList.emplace_back(minMaxPositions.second,
234 builder.getEqualTo(minMaxPositions.first));
235 }
236 return;
237 }
238
240 .Case<AttributePosition, OperationPosition, TypePosition>([&](auto *pos) {
241 getTreePredicates(predList, val, builder, inputs, pos);
242 })
243 .Case<OperandPosition, OperandGroupPosition>([&](auto *pos) {
244 getOperandTreePredicates(predList, val, builder, inputs, pos);
245 })
246 .DefaultUnreachable("unexpected position kind");
247}
248
249static void getAttributePredicates(pdl::AttributeOp op,
250 std::vector<PositionalPredicate> &predList,
251 PredicateBuilder &builder,
253 Position *&attrPos = inputs[op];
254 if (attrPos)
255 return;
256 Attribute value = op.getValueAttr();
257 assert(value && "expected non-tree `pdl.attribute` to contain a value");
258 attrPos = builder.getAttributeLiteral(value);
259}
260
261static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
262 std::vector<PositionalPredicate> &predList,
263 PredicateBuilder &builder,
265 OperandRange arguments = op.getArgs();
266
267 std::vector<Position *> allPositions;
268 allPositions.reserve(arguments.size());
269 for (Value arg : arguments)
270 allPositions.push_back(inputs.lookup(arg));
271
272 // Push the constraint to the furthest position.
273 Position *pos = *llvm::max_element(allPositions, comparePosDepth);
274 ResultRange results = op.getResults();
276 op.getName(), allPositions, SmallVector<Type>(results.getTypes()),
277 op.getIsNegated());
278
279 // For each result register a position so it can be used later
280 for (auto [i, result] : llvm::enumerate(results)) {
281 ConstraintQuestion *q = cast<ConstraintQuestion>(pred.first);
282 ConstraintPosition *pos = builder.getConstraintPosition(q, i);
283 auto [it, inserted] = inputs.try_emplace(result, pos);
284 // If this is an input value that has been visited in the tree, add a
285 // constraint to ensure that both instances refer to the same value.
286 if (!inserted) {
287 Position *first = pos;
288 Position *second = it->second;
289 if (comparePosDepth(second, first))
290 std::tie(second, first) = std::make_pair(first, second);
291
292 predList.emplace_back(second, builder.getEqualTo(first));
293 }
294 }
295 predList.emplace_back(pos, pred);
296}
297
298static void getResultPredicates(pdl::ResultOp op,
299 std::vector<PositionalPredicate> &predList,
300 PredicateBuilder &builder,
302 Position *&resultPos = inputs[op];
303 if (resultPos)
304 return;
305
306 // Ensure that the result isn't null.
307 auto *parentPos = cast<OperationPosition>(inputs.lookup(op.getParent()));
308 resultPos = builder.getResult(parentPos, op.getIndex());
309 predList.emplace_back(resultPos, builder.getIsNotNull());
310}
311
312static void getResultPredicates(pdl::ResultsOp op,
313 std::vector<PositionalPredicate> &predList,
314 PredicateBuilder &builder,
316 Position *&resultPos = inputs[op];
317 if (resultPos)
318 return;
319
320 // Ensure that the result isn't null if the result has an index.
321 auto *parentPos = cast<OperationPosition>(inputs.lookup(op.getParent()));
322 bool isVariadic = isa<pdl::RangeType>(op.getType());
323 std::optional<unsigned> index = op.getIndex();
324 resultPos = builder.getResultGroup(parentPos, index, isVariadic);
325 if (index)
326 predList.emplace_back(resultPos, builder.getIsNotNull());
327}
328
329static void getTypePredicates(Value typeValue,
330 function_ref<Attribute()> typeAttrFn,
331 PredicateBuilder &builder,
333 Position *&typePos = inputs[typeValue];
334 if (typePos)
335 return;
336 Attribute typeAttr = typeAttrFn();
337 assert(typeAttr &&
338 "expected non-tree `pdl.type`/`pdl.types` to contain a value");
339 typePos = builder.getTypeLiteral(typeAttr);
340}
341
342/// Collect all of the predicates that cannot be determined via walking the
343/// tree.
344static void getNonTreePredicates(pdl::PatternOp pattern,
345 std::vector<PositionalPredicate> &predList,
346 PredicateBuilder &builder,
348 for (Operation &op : pattern.getBodyRegion().getOps()) {
350 .Case([&](pdl::AttributeOp attrOp) {
351 getAttributePredicates(attrOp, predList, builder, inputs);
352 })
353 .Case<pdl::ApplyNativeConstraintOp>([&](auto constraintOp) {
354 getConstraintPredicates(constraintOp, predList, builder, inputs);
355 })
356 .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
357 getResultPredicates(resultOp, predList, builder, inputs);
358 })
359 .Case([&](pdl::TypeOp typeOp) {
361 typeOp, [&] { return typeOp.getConstantTypeAttr(); }, builder,
362 inputs);
363 })
364 .Case([&](pdl::TypesOp typeOp) {
366 typeOp, [&] { return typeOp.getConstantTypesAttr(); }, builder,
367 inputs);
368 });
369 }
370}
371
372namespace {
373
374/// An op accepting a value at an optional index.
375struct OpIndex {
376 Value parent;
377 std::optional<unsigned> index;
378};
379
380/// The parent and operand index of each operation for each root, stored
381/// as a nested map [root][operation].
383
384} // namespace
385
386/// Given a pattern, determines the set of roots present in this pattern.
387/// These are the operations whose results are not consumed by other operations.
388static SmallVector<Value> detectRoots(pdl::PatternOp pattern) {
389 // First, collect all the operations that are used as operands
390 // to other operations. These are not roots by default.
391 DenseSet<Value> used;
392 for (auto operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>()) {
393 for (Value operand : operationOp.getOperandValues())
394 TypeSwitch<Operation *>(operand.getDefiningOp())
395 .Case<pdl::ResultOp, pdl::ResultsOp>(
396 [&used](auto resultOp) { used.insert(resultOp.getParent()); });
397 }
398
399 // Remove the specified root from the use set, so that we can
400 // always select it as a root, even if it is used by other operations.
401 if (Value root = pattern.getRewriter().getRoot())
402 used.erase(root);
403
404 // Finally, collect all the unused operations.
405 SmallVector<Value> roots;
406 for (Value operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>())
407 if (!used.contains(operationOp))
408 roots.push_back(operationOp);
409
410 return roots;
411}
412
413/// Given a list of candidate roots, builds the cost graph for connecting them.
414/// The graph is formed by traversing the DAG of operations starting from each
415/// root and marking the depth of each connector value (operand). Then we join
416/// the candidate roots based on the common connector values, taking the one
417/// with the minimum depth. Along the way, we compute, for each candidate root,
418/// a mapping from each operation (in the DAG underneath this root) to its
419/// parent operation and the corresponding operand index.
421 ParentMaps &parentMaps) {
422
423 // The entry of a queue. The entry consists of the following items:
424 // * the value in the DAG underneath the root;
425 // * the parent of the value;
426 // * the operand index of the value in its parent;
427 // * the depth of the visited value.
428 struct Entry {
429 Entry(Value value, Value parent, std::optional<unsigned> index,
430 unsigned depth)
431 : value(value), parent(parent), index(index), depth(depth) {}
432
433 Value value;
434 Value parent;
435 std::optional<unsigned> index;
436 unsigned depth;
437 };
438
439 // A root of a value and its depth (distance from root to the value).
440 struct RootDepth {
441 Value root;
442 unsigned depth = 0;
443 };
444
445 // Map from candidate connector values to their roots and depths. Using a
446 // small vector with 1 entry because most values belong to a single root.
447 llvm::MapVector<Value, SmallVector<RootDepth, 1>> connectorsRootsDepths;
448
449 // Perform a breadth-first traversal of the op DAG rooted at each root.
450 for (Value root : roots) {
451 // The queue of visited values. A value may be present multiple times in
452 // the queue, for multiple parents. We only accept the first occurrence,
453 // which is guaranteed to have the lowest depth.
454 std::queue<Entry> toVisit;
455 toVisit.emplace(root, Value(), 0, 0);
456
457 // The map from value to its parent for the current root.
458 DenseMap<Value, OpIndex> &parentMap = parentMaps[root];
459
460 while (!toVisit.empty()) {
461 Entry entry = toVisit.front();
462 toVisit.pop();
463 // Skip if already visited.
464 if (!parentMap.insert({entry.value, {entry.parent, entry.index}}).second)
465 continue;
466
467 // Mark the root and depth of the value.
468 connectorsRootsDepths[entry.value].push_back({root, entry.depth});
469
470 // Traverse the operands of an operation and result ops.
471 // We intentionally do not traverse attributes and types, because those
472 // are expensive to join on.
473 TypeSwitch<Operation *>(entry.value.getDefiningOp())
474 .Case<pdl::OperationOp>([&](auto operationOp) {
475 OperandRange operands = operationOp.getOperandValues();
476 // Special case when we pass all the operands in one range.
477 // For those, the index is empty.
478 if (operands.size() == 1 &&
479 isa<pdl::RangeType>(operands[0].getType())) {
480 toVisit.emplace(operands[0], entry.value, std::nullopt,
481 entry.depth + 1);
482 return;
483 }
484
485 // Default case: visit all the operands.
486 for (const auto &p :
487 llvm::enumerate(operationOp.getOperandValues()))
488 toVisit.emplace(p.value(), entry.value, p.index(),
489 entry.depth + 1);
490 })
491 .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
492 toVisit.emplace(resultOp.getParent(), entry.value,
493 resultOp.getIndex(), entry.depth);
494 });
495 }
496 }
497
498 // Now build the cost graph.
499 // This is simply a minimum over all depths for the target root.
500 unsigned nextID = 0;
501 for (const auto &connectorRootsDepths : connectorsRootsDepths) {
502 Value value = connectorRootsDepths.first;
503 ArrayRef<RootDepth> rootsDepths = connectorRootsDepths.second;
504 // If there is only one root for this value, this will not trigger
505 // any edges in the cost graph (a perf optimization).
506 if (rootsDepths.size() == 1)
507 continue;
508
509 for (const RootDepth &p : rootsDepths) {
510 for (const RootDepth &q : rootsDepths) {
511 if (&p == &q)
512 continue;
513 // Insert or retrieve the property of edge from p to q.
514 RootOrderingEntry &entry = graph[q.root][p.root];
515 if (!entry.connector /* new edge */ || entry.cost.first > q.depth) {
516 if (!entry.connector)
517 entry.cost.second = nextID++;
518 entry.cost.first = q.depth;
519 entry.connector = value;
520 }
521 }
522 }
523 }
524
525 assert((llvm::hasSingleElement(roots) || graph.size() == roots.size()) &&
526 "the pattern contains a candidate root disconnected from the others");
527}
528
529/// Returns true if the operand at the given index needs to be queried using an
530/// operand group, i.e., if it is variadic itself or follows a variadic operand.
531static bool useOperandGroup(pdl::OperationOp op, unsigned index) {
532 OperandRange operands = op.getOperandValues();
533 assert(index < operands.size() && "operand index out of range");
534 for (unsigned i = 0; i <= index; ++i)
535 if (isa<pdl::RangeType>(operands[i].getType()))
536 return true;
537 return false;
538}
539
540/// Visit a node during upward traversal.
541static void visitUpward(std::vector<PositionalPredicate> &predList,
542 OpIndex opIndex, PredicateBuilder &builder,
543 DenseMap<Value, Position *> &valueToPosition,
544 Position *&pos, unsigned rootID) {
545 Value value = opIndex.parent;
547 .Case<pdl::OperationOp>([&](auto operationOp) {
548 LDBG() << " * Value: " << value;
549
550 // Get users and iterate over them.
551 Position *usersPos = builder.getUsers(pos, /*useRepresentative=*/true);
552 Position *foreachPos = builder.getForEach(usersPos, rootID);
553 OperationPosition *opPos = builder.getPassthroughOp(foreachPos);
554
555 // Compare the operand(s) of the user against the input value(s).
556 Position *operandPos;
557 if (!opIndex.index) {
558 // We are querying all the operands of the operation.
559 operandPos = builder.getAllOperands(opPos);
560 } else if (useOperandGroup(operationOp, *opIndex.index)) {
561 // We are querying an operand group.
562 Type type = operationOp.getOperandValues()[*opIndex.index].getType();
563 bool variadic = isa<pdl::RangeType>(type);
564 operandPos = builder.getOperandGroup(opPos, opIndex.index, variadic);
565 } else {
566 // We are querying an individual operand.
567 operandPos = builder.getOperand(opPos, *opIndex.index);
568 }
569 predList.emplace_back(operandPos, builder.getEqualTo(pos));
570
571 // Guard against duplicate upward visits. These are not possible,
572 // because if this value was already visited, it would have been
573 // cheaper to start the traversal at this value rather than at the
574 // `connector`, violating the optimality of our spanning tree.
575 bool inserted = valueToPosition.try_emplace(value, opPos).second;
576 (void)inserted;
577 assert(inserted && "duplicate upward visit");
578
579 // Obtain the tree predicates at the current value.
580 getTreePredicates(predList, value, builder, valueToPosition, opPos,
581 opIndex.index);
582
583 // Update the position
584 pos = opPos;
585 })
586 .Case<pdl::ResultOp>([&](auto resultOp) {
587 // Traverse up an individual result.
588 auto *opPos = dyn_cast<OperationPosition>(pos);
589 assert(opPos && "operations and results must be interleaved");
590 pos = builder.getResult(opPos, *opIndex.index);
591
592 // Insert the result position in case we have not visited it yet.
593 valueToPosition.try_emplace(value, pos);
594 })
595 .Case<pdl::ResultsOp>([&](auto resultOp) {
596 // Traverse up a group of results.
597 auto *opPos = dyn_cast<OperationPosition>(pos);
598 assert(opPos && "operations and results must be interleaved");
599 bool isVariadic = isa<pdl::RangeType>(value.getType());
600 if (opIndex.index)
601 pos = builder.getResultGroup(opPos, opIndex.index, isVariadic);
602 else
603 pos = builder.getAllResults(opPos);
604
605 // Insert the result position in case we have not visited it yet.
606 valueToPosition.try_emplace(value, pos);
607 });
608}
609
610/// Given a pattern operation, build the set of matcher predicates necessary to
611/// match this pattern.
612static Value buildPredicateList(pdl::PatternOp pattern,
613 PredicateBuilder &builder,
614 std::vector<PositionalPredicate> &predList,
615 DenseMap<Value, Position *> &valueToPosition) {
616 SmallVector<Value> roots = detectRoots(pattern);
617
618 // Build the root ordering graph and compute the parent maps.
619 RootOrderingGraph graph;
620 ParentMaps parentMaps;
621 buildCostGraph(roots, graph, parentMaps);
622 LDBG() << "Graph:";
623 for (auto &target : graph) {
624 LDBG() << " * " << target.first.getLoc() << " " << target.first;
625 for (auto &source : target.second) {
626 RootOrderingEntry &entry = source.second;
627 LDBG() << " <- " << source.first << ": " << entry.cost.first << ":"
628 << entry.cost.second << " via " << entry.connector.getLoc();
629 }
630 }
631
632 // Solve the optimal branching problem for each candidate root, or use the
633 // provided one.
634 Value bestRoot = pattern.getRewriter().getRoot();
636 if (!bestRoot) {
637 unsigned bestCost = 0;
638 LDBG() << "Candidate roots:";
639 for (Value root : roots) {
640 OptimalBranching solver(graph, root);
641 unsigned cost = solver.solve();
642 LDBG() << " * " << root << ": " << cost;
643 if (!bestRoot || bestCost > cost) {
644 bestCost = cost;
645 bestRoot = root;
646 bestEdges = solver.preOrderTraversal(roots);
647 }
648 }
649 } else {
650 OptimalBranching solver(graph, bestRoot);
651 solver.solve();
652 bestEdges = solver.preOrderTraversal(roots);
653 }
654
655 // Print the best solution.
656 LDBG() << "Best tree:";
657 for (const std::pair<Value, Value> &edge : bestEdges) {
658 if (edge.second)
659 LDBG() << " * " << edge.first << " <- " << edge.second;
660 else
661 LDBG() << " * " << edge.first;
662 }
663
664 LDBG() << "Calling key getTreePredicates (Value: " << bestRoot << ")";
665
666 // The best root is the starting point for the traversal. Get the tree
667 // predicates for the DAG rooted at bestRoot.
668 getTreePredicates(predList, bestRoot, builder, valueToPosition,
669 builder.getRoot());
670
671 // Traverse the selected optimal branching. For all edges in order, traverse
672 // up starting from the connector, until the candidate root is reached, and
673 // call getTreePredicates at every node along the way.
674 for (const auto &it : llvm::enumerate(bestEdges)) {
675 Value target = it.value().first;
676 Value source = it.value().second;
677
678 // Check if we already visited the target root. This happens in two cases:
679 // 1) the initial root (bestRoot);
680 // 2) a root that is dominated by (contained in the subtree rooted at) an
681 // already visited root.
682 if (valueToPosition.count(target))
683 continue;
684
685 // Determine the connector.
686 Value connector = graph[target][source].connector;
687 assert(connector && "invalid edge");
688 LDBG() << " * Connector: " << connector.getLoc();
689 DenseMap<Value, OpIndex> parentMap = parentMaps.lookup(target);
690 Position *pos = valueToPosition.lookup(connector);
691 assert(pos && "connector has not been traversed yet");
692
693 // Traverse from the connector upwards towards the target root.
694 for (Value value = connector; value != target;) {
695 OpIndex opIndex = parentMap.lookup(value);
696 assert(opIndex.parent && "missing parent");
697 visitUpward(predList, opIndex, builder, valueToPosition, pos, it.index());
698 value = opIndex.parent;
699 }
700 }
701
702 getNonTreePredicates(pattern, predList, builder, valueToPosition);
703
704 return bestRoot;
705}
706
707//===----------------------------------------------------------------------===//
708// Pattern Predicate Tree Merging
709//===----------------------------------------------------------------------===//
710
711namespace {
712
713/// This class represents a specific predicate applied to a position, and
714/// provides hashing and ordering operators. This class allows for computing a
715/// frequence sum and ordering predicates based on a cost model.
716struct OrderedPredicate {
717 OrderedPredicate(const std::pair<Position *, Qualifier *> &ip)
718 : position(ip.first), question(ip.second) {}
719 OrderedPredicate(const PositionalPredicate &ip)
720 : position(ip.position), question(ip.question) {}
721
722 /// The position this predicate is applied to.
723 Position *position;
724
725 /// The question that is applied by this predicate onto the position.
726 Qualifier *question;
727
728 /// The first and second order benefit sums.
729 /// The primary sum is the number of occurrences of this predicate among all
730 /// of the patterns.
731 unsigned primary = 0;
732 /// The secondary sum is a squared summation of the primary sum of all of the
733 /// predicates within each pattern that contains this predicate. This allows
734 /// for favoring predicates that are more commonly shared within a pattern, as
735 /// opposed to those shared across patterns.
736 unsigned secondary = 0;
737
738 /// The tie breaking ID, used to preserve a deterministic (insertion) order
739 /// among all the predicates with the same priority, depth, and position /
740 /// predicate dependency.
741 unsigned id = 0;
742
743 /// A map between a pattern operation and the answer to the predicate question
744 /// within that pattern.
746
747 /// Returns true if this predicate is ordered before `rhs`, based on the cost
748 /// model.
749 bool operator<(const OrderedPredicate &rhs) const {
750 // Sort by:
751 // * higher first and secondary order sums
752 // * lower depth
753 // * lower position dependency
754 // * lower predicate dependency
755 // * lower tie breaking ID
756 auto *rhsPos = rhs.position;
757 return std::make_tuple(primary, secondary, rhsPos->getOperationDepth(),
758 rhsPos->getKind(), rhs.question->getKind(), rhs.id) >
759 std::make_tuple(rhs.primary, rhs.secondary,
760 position->getOperationDepth(), position->getKind(),
761 question->getKind(), id);
762 }
763};
764
765/// A DenseMapInfo for OrderedPredicate based solely on the position and
766/// question.
767struct OrderedPredicateDenseInfo {
769
770 static OrderedPredicate getEmptyKey() { return Base::getEmptyKey(); }
771 static OrderedPredicate getTombstoneKey() { return Base::getTombstoneKey(); }
772 static bool isEqual(const OrderedPredicate &lhs,
773 const OrderedPredicate &rhs) {
774 return lhs.position == rhs.position && lhs.question == rhs.question;
775 }
776 static unsigned getHashValue(const OrderedPredicate &p) {
777 return llvm::hash_combine(p.position, p.question);
778 }
779};
780
781/// This class wraps a set of ordered predicates that are used within a specific
782/// pattern operation.
783struct OrderedPredicateList {
784 OrderedPredicateList(pdl::PatternOp pattern, Value root)
785 : pattern(pattern), root(root) {}
786
787 pdl::PatternOp pattern;
788 Value root;
790};
791} // namespace
792
793/// Returns true if the given matcher refers to the same predicate as the given
794/// ordered predicate. This means that the position and questions of the two
795/// match.
796static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate) {
797 return node->getPosition() == predicate->position &&
798 node->getQuestion() == predicate->question;
799}
800
801/// Get or insert a child matcher for the given parent switch node, given a
802/// predicate and parent pattern.
803static std::unique_ptr<MatcherNode> &
804getOrCreateChild(SwitchNode *node, OrderedPredicate *predicate,
805 pdl::PatternOp pattern) {
806 assert(isSamePredicate(node, predicate) &&
807 "expected matcher to equal the given predicate");
808
809 auto it = predicate->patternToAnswer.find(pattern);
810 assert(it != predicate->patternToAnswer.end() &&
811 "expected pattern to exist in predicate");
812 return node->getChildren()[it->second];
813}
814
815/// Build the matcher CFG by "pushing" patterns through by sorted predicate
816/// order. A pattern will traverse as far as possible using common predicates
817/// and then either diverge from the CFG or reach the end of a branch and start
818/// creating new nodes.
819static void propagatePattern(std::unique_ptr<MatcherNode> &node,
820 OrderedPredicateList &list,
821 std::vector<OrderedPredicate *>::iterator current,
822 std::vector<OrderedPredicate *>::iterator end) {
823 if (current == end) {
824 // We've hit the end of a pattern, so create a successful result node.
825 node =
826 std::make_unique<SuccessNode>(list.pattern, list.root, std::move(node));
827
828 // If the pattern doesn't contain this predicate, ignore it.
829 } else if (!list.predicates.contains(*current)) {
830 propagatePattern(node, list, std::next(current), end);
831
832 // If the current matcher node is invalid, create a new one for this
833 // position and continue propagation.
834 } else if (!node) {
835 // Create a new node at this position and continue
836 node = std::make_unique<SwitchNode>((*current)->position,
837 (*current)->question);
839 getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
840 list, std::next(current), end);
841
842 // If the matcher has already been created, and it is for this predicate we
843 // continue propagation to the child.
844 } else if (isSamePredicate(node.get(), *current)) {
846 getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
847 list, std::next(current), end);
848
849 // If the matcher doesn't match the current predicate, insert a branch as
850 // the common set of matchers has diverged.
851 } else {
852 propagatePattern(node->getFailureNode(), list, current, end);
853 }
854}
855
856/// Fold any switch nodes nested under `node` to boolean nodes when possible.
857/// `node` is updated in-place if it is a switch.
858static void foldSwitchToBool(std::unique_ptr<MatcherNode> &node) {
859 if (!node)
860 return;
861
862 if (SwitchNode *switchNode = dyn_cast<SwitchNode>(&*node)) {
863 SwitchNode::ChildMapT &children = switchNode->getChildren();
864 for (auto &it : children)
865 foldSwitchToBool(it.second);
866
867 // If the node only contains one child, collapse it into a boolean predicate
868 // node.
869 if (children.size() == 1) {
870 auto *childIt = children.begin();
871 node = std::make_unique<BoolNode>(
872 node->getPosition(), node->getQuestion(), childIt->first,
873 std::move(childIt->second), std::move(node->getFailureNode()));
874 }
875 } else if (BoolNode *boolNode = dyn_cast<BoolNode>(&*node)) {
876 foldSwitchToBool(boolNode->getSuccessNode());
877 }
878
879 foldSwitchToBool(node->getFailureNode());
880}
881
882/// Insert an exit node at the end of the failure path of the `root`.
883static void insertExitNode(std::unique_ptr<MatcherNode> *root) {
884 while (*root)
885 root = &(*root)->getFailureNode();
886 *root = std::make_unique<ExitNode>();
887}
888
889/// Sorts the range begin/end with the partial order given by cmp.
890template <typename Iterator, typename Compare>
891static void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp) {
892 while (begin != end) {
893 // Cannot compute sortBeforeOthers in the predicate of stable_partition
894 // because stable_partition will not keep the [begin, end) range intact
895 // while it runs.
897 for (auto i = begin; i != end; ++i) {
898 if (std::none_of(begin, end, [&](auto const &b) { return cmp(b, *i); }))
899 sortBeforeOthers.insert(*i);
900 }
901
902 auto const next = std::stable_partition(begin, end, [&](auto const &a) {
903 return sortBeforeOthers.contains(a);
904 });
905 assert(next != begin && "not a partial ordering");
906 begin = next;
907 }
908}
909
910/// Returns true if 'b' depends on a result of 'a'.
911static bool dependsOn(OrderedPredicate *a, OrderedPredicate *b) {
912 auto *cqa = dyn_cast<ConstraintQuestion>(a->question);
913 if (!cqa)
914 return false;
915
916 auto positionDependsOnA = [&](Position *p) {
917 auto *cp = dyn_cast<ConstraintPosition>(p);
918 return cp && cp->getQuestion() == cqa;
919 };
920
921 if (auto *cqb = dyn_cast<ConstraintQuestion>(b->question)) {
922 // Does any argument of b use a?
923 return llvm::any_of(cqb->getArgs(), positionDependsOnA);
924 }
925 if (auto *equalTo = dyn_cast<EqualToQuestion>(b->question)) {
926 return positionDependsOnA(b->position) ||
927 positionDependsOnA(equalTo->getValue());
928 }
929 return positionDependsOnA(b->position);
930}
931
932/// Given a module containing PDL pattern operations, generate a matcher tree
933/// using the patterns within the given module and return the root matcher node.
934std::unique_ptr<MatcherNode>
936 DenseMap<Value, Position *> &valueToPosition) {
937 // The set of predicates contained within the pattern operations of the
938 // module.
939 struct PatternPredicates {
940 PatternPredicates(pdl::PatternOp pattern, Value root,
941 std::vector<PositionalPredicate> predicates)
942 : pattern(pattern), root(root), predicates(std::move(predicates)) {}
943
944 /// A pattern.
945 pdl::PatternOp pattern;
946
947 /// A root of the pattern chosen among the candidate roots in pdl.rewrite.
948 Value root;
949
950 /// The extracted predicates for this pattern and root.
951 std::vector<PositionalPredicate> predicates;
952 };
953
954 SmallVector<PatternPredicates, 16> patternsAndPredicates;
955 for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
956 std::vector<PositionalPredicate> predicateList;
957 Value root =
958 buildPredicateList(pattern, builder, predicateList, valueToPosition);
959 patternsAndPredicates.emplace_back(pattern, root, std::move(predicateList));
960 }
961
962 // Associate a pattern result with each unique predicate.
964 for (auto &patternAndPredList : patternsAndPredicates) {
965 for (auto &predicate : patternAndPredList.predicates) {
966 auto it = uniqued.insert(predicate);
967 it.first->patternToAnswer.try_emplace(patternAndPredList.pattern,
968 predicate.answer);
969 // Mark the insertion order (0-based indexing).
970 if (it.second)
971 it.first->id = uniqued.size() - 1;
972 }
973 }
974
975 // Associate each pattern to a set of its ordered predicates for later lookup.
976 std::vector<OrderedPredicateList> lists;
977 lists.reserve(patternsAndPredicates.size());
978 for (auto &patternAndPredList : patternsAndPredicates) {
979 OrderedPredicateList list(patternAndPredList.pattern,
980 patternAndPredList.root);
981 for (auto &predicate : patternAndPredList.predicates) {
982 OrderedPredicate *orderedPredicate = &*uniqued.find(predicate);
983 list.predicates.insert(orderedPredicate);
984
985 // Increment the primary sum for each reference to a particular predicate.
986 ++orderedPredicate->primary;
987 }
988 lists.push_back(std::move(list));
989 }
990
991 // For a particular pattern, get the total primary sum and add it to the
992 // secondary sum of each predicate. Square the primary sums to emphasize
993 // shared predicates within rather than across patterns.
994 for (auto &list : lists) {
995 unsigned total = 0;
996 for (auto *predicate : list.predicates)
997 total += predicate->primary * predicate->primary;
998 for (auto *predicate : list.predicates)
999 predicate->secondary += total;
1000 }
1001
1002 // Sort the set of predicates now that the cost primary and secondary sums
1003 // have been computed.
1004 std::vector<OrderedPredicate *> ordered;
1005 ordered.reserve(uniqued.size());
1006 for (auto &ip : uniqued)
1007 ordered.push_back(&ip);
1008 llvm::sort(ordered, [](OrderedPredicate *lhs, OrderedPredicate *rhs) {
1009 return *lhs < *rhs;
1010 });
1011
1012 // Mostly keep the now established order, but also ensure that
1013 // ConstraintQuestions come after the results they use.
1014 stableTopologicalSort(ordered.begin(), ordered.end(), dependsOn);
1015
1016 // Build the matchers for each of the pattern predicate lists.
1017 std::unique_ptr<MatcherNode> root;
1018 for (OrderedPredicateList &list : lists)
1019 propagatePattern(root, list, ordered.begin(), ordered.end());
1020
1021 // Collapse the graph and insert the exit node.
1022 foldSwitchToBool(root);
1023 insertExitNode(&root);
1024 return root;
1025}
1026
1027//===----------------------------------------------------------------------===//
1028// MatcherNode
1029//===----------------------------------------------------------------------===//
1030
1032 std::unique_ptr<MatcherNode> failureNode)
1033 : position(p), question(q), failureNode(std::move(failureNode)),
1034 matcherTypeID(matcherTypeID) {}
1035
1036//===----------------------------------------------------------------------===//
1037// BoolNode
1038//===----------------------------------------------------------------------===//
1039
1040BoolNode::BoolNode(Position *position, Qualifier *question, Qualifier *answer,
1041 std::unique_ptr<MatcherNode> successNode,
1042 std::unique_ptr<MatcherNode> failureNode)
1043 : MatcherNode(TypeID::get<BoolNode>(), position, question,
1044 std::move(failureNode)),
1045 answer(answer), successNode(std::move(successNode)) {}
1046
1047//===----------------------------------------------------------------------===//
1048// SuccessNode
1049//===----------------------------------------------------------------------===//
1050
1051SuccessNode::SuccessNode(pdl::PatternOp pattern, Value root,
1052 std::unique_ptr<MatcherNode> failureNode)
1053 : MatcherNode(TypeID::get<SuccessNode>(), /*position=*/nullptr,
1054 /*question=*/nullptr, std::move(failureNode)),
1055 pattern(pattern), root(root) {}
1056
1057//===----------------------------------------------------------------------===//
1058// SwitchNode
1059//===----------------------------------------------------------------------===//
1060
1062 : MatcherNode(TypeID::get<SwitchNode>(), position, question) {}
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static Value buildPredicateList(pdl::PatternOp pattern, PredicateBuilder &builder, std::vector< PositionalPredicate > &predList, DenseMap< Value, Position * > &valueToPosition)
Given a pattern operation, build the set of matcher predicates necessary to match this pattern.
static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
static bool dependsOn(OrderedPredicate *a, OrderedPredicate *b)
Returns true if 'b' depends on a result of 'a'.
static void getTypePredicates(Value typeValue, function_ref< Attribute()> typeAttrFn, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
static void getNonTreePredicates(pdl::PatternOp pattern, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
Collect all of the predicates that cannot be determined via walking the tree.
static std::unique_ptr< MatcherNode > & getOrCreateChild(SwitchNode *node, OrderedPredicate *predicate, pdl::PatternOp pattern)
Get or insert a child matcher for the given parent switch node, given a predicate and parent pattern.
static SmallVector< Value > detectRoots(pdl::PatternOp pattern)
Given a pattern, determines the set of roots present in this pattern.
static bool useOperandGroup(pdl::OperationOp op, unsigned index)
Returns true if the operand at the given index needs to be queried using an operand group,...
static void getTreePredicates(std::vector< PositionalPredicate > &predList, Value val, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs, Position *pos)
Collect the tree predicates anchored at the given value.
static void getResultPredicates(pdl::ResultOp op, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
static void getOperandTreePredicates(std::vector< PositionalPredicate > &predList, Value val, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs, Position *pos)
Collect all of the predicates for the given operand position.
static bool comparePosDepth(Position *lhs, Position *rhs)
Compares the depths of two positions.
static void visitUpward(std::vector< PositionalPredicate > &predList, OpIndex opIndex, PredicateBuilder &builder, DenseMap< Value, Position * > &valueToPosition, Position *&pos, unsigned rootID)
Visit a node during upward traversal.
static unsigned getNumNonRangeValues(ValueRange values)
Returns the number of non-range elements within values.
static void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp)
Sorts the range begin/end with the partial order given by cmp.
static void insertExitNode(std::unique_ptr< MatcherNode > *root)
Insert an exit node at the end of the failure path of the root.
static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate)
Returns true if the given matcher refers to the same predicate as the given ordered predicate.
static void foldSwitchToBool(std::unique_ptr< MatcherNode > &node)
Fold any switch nodes nested under node to boolean nodes when possible.
static void buildCostGraph(ArrayRef< Value > roots, RootOrderingGraph &graph, ParentMaps &parentMaps)
Given a list of candidate roots, builds the cost graph for connecting them.
static void getAttributePredicates(pdl::AttributeOp op, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
static void propagatePattern(std::unique_ptr< MatcherNode > &node, OrderedPredicateList &list, std::vector< OrderedPredicate * >::iterator current, std::vector< OrderedPredicate * >::iterator end)
Build the matcher CFG by "pushing" patterns through by sorted predicate order.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
type_range getType() const
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
This class implements the result iterators for the Operation class.
Definition ValueRange.h:247
type_range getTypes() const
This class provides an efficient unique identifier for a specific C++ type.
Definition TypeID.h:107
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
This class represents the base of a predicate matcher node.
Position * getPosition() const
Returns the position on which the question predicate should be checked.
MatcherNode(TypeID matcherTypeID, Position *position=nullptr, Qualifier *question=nullptr, std::unique_ptr< MatcherNode > failureNode=nullptr)
Qualifier * getQuestion() const
Returns the predicate checked on this node.
static std::unique_ptr< MatcherNode > generateMatcherTree(ModuleOp module, PredicateBuilder &builder, DenseMap< Value, Position * > &valueToPosition)
Given a module containing PDL pattern operations, generate a matcher tree using the patterns within t...
The optimal branching algorithm solver.
unsigned solve()
Runs the Edmonds' algorithm for the current graph, returning the total cost of the minimum-weight spa...
std::vector< std::pair< Value, Value > > EdgeList
A list of edges (child, parent).
EdgeList preOrderTraversal(ArrayRef< Value > nodes) const
Returns the computed edges as visited in the preorder traversal.
A position describes a value on the input IR on which a predicate may be applied, such as an operatio...
Definition Predicate.h:144
const KeyTy & getValue() const
Return the key value of this predicate.
Definition Predicate.h:111
This class provides utilities for constructing predicates.
Definition Predicate.h:610
Position * getResultGroup(OperationPosition *p, std::optional< unsigned > group, bool isVariadic)
Returns a position for a group of results of the given operation.
Definition Predicate.h:675
Predicate getOperandCount(unsigned count)
Create a predicate comparing the number of operands of an operation to a known value.
Definition Predicate.h:740
Position * getAllOperands(OperationPosition *p)
Definition Predicate.h:665
ConstraintPosition * getConstraintPosition(ConstraintQuestion *q, unsigned index)
Definition Predicate.h:636
Predicate getIsNotNull()
Create a predicate comparing a value with null.
Definition Predicate.h:734
Position * getAllResults(OperationPosition *p)
Definition Predicate.h:679
Predicate getOperandCountAtLeast(unsigned count)
Definition Predicate.h:744
Predicate getResultCountAtLeast(unsigned count)
Definition Predicate.h:761
Position * getResult(OperationPosition *p, unsigned result)
Returns a result position for a result of the given operation.
Definition Predicate.h:670
Position * getForEach(Position *p, unsigned id)
Definition Predicate.h:651
UsersPosition * getUsers(Position *p, bool useRepresentative)
Returns the users of a position using the value at the given operand.
Definition Predicate.h:693
Predicate getAttributeConstraint(Attribute attr)
Create a predicate comparing an attribute to a known value.
Definition Predicate.h:710
Position * getType(Position *p)
Returns a type position for the given entity.
Definition Predicate.h:684
Position * getOperand(OperationPosition *p, unsigned operand)
Returns an operand position for an operand of the given operation.
Definition Predicate.h:656
Predicate getTypeConstraint(Attribute type)
Create a predicate comparing the type of an attribute or value to a known type.
Definition Predicate.h:769
Position * getRoot()
Returns the root operation position.
Definition Predicate.h:620
Predicate getResultCount(unsigned count)
Create a predicate comparing the number of results of an operation to a known value.
Definition Predicate.h:757
Position * getOperandGroup(OperationPosition *p, std::optional< unsigned > group, bool isVariadic)
Returns a position for a group of operands of the given operation.
Definition Predicate.h:661
Position * getTypeLiteral(Attribute attr)
Returns a type position for the given type value.
Definition Predicate.h:688
std::pair< Qualifier *, Qualifier * > Predicate
An ordinal predicate consists of a "Question" and a set of acceptable "Answers" (later converted to o...
Definition Predicate.h:707
Predicate getEqualTo(Position *pos)
Create a predicate checking if two values are equal.
Definition Predicate.h:716
OperationPosition * getPassthroughOp(Position *p)
Returns the operation position equivalent to the given position.
Definition Predicate.h:630
Position * getAttribute(OperationPosition *p, StringRef name)
Returns an attribute position for an attribute of the given operation.
Definition Predicate.h:642
OperationPosition * getOperandDefiningOp(Position *p)
Returns the parent position defining the value held by the given operand.
Definition Predicate.h:623
Predicate getConstraint(StringRef name, ArrayRef< Position * > args, ArrayRef< Type > resultTypes, bool isNegated)
Create a predicate that applies a generic constraint.
Definition Predicate.h:726
Predicate getOperationName(StringRef name)
Create a predicate comparing the name of an operation to a known value.
Definition Predicate.h:750
Position * getAttributeLiteral(Attribute attr)
Returns an attribute position for the given attribute.
Definition Predicate.h:647
An ordinal predicate consists of a "Question" and a set of acceptable "Answers" (later converted to o...
Definition Predicate.h:422
DenseMap< Value, DenseMap< Value, RootOrderingEntry > > RootOrderingGraph
A directed graph representing the cost of ordering the roots in the predicate tree.
bool operator<(const Fraction &x, const Fraction &y)
Definition Fraction.h:83
Include the generated interface declarations.
llvm::DenseMapInfo< T, Enable > DenseMapInfo
Definition LLVM.h:122
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
A position describing an attribute of an operation.
Definition Predicate.h:175
A BoolNode denotes a question with a boolean-like result.
BoolNode(Position *position, Qualifier *question, Qualifier *answer, std::unique_ptr< MatcherNode > successNode, std::unique_ptr< MatcherNode > failureNode=nullptr)
A position describing the result of a native constraint.
Definition Predicate.h:301
Apply a parameterized constraint to multiple position values and possibly produce results.
Definition Predicate.h:493
An operation position describes an operation node in the IR.
Definition Predicate.h:259
bool isRoot() const
Returns if this operation position corresponds to the root.
Definition Predicate.h:283
bool isOperandDefiningOp() const
Returns if this operation represents an operand defining op.
Definition Predicate.cpp:55
The information associated with an edge in the cost graph.
Value connector
The connector value in the intersection of the two subtrees rooted at the source and target root that...
std::pair< unsigned, unsigned > cost
The depth of the connector Value w.r.t.
SuccessNode(pdl::PatternOp pattern, Value root, std::unique_ptr< MatcherNode > failureNode)
A SwitchNode denotes a question with multiple potential results.
SwitchNode(Position *position, Qualifier *question)
llvm::MapVector< Qualifier *, std::unique_ptr< MatcherNode > > ChildMapT
Returns the children of this switch node.
A position describing the result type of an entity, i.e.
Definition Predicate.h:364