MLIR  18.0.0git
PredicateTree.cpp
Go to the documentation of this file.
1 //===- PredicateTree.cpp - Predicate tree merging -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PredicateTree.h"
10 #include "RootOrdering.h"
11 
15 #include "mlir/IR/BuiltinOps.h"
17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 #include "llvm/Support/Debug.h"
20 #include <queue>
21 
22 #define DEBUG_TYPE "pdl-predicate-tree"
23 
24 using namespace mlir;
25 using namespace mlir::pdl_to_pdl_interp;
26 
27 //===----------------------------------------------------------------------===//
28 // Predicate List Building
29 //===----------------------------------------------------------------------===//
30 
31 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
32  Value val, PredicateBuilder &builder,
34  Position *pos);
35 
36 /// Compares the depths of two positions.
37 static bool comparePosDepth(Position *lhs, Position *rhs) {
38  return lhs->getOperationDepth() < rhs->getOperationDepth();
39 }
40 
41 /// Returns the number of non-range elements within `values`.
42 static unsigned getNumNonRangeValues(ValueRange values) {
43  return llvm::count_if(values.getTypes(),
44  [](Type type) { return !isa<pdl::RangeType>(type); });
45 }
46 
47 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
48  Value val, PredicateBuilder &builder,
50  AttributePosition *pos) {
51  assert(isa<pdl::AttributeType>(val.getType()) && "expected attribute type");
52  pdl::AttributeOp attr = cast<pdl::AttributeOp>(val.getDefiningOp());
53  predList.emplace_back(pos, builder.getIsNotNull());
54 
55  // If the attribute has a type or value, add a constraint.
56  if (Value type = attr.getValueType())
57  getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
58  else if (Attribute value = attr.getValueAttr())
59  predList.emplace_back(pos, builder.getAttributeConstraint(value));
60 }
61 
62 /// Collect all of the predicates for the given operand position.
63 static void getOperandTreePredicates(std::vector<PositionalPredicate> &predList,
64  Value val, PredicateBuilder &builder,
66  Position *pos) {
67  Type valueType = val.getType();
68  bool isVariadic = isa<pdl::RangeType>(valueType);
69 
70  // If this is a typed operand, add a type constraint.
72  .Case<pdl::OperandOp, pdl::OperandsOp>([&](auto op) {
73  // Prevent traversal into a null value if the operand has a proper
74  // index.
75  if (std::is_same<pdl::OperandOp, decltype(op)>::value ||
76  cast<OperandGroupPosition>(pos)->getOperandGroupNumber())
77  predList.emplace_back(pos, builder.getIsNotNull());
78 
79  if (Value type = op.getValueType())
80  getTreePredicates(predList, type, builder, inputs,
81  builder.getType(pos));
82  })
83  .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto op) {
84  std::optional<unsigned> index = op.getIndex();
85 
86  // Prevent traversal into a null value if the result has a proper index.
87  if (index)
88  predList.emplace_back(pos, builder.getIsNotNull());
89 
90  // Get the parent operation of this operand.
91  OperationPosition *parentPos = builder.getOperandDefiningOp(pos);
92  predList.emplace_back(parentPos, builder.getIsNotNull());
93 
94  // Ensure that the operands match the corresponding results of the
95  // parent operation.
96  Position *resultPos = nullptr;
97  if (std::is_same<pdl::ResultOp, decltype(op)>::value)
98  resultPos = builder.getResult(parentPos, *index);
99  else
100  resultPos = builder.getResultGroup(parentPos, index, isVariadic);
101  predList.emplace_back(resultPos, builder.getEqualTo(pos));
102 
103  // Collect the predicates of the parent operation.
104  getTreePredicates(predList, op.getParent(), builder, inputs,
105  (Position *)parentPos);
106  });
107 }
108 
109 static void
110 getTreePredicates(std::vector<PositionalPredicate> &predList, Value val,
111  PredicateBuilder &builder,
113  std::optional<unsigned> ignoreOperand = std::nullopt) {
114  assert(isa<pdl::OperationType>(val.getType()) && "expected operation");
115  pdl::OperationOp op = cast<pdl::OperationOp>(val.getDefiningOp());
116  OperationPosition *opPos = cast<OperationPosition>(pos);
117 
118  // Ensure getDefiningOp returns a non-null operation.
119  if (!opPos->isRoot())
120  predList.emplace_back(pos, builder.getIsNotNull());
121 
122  // Check that this is the correct root operation.
123  if (std::optional<StringRef> opName = op.getOpName())
124  predList.emplace_back(pos, builder.getOperationName(*opName));
125 
126  // Check that the operation has the proper number of operands. If there are
127  // any variable length operands, we check a minimum instead of an exact count.
128  OperandRange operands = op.getOperandValues();
129  unsigned minOperands = getNumNonRangeValues(operands);
130  if (minOperands != operands.size()) {
131  if (minOperands)
132  predList.emplace_back(pos, builder.getOperandCountAtLeast(minOperands));
133  } else {
134  predList.emplace_back(pos, builder.getOperandCount(minOperands));
135  }
136 
137  // Check that the operation has the proper number of results. If there are
138  // any variable length results, we check a minimum instead of an exact count.
139  OperandRange types = op.getTypeValues();
140  unsigned minResults = getNumNonRangeValues(types);
141  if (minResults == types.size())
142  predList.emplace_back(pos, builder.getResultCount(types.size()));
143  else if (minResults)
144  predList.emplace_back(pos, builder.getResultCountAtLeast(minResults));
145 
146  // Recurse into any attributes, operands, or results.
147  for (auto [attrName, attr] :
148  llvm::zip(op.getAttributeValueNames(), op.getAttributeValues())) {
150  predList, attr, builder, inputs,
151  builder.getAttribute(opPos, cast<StringAttr>(attrName).getValue()));
152  }
153 
154  // Process the operands and results of the operation. For all values up to
155  // the first variable length value, we use the concrete operand/result
156  // number. After that, we use the "group" given that we can't know the
157  // concrete indices until runtime. If there is only one variadic operand
158  // group, we treat it as all of the operands/results of the operation.
159  /// Operands.
160  if (operands.size() == 1 && isa<pdl::RangeType>(operands[0].getType())) {
161  // Ignore the operands if we are performing an upward traversal (in that
162  // case, they have already been visited).
163  if (opPos->isRoot() || opPos->isOperandDefiningOp())
164  getTreePredicates(predList, operands.front(), builder, inputs,
165  builder.getAllOperands(opPos));
166  } else {
167  bool foundVariableLength = false;
168  for (const auto &operandIt : llvm::enumerate(operands)) {
169  bool isVariadic = isa<pdl::RangeType>(operandIt.value().getType());
170  foundVariableLength |= isVariadic;
171 
172  // Ignore the specified operand, usually because this position was
173  // visited in an upward traversal via an iterative choice.
174  if (ignoreOperand && *ignoreOperand == operandIt.index())
175  continue;
176 
177  Position *pos =
178  foundVariableLength
179  ? builder.getOperandGroup(opPos, operandIt.index(), isVariadic)
180  : builder.getOperand(opPos, operandIt.index());
181  getTreePredicates(predList, operandIt.value(), builder, inputs, pos);
182  }
183  }
184  /// Results.
185  if (types.size() == 1 && isa<pdl::RangeType>(types[0].getType())) {
186  getTreePredicates(predList, types.front(), builder, inputs,
187  builder.getType(builder.getAllResults(opPos)));
188  return;
189  }
190 
191  bool foundVariableLength = false;
192  for (auto [idx, typeValue] : llvm::enumerate(types)) {
193  bool isVariadic = isa<pdl::RangeType>(typeValue.getType());
194  foundVariableLength |= isVariadic;
195 
196  auto *resultPos = foundVariableLength
197  ? builder.getResultGroup(pos, idx, isVariadic)
198  : builder.getResult(pos, idx);
199  predList.emplace_back(resultPos, builder.getIsNotNull());
200  getTreePredicates(predList, typeValue, builder, inputs,
201  builder.getType(resultPos));
202  }
203 }
204 
205 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
206  Value val, PredicateBuilder &builder,
208  TypePosition *pos) {
209  // Check for a constraint on a constant type.
210  if (pdl::TypeOp typeOp = val.getDefiningOp<pdl::TypeOp>()) {
211  if (Attribute type = typeOp.getConstantTypeAttr())
212  predList.emplace_back(pos, builder.getTypeConstraint(type));
213  } else if (pdl::TypesOp typeOp = val.getDefiningOp<pdl::TypesOp>()) {
214  if (Attribute typeAttr = typeOp.getConstantTypesAttr())
215  predList.emplace_back(pos, builder.getTypeConstraint(typeAttr));
216  }
217 }
218 
219 /// Collect the tree predicates anchored at the given value.
220 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
221  Value val, PredicateBuilder &builder,
223  Position *pos) {
224  // Make sure this input value is accessible to the rewrite.
225  auto it = inputs.try_emplace(val, pos);
226  if (!it.second) {
227  // If this is an input value that has been visited in the tree, add a
228  // constraint to ensure that both instances refer to the same value.
229  if (isa<pdl::AttributeOp, pdl::OperandOp, pdl::OperandsOp, pdl::OperationOp,
230  pdl::TypeOp>(val.getDefiningOp())) {
231  auto minMaxPositions =
232  std::minmax(pos, it.first->second, comparePosDepth);
233  predList.emplace_back(minMaxPositions.second,
234  builder.getEqualTo(minMaxPositions.first));
235  }
236  return;
237  }
238 
240  .Case<AttributePosition, OperationPosition, TypePosition>([&](auto *pos) {
241  getTreePredicates(predList, val, builder, inputs, pos);
242  })
243  .Case<OperandPosition, OperandGroupPosition>([&](auto *pos) {
244  getOperandTreePredicates(predList, val, builder, inputs, pos);
245  })
246  .Default([](auto *) { llvm_unreachable("unexpected position kind"); });
247 }
248 
249 static void getAttributePredicates(pdl::AttributeOp op,
250  std::vector<PositionalPredicate> &predList,
251  PredicateBuilder &builder,
252  DenseMap<Value, Position *> &inputs) {
253  Position *&attrPos = inputs[op];
254  if (attrPos)
255  return;
256  Attribute value = op.getValueAttr();
257  assert(value && "expected non-tree `pdl.attribute` to contain a value");
258  attrPos = builder.getAttributeLiteral(value);
259 }
260 
261 static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
262  std::vector<PositionalPredicate> &predList,
263  PredicateBuilder &builder,
264  DenseMap<Value, Position *> &inputs) {
265  OperandRange arguments = op.getArgs();
266 
267  std::vector<Position *> allPositions;
268  allPositions.reserve(arguments.size());
269  for (Value arg : arguments)
270  allPositions.push_back(inputs.lookup(arg));
271 
272  // Push the constraint to the furthest position.
273  Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
276  builder.getConstraint(op.getName(), allPositions, op.getIsNegated());
277  predList.emplace_back(pos, pred);
278 }
279 
280 static void getResultPredicates(pdl::ResultOp op,
281  std::vector<PositionalPredicate> &predList,
282  PredicateBuilder &builder,
283  DenseMap<Value, Position *> &inputs) {
284  Position *&resultPos = inputs[op];
285  if (resultPos)
286  return;
287 
288  // Ensure that the result isn't null.
289  auto *parentPos = cast<OperationPosition>(inputs.lookup(op.getParent()));
290  resultPos = builder.getResult(parentPos, op.getIndex());
291  predList.emplace_back(resultPos, builder.getIsNotNull());
292 }
293 
294 static void getResultPredicates(pdl::ResultsOp op,
295  std::vector<PositionalPredicate> &predList,
296  PredicateBuilder &builder,
297  DenseMap<Value, Position *> &inputs) {
298  Position *&resultPos = inputs[op];
299  if (resultPos)
300  return;
301 
302  // Ensure that the result isn't null if the result has an index.
303  auto *parentPos = cast<OperationPosition>(inputs.lookup(op.getParent()));
304  bool isVariadic = isa<pdl::RangeType>(op.getType());
305  std::optional<unsigned> index = op.getIndex();
306  resultPos = builder.getResultGroup(parentPos, index, isVariadic);
307  if (index)
308  predList.emplace_back(resultPos, builder.getIsNotNull());
309 }
310 
311 static void getTypePredicates(Value typeValue,
312  function_ref<Attribute()> typeAttrFn,
313  PredicateBuilder &builder,
314  DenseMap<Value, Position *> &inputs) {
315  Position *&typePos = inputs[typeValue];
316  if (typePos)
317  return;
318  Attribute typeAttr = typeAttrFn();
319  assert(typeAttr &&
320  "expected non-tree `pdl.type`/`pdl.types` to contain a value");
321  typePos = builder.getTypeLiteral(typeAttr);
322 }
323 
324 /// Collect all of the predicates that cannot be determined via walking the
325 /// tree.
326 static void getNonTreePredicates(pdl::PatternOp pattern,
327  std::vector<PositionalPredicate> &predList,
328  PredicateBuilder &builder,
329  DenseMap<Value, Position *> &inputs) {
330  for (Operation &op : pattern.getBodyRegion().getOps()) {
332  .Case([&](pdl::AttributeOp attrOp) {
333  getAttributePredicates(attrOp, predList, builder, inputs);
334  })
335  .Case<pdl::ApplyNativeConstraintOp>([&](auto constraintOp) {
336  getConstraintPredicates(constraintOp, predList, builder, inputs);
337  })
338  .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
339  getResultPredicates(resultOp, predList, builder, inputs);
340  })
341  .Case([&](pdl::TypeOp typeOp) {
343  typeOp, [&] { return typeOp.getConstantTypeAttr(); }, builder,
344  inputs);
345  })
346  .Case([&](pdl::TypesOp typeOp) {
348  typeOp, [&] { return typeOp.getConstantTypesAttr(); }, builder,
349  inputs);
350  });
351  }
352 }
353 
354 namespace {
355 
356 /// An op accepting a value at an optional index.
357 struct OpIndex {
358  Value parent;
359  std::optional<unsigned> index;
360 };
361 
362 /// The parent and operand index of each operation for each root, stored
363 /// as a nested map [root][operation].
364 using ParentMaps = DenseMap<Value, DenseMap<Value, OpIndex>>;
365 
366 } // namespace
367 
368 /// Given a pattern, determines the set of roots present in this pattern.
369 /// These are the operations whose results are not consumed by other operations.
370 static SmallVector<Value> detectRoots(pdl::PatternOp pattern) {
371  // First, collect all the operations that are used as operands
372  // to other operations. These are not roots by default.
373  DenseSet<Value> used;
374  for (auto operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>()) {
375  for (Value operand : operationOp.getOperandValues())
376  TypeSwitch<Operation *>(operand.getDefiningOp())
377  .Case<pdl::ResultOp, pdl::ResultsOp>(
378  [&used](auto resultOp) { used.insert(resultOp.getParent()); });
379  }
380 
381  // Remove the specified root from the use set, so that we can
382  // always select it as a root, even if it is used by other operations.
383  if (Value root = pattern.getRewriter().getRoot())
384  used.erase(root);
385 
386  // Finally, collect all the unused operations.
387  SmallVector<Value> roots;
388  for (Value operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>())
389  if (!used.contains(operationOp))
390  roots.push_back(operationOp);
391 
392  return roots;
393 }
394 
395 /// Given a list of candidate roots, builds the cost graph for connecting them.
396 /// The graph is formed by traversing the DAG of operations starting from each
397 /// root and marking the depth of each connector value (operand). Then we join
398 /// the candidate roots based on the common connector values, taking the one
399 /// with the minimum depth. Along the way, we compute, for each candidate root,
400 /// a mapping from each operation (in the DAG underneath this root) to its
401 /// parent operation and the corresponding operand index.
403  ParentMaps &parentMaps) {
404 
405  // The entry of a queue. The entry consists of the following items:
406  // * the value in the DAG underneath the root;
407  // * the parent of the value;
408  // * the operand index of the value in its parent;
409  // * the depth of the visited value.
410  struct Entry {
411  Entry(Value value, Value parent, std::optional<unsigned> index,
412  unsigned depth)
413  : value(value), parent(parent), index(index), depth(depth) {}
414 
415  Value value;
416  Value parent;
417  std::optional<unsigned> index;
418  unsigned depth;
419  };
420 
421  // A root of a value and its depth (distance from root to the value).
422  struct RootDepth {
423  Value root;
424  unsigned depth = 0;
425  };
426 
427  // Map from candidate connector values to their roots and depths. Using a
428  // small vector with 1 entry because most values belong to a single root.
429  llvm::MapVector<Value, SmallVector<RootDepth, 1>> connectorsRootsDepths;
430 
431  // Perform a breadth-first traversal of the op DAG rooted at each root.
432  for (Value root : roots) {
433  // The queue of visited values. A value may be present multiple times in
434  // the queue, for multiple parents. We only accept the first occurrence,
435  // which is guaranteed to have the lowest depth.
436  std::queue<Entry> toVisit;
437  toVisit.emplace(root, Value(), 0, 0);
438 
439  // The map from value to its parent for the current root.
440  DenseMap<Value, OpIndex> &parentMap = parentMaps[root];
441 
442  while (!toVisit.empty()) {
443  Entry entry = toVisit.front();
444  toVisit.pop();
445  // Skip if already visited.
446  if (!parentMap.insert({entry.value, {entry.parent, entry.index}}).second)
447  continue;
448 
449  // Mark the root and depth of the value.
450  connectorsRootsDepths[entry.value].push_back({root, entry.depth});
451 
452  // Traverse the operands of an operation and result ops.
453  // We intentionally do not traverse attributes and types, because those
454  // are expensive to join on.
455  TypeSwitch<Operation *>(entry.value.getDefiningOp())
456  .Case<pdl::OperationOp>([&](auto operationOp) {
457  OperandRange operands = operationOp.getOperandValues();
458  // Special case when we pass all the operands in one range.
459  // For those, the index is empty.
460  if (operands.size() == 1 &&
461  isa<pdl::RangeType>(operands[0].getType())) {
462  toVisit.emplace(operands[0], entry.value, std::nullopt,
463  entry.depth + 1);
464  return;
465  }
466 
467  // Default case: visit all the operands.
468  for (const auto &p :
469  llvm::enumerate(operationOp.getOperandValues()))
470  toVisit.emplace(p.value(), entry.value, p.index(),
471  entry.depth + 1);
472  })
473  .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
474  toVisit.emplace(resultOp.getParent(), entry.value,
475  resultOp.getIndex(), entry.depth);
476  });
477  }
478  }
479 
480  // Now build the cost graph.
481  // This is simply a minimum over all depths for the target root.
482  unsigned nextID = 0;
483  for (const auto &connectorRootsDepths : connectorsRootsDepths) {
484  Value value = connectorRootsDepths.first;
485  ArrayRef<RootDepth> rootsDepths = connectorRootsDepths.second;
486  // If there is only one root for this value, this will not trigger
487  // any edges in the cost graph (a perf optimization).
488  if (rootsDepths.size() == 1)
489  continue;
490 
491  for (const RootDepth &p : rootsDepths) {
492  for (const RootDepth &q : rootsDepths) {
493  if (&p == &q)
494  continue;
495  // Insert or retrieve the property of edge from p to q.
496  RootOrderingEntry &entry = graph[q.root][p.root];
497  if (!entry.connector /* new edge */ || entry.cost.first > q.depth) {
498  if (!entry.connector)
499  entry.cost.second = nextID++;
500  entry.cost.first = q.depth;
501  entry.connector = value;
502  }
503  }
504  }
505  }
506 
507  assert((llvm::hasSingleElement(roots) || graph.size() == roots.size()) &&
508  "the pattern contains a candidate root disconnected from the others");
509 }
510 
511 /// Returns true if the operand at the given index needs to be queried using an
512 /// operand group, i.e., if it is variadic itself or follows a variadic operand.
513 static bool useOperandGroup(pdl::OperationOp op, unsigned index) {
514  OperandRange operands = op.getOperandValues();
515  assert(index < operands.size() && "operand index out of range");
516  for (unsigned i = 0; i <= index; ++i)
517  if (isa<pdl::RangeType>(operands[i].getType()))
518  return true;
519  return false;
520 }
521 
522 /// Visit a node during upward traversal.
523 static void visitUpward(std::vector<PositionalPredicate> &predList,
524  OpIndex opIndex, PredicateBuilder &builder,
525  DenseMap<Value, Position *> &valueToPosition,
526  Position *&pos, unsigned rootID) {
527  Value value = opIndex.parent;
529  .Case<pdl::OperationOp>([&](auto operationOp) {
530  LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
531 
532  // Get users and iterate over them.
533  Position *usersPos = builder.getUsers(pos, /*useRepresentative=*/true);
534  Position *foreachPos = builder.getForEach(usersPos, rootID);
535  OperationPosition *opPos = builder.getPassthroughOp(foreachPos);
536 
537  // Compare the operand(s) of the user against the input value(s).
538  Position *operandPos;
539  if (!opIndex.index) {
540  // We are querying all the operands of the operation.
541  operandPos = builder.getAllOperands(opPos);
542  } else if (useOperandGroup(operationOp, *opIndex.index)) {
543  // We are querying an operand group.
544  Type type = operationOp.getOperandValues()[*opIndex.index].getType();
545  bool variadic = isa<pdl::RangeType>(type);
546  operandPos = builder.getOperandGroup(opPos, opIndex.index, variadic);
547  } else {
548  // We are querying an individual operand.
549  operandPos = builder.getOperand(opPos, *opIndex.index);
550  }
551  predList.emplace_back(operandPos, builder.getEqualTo(pos));
552 
553  // Guard against duplicate upward visits. These are not possible,
554  // because if this value was already visited, it would have been
555  // cheaper to start the traversal at this value rather than at the
556  // `connector`, violating the optimality of our spanning tree.
557  bool inserted = valueToPosition.try_emplace(value, opPos).second;
558  (void)inserted;
559  assert(inserted && "duplicate upward visit");
560 
561  // Obtain the tree predicates at the current value.
562  getTreePredicates(predList, value, builder, valueToPosition, opPos,
563  opIndex.index);
564 
565  // Update the position
566  pos = opPos;
567  })
568  .Case<pdl::ResultOp>([&](auto resultOp) {
569  // Traverse up an individual result.
570  auto *opPos = dyn_cast<OperationPosition>(pos);
571  assert(opPos && "operations and results must be interleaved");
572  pos = builder.getResult(opPos, *opIndex.index);
573 
574  // Insert the result position in case we have not visited it yet.
575  valueToPosition.try_emplace(value, pos);
576  })
577  .Case<pdl::ResultsOp>([&](auto resultOp) {
578  // Traverse up a group of results.
579  auto *opPos = dyn_cast<OperationPosition>(pos);
580  assert(opPos && "operations and results must be interleaved");
581  bool isVariadic = isa<pdl::RangeType>(value.getType());
582  if (opIndex.index)
583  pos = builder.getResultGroup(opPos, opIndex.index, isVariadic);
584  else
585  pos = builder.getAllResults(opPos);
586 
587  // Insert the result position in case we have not visited it yet.
588  valueToPosition.try_emplace(value, pos);
589  });
590 }
591 
592 /// Given a pattern operation, build the set of matcher predicates necessary to
593 /// match this pattern.
594 static Value buildPredicateList(pdl::PatternOp pattern,
595  PredicateBuilder &builder,
596  std::vector<PositionalPredicate> &predList,
597  DenseMap<Value, Position *> &valueToPosition) {
598  SmallVector<Value> roots = detectRoots(pattern);
599 
600  // Build the root ordering graph and compute the parent maps.
601  RootOrderingGraph graph;
602  ParentMaps parentMaps;
603  buildCostGraph(roots, graph, parentMaps);
604  LLVM_DEBUG({
605  llvm::dbgs() << "Graph:\n";
606  for (auto &target : graph) {
607  llvm::dbgs() << " * " << target.first.getLoc() << " " << target.first
608  << "\n";
609  for (auto &source : target.second) {
610  RootOrderingEntry &entry = source.second;
611  llvm::dbgs() << " <- " << source.first << ": " << entry.cost.first
612  << ":" << entry.cost.second << " via "
613  << entry.connector.getLoc() << "\n";
614  }
615  }
616  });
617 
618  // Solve the optimal branching problem for each candidate root, or use the
619  // provided one.
620  Value bestRoot = pattern.getRewriter().getRoot();
621  OptimalBranching::EdgeList bestEdges;
622  if (!bestRoot) {
623  unsigned bestCost = 0;
624  LLVM_DEBUG(llvm::dbgs() << "Candidate roots:\n");
625  for (Value root : roots) {
626  OptimalBranching solver(graph, root);
627  unsigned cost = solver.solve();
628  LLVM_DEBUG(llvm::dbgs() << " * " << root << ": " << cost << "\n");
629  if (!bestRoot || bestCost > cost) {
630  bestCost = cost;
631  bestRoot = root;
632  bestEdges = solver.preOrderTraversal(roots);
633  }
634  }
635  } else {
636  OptimalBranching solver(graph, bestRoot);
637  solver.solve();
638  bestEdges = solver.preOrderTraversal(roots);
639  }
640 
641  // Print the best solution.
642  LLVM_DEBUG({
643  llvm::dbgs() << "Best tree:\n";
644  for (const std::pair<Value, Value> &edge : bestEdges) {
645  llvm::dbgs() << " * " << edge.first;
646  if (edge.second)
647  llvm::dbgs() << " <- " << edge.second;
648  llvm::dbgs() << "\n";
649  }
650  });
651 
652  LLVM_DEBUG(llvm::dbgs() << "Calling key getTreePredicates:\n");
653  LLVM_DEBUG(llvm::dbgs() << " * Value: " << bestRoot << "\n");
654 
655  // The best root is the starting point for the traversal. Get the tree
656  // predicates for the DAG rooted at bestRoot.
657  getTreePredicates(predList, bestRoot, builder, valueToPosition,
658  builder.getRoot());
659 
660  // Traverse the selected optimal branching. For all edges in order, traverse
661  // up starting from the connector, until the candidate root is reached, and
662  // call getTreePredicates at every node along the way.
663  for (const auto &it : llvm::enumerate(bestEdges)) {
664  Value target = it.value().first;
665  Value source = it.value().second;
666 
667  // Check if we already visited the target root. This happens in two cases:
668  // 1) the initial root (bestRoot);
669  // 2) a root that is dominated by (contained in the subtree rooted at) an
670  // already visited root.
671  if (valueToPosition.count(target))
672  continue;
673 
674  // Determine the connector.
675  Value connector = graph[target][source].connector;
676  assert(connector && "invalid edge");
677  LLVM_DEBUG(llvm::dbgs() << " * Connector: " << connector.getLoc() << "\n");
678  DenseMap<Value, OpIndex> parentMap = parentMaps.lookup(target);
679  Position *pos = valueToPosition.lookup(connector);
680  assert(pos && "connector has not been traversed yet");
681 
682  // Traverse from the connector upwards towards the target root.
683  for (Value value = connector; value != target;) {
684  OpIndex opIndex = parentMap.lookup(value);
685  assert(opIndex.parent && "missing parent");
686  visitUpward(predList, opIndex, builder, valueToPosition, pos, it.index());
687  value = opIndex.parent;
688  }
689  }
690 
691  getNonTreePredicates(pattern, predList, builder, valueToPosition);
692 
693  return bestRoot;
694 }
695 
696 //===----------------------------------------------------------------------===//
697 // Pattern Predicate Tree Merging
698 //===----------------------------------------------------------------------===//
699 
700 namespace {
701 
702 /// This class represents a specific predicate applied to a position, and
703 /// provides hashing and ordering operators. This class allows for computing a
704 /// frequence sum and ordering predicates based on a cost model.
705 struct OrderedPredicate {
706  OrderedPredicate(const std::pair<Position *, Qualifier *> &ip)
707  : position(ip.first), question(ip.second) {}
708  OrderedPredicate(const PositionalPredicate &ip)
709  : position(ip.position), question(ip.question) {}
710 
711  /// The position this predicate is applied to.
712  Position *position;
713 
714  /// The question that is applied by this predicate onto the position.
715  Qualifier *question;
716 
717  /// The first and second order benefit sums.
718  /// The primary sum is the number of occurrences of this predicate among all
719  /// of the patterns.
720  unsigned primary = 0;
721  /// The secondary sum is a squared summation of the primary sum of all of the
722  /// predicates within each pattern that contains this predicate. This allows
723  /// for favoring predicates that are more commonly shared within a pattern, as
724  /// opposed to those shared across patterns.
725  unsigned secondary = 0;
726 
727  /// The tie breaking ID, used to preserve a deterministic (insertion) order
728  /// among all the predicates with the same priority, depth, and position /
729  /// predicate dependency.
730  unsigned id = 0;
731 
732  /// A map between a pattern operation and the answer to the predicate question
733  /// within that pattern.
734  DenseMap<Operation *, Qualifier *> patternToAnswer;
735 
736  /// Returns true if this predicate is ordered before `rhs`, based on the cost
737  /// model.
738  bool operator<(const OrderedPredicate &rhs) const {
739  // Sort by:
740  // * higher first and secondary order sums
741  // * lower depth
742  // * lower position dependency
743  // * lower predicate dependency
744  // * lower tie breaking ID
745  auto *rhsPos = rhs.position;
746  return std::make_tuple(primary, secondary, rhsPos->getOperationDepth(),
747  rhsPos->getKind(), rhs.question->getKind(), rhs.id) >
748  std::make_tuple(rhs.primary, rhs.secondary,
749  position->getOperationDepth(), position->getKind(),
750  question->getKind(), id);
751  }
752 };
753 
754 /// A DenseMapInfo for OrderedPredicate based solely on the position and
755 /// question.
756 struct OrderedPredicateDenseInfo {
758 
759  static OrderedPredicate getEmptyKey() { return Base::getEmptyKey(); }
760  static OrderedPredicate getTombstoneKey() { return Base::getTombstoneKey(); }
761  static bool isEqual(const OrderedPredicate &lhs,
762  const OrderedPredicate &rhs) {
763  return lhs.position == rhs.position && lhs.question == rhs.question;
764  }
765  static unsigned getHashValue(const OrderedPredicate &p) {
766  return llvm::hash_combine(p.position, p.question);
767  }
768 };
769 
770 /// This class wraps a set of ordered predicates that are used within a specific
771 /// pattern operation.
772 struct OrderedPredicateList {
773  OrderedPredicateList(pdl::PatternOp pattern, Value root)
774  : pattern(pattern), root(root) {}
775 
776  pdl::PatternOp pattern;
777  Value root;
778  DenseSet<OrderedPredicate *> predicates;
779 };
780 } // namespace
781 
782 /// Returns true if the given matcher refers to the same predicate as the given
783 /// ordered predicate. This means that the position and questions of the two
784 /// match.
785 static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate) {
786  return node->getPosition() == predicate->position &&
787  node->getQuestion() == predicate->question;
788 }
789 
790 /// Get or insert a child matcher for the given parent switch node, given a
791 /// predicate and parent pattern.
792 std::unique_ptr<MatcherNode> &getOrCreateChild(SwitchNode *node,
793  OrderedPredicate *predicate,
794  pdl::PatternOp pattern) {
795  assert(isSamePredicate(node, predicate) &&
796  "expected matcher to equal the given predicate");
797 
798  auto it = predicate->patternToAnswer.find(pattern);
799  assert(it != predicate->patternToAnswer.end() &&
800  "expected pattern to exist in predicate");
801  return node->getChildren().insert({it->second, nullptr}).first->second;
802 }
803 
804 /// Build the matcher CFG by "pushing" patterns through by sorted predicate
805 /// order. A pattern will traverse as far as possible using common predicates
806 /// and then either diverge from the CFG or reach the end of a branch and start
807 /// creating new nodes.
808 static void propagatePattern(std::unique_ptr<MatcherNode> &node,
809  OrderedPredicateList &list,
810  std::vector<OrderedPredicate *>::iterator current,
811  std::vector<OrderedPredicate *>::iterator end) {
812  if (current == end) {
813  // We've hit the end of a pattern, so create a successful result node.
814  node =
815  std::make_unique<SuccessNode>(list.pattern, list.root, std::move(node));
816 
817  // If the pattern doesn't contain this predicate, ignore it.
818  } else if (!list.predicates.contains(*current)) {
819  propagatePattern(node, list, std::next(current), end);
820 
821  // If the current matcher node is invalid, create a new one for this
822  // position and continue propagation.
823  } else if (!node) {
824  // Create a new node at this position and continue
825  node = std::make_unique<SwitchNode>((*current)->position,
826  (*current)->question);
828  getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
829  list, std::next(current), end);
830 
831  // If the matcher has already been created, and it is for this predicate we
832  // continue propagation to the child.
833  } else if (isSamePredicate(node.get(), *current)) {
835  getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
836  list, std::next(current), end);
837 
838  // If the matcher doesn't match the current predicate, insert a branch as
839  // the common set of matchers has diverged.
840  } else {
841  propagatePattern(node->getFailureNode(), list, current, end);
842  }
843 }
844 
845 /// Fold any switch nodes nested under `node` to boolean nodes when possible.
846 /// `node` is updated in-place if it is a switch.
847 static void foldSwitchToBool(std::unique_ptr<MatcherNode> &node) {
848  if (!node)
849  return;
850 
851  if (SwitchNode *switchNode = dyn_cast<SwitchNode>(&*node)) {
852  SwitchNode::ChildMapT &children = switchNode->getChildren();
853  for (auto &it : children)
854  foldSwitchToBool(it.second);
855 
856  // If the node only contains one child, collapse it into a boolean predicate
857  // node.
858  if (children.size() == 1) {
859  auto childIt = children.begin();
860  node = std::make_unique<BoolNode>(
861  node->getPosition(), node->getQuestion(), childIt->first,
862  std::move(childIt->second), std::move(node->getFailureNode()));
863  }
864  } else if (BoolNode *boolNode = dyn_cast<BoolNode>(&*node)) {
865  foldSwitchToBool(boolNode->getSuccessNode());
866  }
867 
868  foldSwitchToBool(node->getFailureNode());
869 }
870 
871 /// Insert an exit node at the end of the failure path of the `root`.
872 static void insertExitNode(std::unique_ptr<MatcherNode> *root) {
873  while (*root)
874  root = &(*root)->getFailureNode();
875  *root = std::make_unique<ExitNode>();
876 }
877 
878 /// Given a module containing PDL pattern operations, generate a matcher tree
879 /// using the patterns within the given module and return the root matcher node.
880 std::unique_ptr<MatcherNode>
881 MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
882  DenseMap<Value, Position *> &valueToPosition) {
883  // The set of predicates contained within the pattern operations of the
884  // module.
885  struct PatternPredicates {
886  PatternPredicates(pdl::PatternOp pattern, Value root,
887  std::vector<PositionalPredicate> predicates)
888  : pattern(pattern), root(root), predicates(std::move(predicates)) {}
889 
890  /// A pattern.
891  pdl::PatternOp pattern;
892 
893  /// A root of the pattern chosen among the candidate roots in pdl.rewrite.
894  Value root;
895 
896  /// The extracted predicates for this pattern and root.
897  std::vector<PositionalPredicate> predicates;
898  };
899 
900  SmallVector<PatternPredicates, 16> patternsAndPredicates;
901  for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
902  std::vector<PositionalPredicate> predicateList;
903  Value root =
904  buildPredicateList(pattern, builder, predicateList, valueToPosition);
905  patternsAndPredicates.emplace_back(pattern, root, std::move(predicateList));
906  }
907 
908  // Associate a pattern result with each unique predicate.
910  for (auto &patternAndPredList : patternsAndPredicates) {
911  for (auto &predicate : patternAndPredList.predicates) {
912  auto it = uniqued.insert(predicate);
913  it.first->patternToAnswer.try_emplace(patternAndPredList.pattern,
914  predicate.answer);
915  // Mark the insertion order (0-based indexing).
916  if (it.second)
917  it.first->id = uniqued.size() - 1;
918  }
919  }
920 
921  // Associate each pattern to a set of its ordered predicates for later lookup.
922  std::vector<OrderedPredicateList> lists;
923  lists.reserve(patternsAndPredicates.size());
924  for (auto &patternAndPredList : patternsAndPredicates) {
925  OrderedPredicateList list(patternAndPredList.pattern,
926  patternAndPredList.root);
927  for (auto &predicate : patternAndPredList.predicates) {
928  OrderedPredicate *orderedPredicate = &*uniqued.find(predicate);
929  list.predicates.insert(orderedPredicate);
930 
931  // Increment the primary sum for each reference to a particular predicate.
932  ++orderedPredicate->primary;
933  }
934  lists.push_back(std::move(list));
935  }
936 
937  // For a particular pattern, get the total primary sum and add it to the
938  // secondary sum of each predicate. Square the primary sums to emphasize
939  // shared predicates within rather than across patterns.
940  for (auto &list : lists) {
941  unsigned total = 0;
942  for (auto *predicate : list.predicates)
943  total += predicate->primary * predicate->primary;
944  for (auto *predicate : list.predicates)
945  predicate->secondary += total;
946  }
947 
948  // Sort the set of predicates now that the cost primary and secondary sums
949  // have been computed.
950  std::vector<OrderedPredicate *> ordered;
951  ordered.reserve(uniqued.size());
952  for (auto &ip : uniqued)
953  ordered.push_back(&ip);
954  llvm::sort(ordered, [](OrderedPredicate *lhs, OrderedPredicate *rhs) {
955  return *lhs < *rhs;
956  });
957 
958  // Build the matchers for each of the pattern predicate lists.
959  std::unique_ptr<MatcherNode> root;
960  for (OrderedPredicateList &list : lists)
961  propagatePattern(root, list, ordered.begin(), ordered.end());
962 
963  // Collapse the graph and insert the exit node.
964  foldSwitchToBool(root);
965  insertExitNode(&root);
966  return root;
967 }
968 
969 //===----------------------------------------------------------------------===//
970 // MatcherNode
971 //===----------------------------------------------------------------------===//
972 
973 MatcherNode::MatcherNode(TypeID matcherTypeID, Position *p, Qualifier *q,
974  std::unique_ptr<MatcherNode> failureNode)
975  : position(p), question(q), failureNode(std::move(failureNode)),
976  matcherTypeID(matcherTypeID) {}
977 
978 //===----------------------------------------------------------------------===//
979 // BoolNode
980 //===----------------------------------------------------------------------===//
981 
982 BoolNode::BoolNode(Position *position, Qualifier *question, Qualifier *answer,
983  std::unique_ptr<MatcherNode> successNode,
984  std::unique_ptr<MatcherNode> failureNode)
985  : MatcherNode(TypeID::get<BoolNode>(), position, question,
986  std::move(failureNode)),
987  answer(answer), successNode(std::move(successNode)) {}
988 
989 //===----------------------------------------------------------------------===//
990 // SuccessNode
991 //===----------------------------------------------------------------------===//
992 
993 SuccessNode::SuccessNode(pdl::PatternOp pattern, Value root,
994  std::unique_ptr<MatcherNode> failureNode)
995  : MatcherNode(TypeID::get<SuccessNode>(), /*position=*/nullptr,
996  /*question=*/nullptr, std::move(failureNode)),
997  pattern(pattern), root(root) {}
998 
999 //===----------------------------------------------------------------------===//
1000 // SwitchNode
1001 //===----------------------------------------------------------------------===//
1002 
1004  : MatcherNode(TypeID::get<SwitchNode>(), position, question) {}
static Value buildPredicateList(pdl::PatternOp pattern, PredicateBuilder &builder, std::vector< PositionalPredicate > &predList, DenseMap< Value, Position * > &valueToPosition)
Given a pattern operation, build the set of matcher predicates necessary to match this pattern.
static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
static void getTypePredicates(Value typeValue, function_ref< Attribute()> typeAttrFn, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
static void getNonTreePredicates(pdl::PatternOp pattern, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
Collect all of the predicates that cannot be determined via walking the tree.
static bool useOperandGroup(pdl::OperationOp op, unsigned index)
Returns true if the operand at the given index needs to be queried using an operand group,...
static void getTreePredicates(std::vector< PositionalPredicate > &predList, Value val, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs, Position *pos)
Collect the tree predicates anchored at the given value.
static void getResultPredicates(pdl::ResultOp op, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
static void getOperandTreePredicates(std::vector< PositionalPredicate > &predList, Value val, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs, Position *pos)
Collect all of the predicates for the given operand position.
std::unique_ptr< MatcherNode > & getOrCreateChild(SwitchNode *node, OrderedPredicate *predicate, pdl::PatternOp pattern)
Get or insert a child matcher for the given parent switch node, given a predicate and parent pattern.
static bool comparePosDepth(Position *lhs, Position *rhs)
Compares the depths of two positions.
static void visitUpward(std::vector< PositionalPredicate > &predList, OpIndex opIndex, PredicateBuilder &builder, DenseMap< Value, Position * > &valueToPosition, Position *&pos, unsigned rootID)
Visit a node during upward traversal.
static unsigned getNumNonRangeValues(ValueRange values)
Returns the number of non-range elements within values.
static SmallVector< Value > detectRoots(pdl::PatternOp pattern)
Given a pattern, determines the set of roots present in this pattern.
static void insertExitNode(std::unique_ptr< MatcherNode > *root)
Insert an exit node at the end of the failure path of the root.
static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate)
Returns true if the given matcher refers to the same predicate as the given ordered predicate.
static void foldSwitchToBool(std::unique_ptr< MatcherNode > &node)
Fold any switch nodes nested under node to boolean nodes when possible.
static void buildCostGraph(ArrayRef< Value > roots, RootOrderingGraph &graph, ParentMaps &parentMaps)
Given a list of candidate roots, builds the cost graph for connecting them.
static void getAttributePredicates(pdl::AttributeOp op, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
static void propagatePattern(std::unique_ptr< MatcherNode > &node, OrderedPredicateList &list, std::vector< OrderedPredicate * >::iterator current, std::vector< OrderedPredicate * >::iterator end)
Build the matcher CFG by "pushing" patterns through by sorted predicate order.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
type_range getType() const
Definition: ValueRange.cpp:30
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
This class represents the base of a predicate matcher node.
Definition: PredicateTree.h:50
Position * getPosition() const
Returns the position on which the question predicate should be checked.
Definition: PredicateTree.h:63
Qualifier * getQuestion() const
Returns the predicate checked on this node.
Definition: PredicateTree.h:66
The optimal branching algorithm solver.
Definition: RootOrdering.h:90
unsigned solve()
Runs the Edmonds' algorithm for the current graph, returning the total cost of the minimum-weight spa...
std::vector< std::pair< Value, Value > > EdgeList
A list of edges (child, parent).
Definition: RootOrdering.h:93
EdgeList preOrderTraversal(ArrayRef< Value > nodes) const
Returns the computed edges as visited in the preorder traversal.
A position describes a value on the input IR on which a predicate may be applied, such as an operatio...
Definition: Predicate.h:143
unsigned getOperationDepth() const
Returns the depth of the first ancestor operation position.
Definition: Predicate.cpp:21
Predicates::Kind getKind() const
Returns the kind of this position.
Definition: Predicate.h:155
const KeyTy & getValue() const
Return the key value of this predicate.
Definition: Predicate.h:110
This class provides utilities for constructing predicates.
Definition: Predicate.h:566
Position * getTypeLiteral(Attribute attr)
Returns a type position for the given type value.
Definition: Predicate.h:638
Predicate getOperandCount(unsigned count)
Create a predicate comparing the number of operands of an operation to a known value.
Definition: Predicate.h:690
OperationPosition * getPassthroughOp(Position *p)
Returns the operation position equivalent to the given position.
Definition: Predicate.h:586
Predicate getIsNotNull()
Create a predicate comparing a value with null.
Definition: Predicate.h:684
Predicate getOperandCountAtLeast(unsigned count)
Definition: Predicate.h:694
Predicate getResultCountAtLeast(unsigned count)
Definition: Predicate.h:711
Position * getType(Position *p)
Returns a type position for the given entity.
Definition: Predicate.h:634
Position * getAttribute(OperationPosition *p, StringRef name)
Returns an attribute position for an attribute of the given operation.
Definition: Predicate.h:592
Position * getOperandGroup(OperationPosition *p, std::optional< unsigned > group, bool isVariadic)
Returns a position for a group of operands of the given operation.
Definition: Predicate.h:611
Position * getForEach(Position *p, unsigned id)
Definition: Predicate.h:601
Position * getOperand(OperationPosition *p, unsigned operand)
Returns an operand position for an operand of the given operation.
Definition: Predicate.h:606
Position * getResult(OperationPosition *p, unsigned result)
Returns a result position for a result of the given operation.
Definition: Predicate.h:620
Position * getRoot()
Returns the root operation position.
Definition: Predicate.h:576
Predicate getAttributeConstraint(Attribute attr)
Create a predicate comparing an attribute to a known value.
Definition: Predicate.h:660
Position * getResultGroup(OperationPosition *p, std::optional< unsigned > group, bool isVariadic)
Returns a position for a group of results of the given operation.
Definition: Predicate.h:625
Position * getAllResults(OperationPosition *p)
Definition: Predicate.h:629
UsersPosition * getUsers(Position *p, bool useRepresentative)
Returns the users of a position using the value at the given operand.
Definition: Predicate.h:643
Predicate getTypeConstraint(Attribute type)
Create a predicate comparing the type of an attribute or value to a known type.
Definition: Predicate.h:719
OperationPosition * getOperandDefiningOp(Position *p)
Returns the parent position defining the value held by the given operand.
Definition: Predicate.h:579
Predicate getResultCount(unsigned count)
Create a predicate comparing the number of results of an operation to a known value.
Definition: Predicate.h:707
std::pair< Qualifier *, Qualifier * > Predicate
An ordinal predicate consists of a "Question" and a set of acceptable "Answers" (later converted to o...
Definition: Predicate.h:657
Predicate getEqualTo(Position *pos)
Create a predicate checking if two values are equal.
Definition: Predicate.h:666
Position * getAllOperands(OperationPosition *p)
Definition: Predicate.h:615
Position * getAttributeLiteral(Attribute attr)
Returns an attribute position for the given attribute.
Definition: Predicate.h:597
Predicate getOperationName(StringRef name)
Create a predicate comparing the name of an operation to a known value.
Definition: Predicate.h:700
Predicate getConstraint(StringRef name, ArrayRef< Position * > pos, bool isNegated)
Create a predicate that applies a generic constraint.
Definition: Predicate.h:676
An ordinal predicate consists of a "Question" and a set of acceptable "Answers" (later converted to o...
Definition: Predicate.h:387
Predicates::Kind getKind() const
Returns the kind of this qualifier.
Definition: Predicate.h:392
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
bool operator<(const Fraction &x, const Fraction &y)
Definition: Fraction.h:80
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
A position describing an attribute of an operation.
Definition: Predicate.h:173
A BoolNode denotes a question with a boolean-like result.
BoolNode(Position *position, Qualifier *question, Qualifier *answer, std::unique_ptr< MatcherNode > successNode, std::unique_ptr< MatcherNode > failureNode=nullptr)
An operation position describes an operation node in the IR.
Definition: Predicate.h:252
bool isRoot() const
Returns if this operation position corresponds to the root.
Definition: Predicate.h:276
bool isOperandDefiningOp() const
Returns if this operation represents an operand defining op.
Definition: Predicate.cpp:51
A PositionalPredicate is a predicate that is associated with a specific positional value.
Definition: PredicateTree.h:30
The information associated with an edge in the cost graph.
Definition: RootOrdering.h:59
Value connector
The connector value in the intersection of the two subtrees rooted at the source and target root that...
Definition: RootOrdering.h:70
std::pair< unsigned, unsigned > cost
The depth of the connector Value w.r.t.
Definition: RootOrdering.h:65
A SuccessNode denotes that a given high level pattern has successfully been matched.
SuccessNode(pdl::PatternOp pattern, Value root, std::unique_ptr< MatcherNode > failureNode)
A SwitchNode denotes a question with multiple potential results.
llvm::MapVector< Qualifier *, std::unique_ptr< MatcherNode > > ChildMapT
Returns the children of this switch node.
SwitchNode(Position *position, Qualifier *question)
A position describing the result type of an entity, i.e.
Definition: Predicate.h:331