MLIR  14.0.0git
PredicateTree.cpp
Go to the documentation of this file.
1 //===- PredicateTree.cpp - Predicate tree merging -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PredicateTree.h"
10 #include "RootOrdering.h"
11 
15 #include "mlir/IR/BuiltinOps.h"
17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 #include "llvm/Support/Debug.h"
20 #include <queue>
21 
22 #define DEBUG_TYPE "pdl-predicate-tree"
23 
24 using namespace mlir;
25 using namespace mlir::pdl_to_pdl_interp;
26 
27 //===----------------------------------------------------------------------===//
28 // Predicate List Building
29 //===----------------------------------------------------------------------===//
30 
31 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
32  Value val, PredicateBuilder &builder,
34  Position *pos);
35 
36 /// Compares the depths of two positions.
37 static bool comparePosDepth(Position *lhs, Position *rhs) {
38  return lhs->getOperationDepth() < rhs->getOperationDepth();
39 }
40 
41 /// Returns the number of non-range elements within `values`.
42 static unsigned getNumNonRangeValues(ValueRange values) {
43  return llvm::count_if(values.getTypes(),
44  [](Type type) { return !type.isa<pdl::RangeType>(); });
45 }
46 
47 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
48  Value val, PredicateBuilder &builder,
50  AttributePosition *pos) {
51  assert(val.getType().isa<pdl::AttributeType>() && "expected attribute type");
52  pdl::AttributeOp attr = cast<pdl::AttributeOp>(val.getDefiningOp());
53  predList.emplace_back(pos, builder.getIsNotNull());
54 
55  // If the attribute has a type or value, add a constraint.
56  if (Value type = attr.type())
57  getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
58  else if (Attribute value = attr.valueAttr())
59  predList.emplace_back(pos, builder.getAttributeConstraint(value));
60 }
61 
62 /// Collect all of the predicates for the given operand position.
63 static void getOperandTreePredicates(std::vector<PositionalPredicate> &predList,
64  Value val, PredicateBuilder &builder,
66  Position *pos) {
67  Type valueType = val.getType();
68  bool isVariadic = valueType.isa<pdl::RangeType>();
69 
70  // If this is a typed operand, add a type constraint.
72  .Case<pdl::OperandOp, pdl::OperandsOp>([&](auto op) {
73  // Prevent traversal into a null value if the operand has a proper
74  // index.
75  if (std::is_same<pdl::OperandOp, decltype(op)>::value ||
76  cast<OperandGroupPosition>(pos)->getOperandGroupNumber())
77  predList.emplace_back(pos, builder.getIsNotNull());
78 
79  if (Value type = op.type())
80  getTreePredicates(predList, type, builder, inputs,
81  builder.getType(pos));
82  })
83  .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto op) {
84  Optional<unsigned> index = op.index();
85 
86  // Prevent traversal into a null value if the result has a proper index.
87  if (index)
88  predList.emplace_back(pos, builder.getIsNotNull());
89 
90  // Get the parent operation of this operand.
91  OperationPosition *parentPos = builder.getOperandDefiningOp(pos);
92  predList.emplace_back(parentPos, builder.getIsNotNull());
93 
94  // Ensure that the operands match the corresponding results of the
95  // parent operation.
96  Position *resultPos = nullptr;
97  if (std::is_same<pdl::ResultOp, decltype(op)>::value)
98  resultPos = builder.getResult(parentPos, *index);
99  else
100  resultPos = builder.getResultGroup(parentPos, index, isVariadic);
101  predList.emplace_back(resultPos, builder.getEqualTo(pos));
102 
103  // Collect the predicates of the parent operation.
104  getTreePredicates(predList, op.parent(), builder, inputs,
105  (Position *)parentPos);
106  });
107 }
108 
109 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
110  Value val, PredicateBuilder &builder,
112  OperationPosition *pos,
113  Optional<unsigned> ignoreOperand = llvm::None) {
114  assert(val.getType().isa<pdl::OperationType>() && "expected operation");
115  pdl::OperationOp op = cast<pdl::OperationOp>(val.getDefiningOp());
116  OperationPosition *opPos = cast<OperationPosition>(pos);
117 
118  // Ensure getDefiningOp returns a non-null operation.
119  if (!opPos->isRoot())
120  predList.emplace_back(pos, builder.getIsNotNull());
121 
122  // Check that this is the correct root operation.
123  if (Optional<StringRef> opName = op.name())
124  predList.emplace_back(pos, builder.getOperationName(*opName));
125 
126  // Check that the operation has the proper number of operands. If there are
127  // any variable length operands, we check a minimum instead of an exact count.
128  OperandRange operands = op.operands();
129  unsigned minOperands = getNumNonRangeValues(operands);
130  if (minOperands != operands.size()) {
131  if (minOperands)
132  predList.emplace_back(pos, builder.getOperandCountAtLeast(minOperands));
133  } else {
134  predList.emplace_back(pos, builder.getOperandCount(minOperands));
135  }
136 
137  // Check that the operation has the proper number of results. If there are
138  // any variable length results, we check a minimum instead of an exact count.
139  OperandRange types = op.types();
140  unsigned minResults = getNumNonRangeValues(types);
141  if (minResults == types.size())
142  predList.emplace_back(pos, builder.getResultCount(types.size()));
143  else if (minResults)
144  predList.emplace_back(pos, builder.getResultCountAtLeast(minResults));
145 
146  // Recurse into any attributes, operands, or results.
147  for (auto it : llvm::zip(op.attributeNames(), op.attributes())) {
149  predList, std::get<1>(it), builder, inputs,
150  builder.getAttribute(opPos,
151  std::get<0>(it).cast<StringAttr>().getValue()));
152  }
153 
154  // Process the operands and results of the operation. For all values up to
155  // the first variable length value, we use the concrete operand/result
156  // number. After that, we use the "group" given that we can't know the
157  // concrete indices until runtime. If there is only one variadic operand
158  // group, we treat it as all of the operands/results of the operation.
159  /// Operands.
160  if (operands.size() == 1 && operands[0].getType().isa<pdl::RangeType>()) {
161  // Ignore the operands if we are performing an upward traversal (in that
162  // case, they have already been visited).
163  if (opPos->isRoot() || opPos->isOperandDefiningOp())
164  getTreePredicates(predList, operands.front(), builder, inputs,
165  builder.getAllOperands(opPos));
166  } else {
167  bool foundVariableLength = false;
168  for (const auto &operandIt : llvm::enumerate(operands)) {
169  bool isVariadic = operandIt.value().getType().isa<pdl::RangeType>();
170  foundVariableLength |= isVariadic;
171 
172  // Ignore the specified operand, usually because this position was
173  // visited in an upward traversal via an iterative choice.
174  if (ignoreOperand && *ignoreOperand == operandIt.index())
175  continue;
176 
177  Position *pos =
178  foundVariableLength
179  ? builder.getOperandGroup(opPos, operandIt.index(), isVariadic)
180  : builder.getOperand(opPos, operandIt.index());
181  getTreePredicates(predList, operandIt.value(), builder, inputs, pos);
182  }
183  }
184  /// Results.
185  if (types.size() == 1 && types[0].getType().isa<pdl::RangeType>()) {
186  getTreePredicates(predList, types.front(), builder, inputs,
187  builder.getType(builder.getAllResults(opPos)));
188  } else {
189  bool foundVariableLength = false;
190  for (auto &resultIt : llvm::enumerate(types)) {
191  bool isVariadic = resultIt.value().getType().isa<pdl::RangeType>();
192  foundVariableLength |= isVariadic;
193 
194  auto *resultPos =
195  foundVariableLength
196  ? builder.getResultGroup(pos, resultIt.index(), isVariadic)
197  : builder.getResult(pos, resultIt.index());
198  predList.emplace_back(resultPos, builder.getIsNotNull());
199  getTreePredicates(predList, resultIt.value(), builder, inputs,
200  builder.getType(resultPos));
201  }
202  }
203 }
204 
205 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
206  Value val, PredicateBuilder &builder,
208  TypePosition *pos) {
209  // Check for a constraint on a constant type.
210  if (pdl::TypeOp typeOp = val.getDefiningOp<pdl::TypeOp>()) {
211  if (Attribute type = typeOp.typeAttr())
212  predList.emplace_back(pos, builder.getTypeConstraint(type));
213  } else if (pdl::TypesOp typeOp = val.getDefiningOp<pdl::TypesOp>()) {
214  if (Attribute typeAttr = typeOp.typesAttr())
215  predList.emplace_back(pos, builder.getTypeConstraint(typeAttr));
216  }
217 }
218 
219 /// Collect the tree predicates anchored at the given value.
220 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
221  Value val, PredicateBuilder &builder,
223  Position *pos) {
224  // Make sure this input value is accessible to the rewrite.
225  auto it = inputs.try_emplace(val, pos);
226  if (!it.second) {
227  // If this is an input value that has been visited in the tree, add a
228  // constraint to ensure that both instances refer to the same value.
229  if (isa<pdl::AttributeOp, pdl::OperandOp, pdl::OperandsOp, pdl::OperationOp,
230  pdl::TypeOp>(val.getDefiningOp())) {
231  auto minMaxPositions =
232  std::minmax(pos, it.first->second, comparePosDepth);
233  predList.emplace_back(minMaxPositions.second,
234  builder.getEqualTo(minMaxPositions.first));
235  }
236  return;
237  }
238 
240  .Case<AttributePosition, OperationPosition, TypePosition>([&](auto *pos) {
241  getTreePredicates(predList, val, builder, inputs, pos);
242  })
243  .Case<OperandPosition, OperandGroupPosition>([&](auto *pos) {
244  getOperandTreePredicates(predList, val, builder, inputs, pos);
245  })
246  .Default([](auto *) { llvm_unreachable("unexpected position kind"); });
247 }
248 
249 static void getAttributePredicates(pdl::AttributeOp op,
250  std::vector<PositionalPredicate> &predList,
251  PredicateBuilder &builder,
252  DenseMap<Value, Position *> &inputs) {
253  Position *&attrPos = inputs[op];
254  if (attrPos)
255  return;
256  Attribute value = op.valueAttr();
257  assert(value && "expected non-tree `pdl.attribute` to contain a value");
258  attrPos = builder.getAttributeLiteral(value);
259 }
260 
261 static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
262  std::vector<PositionalPredicate> &predList,
263  PredicateBuilder &builder,
264  DenseMap<Value, Position *> &inputs) {
265  OperandRange arguments = op.args();
266  ArrayAttr parameters = op.constParamsAttr();
267 
268  std::vector<Position *> allPositions;
269  allPositions.reserve(arguments.size());
270  for (Value arg : arguments)
271  allPositions.push_back(inputs.lookup(arg));
272 
273  // Push the constraint to the furthest position.
274  Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
277  builder.getConstraint(op.name(), allPositions, parameters);
278  predList.emplace_back(pos, pred);
279 }
280 
281 static void getResultPredicates(pdl::ResultOp op,
282  std::vector<PositionalPredicate> &predList,
283  PredicateBuilder &builder,
284  DenseMap<Value, Position *> &inputs) {
285  Position *&resultPos = inputs[op];
286  if (resultPos)
287  return;
288 
289  // Ensure that the result isn't null.
290  auto *parentPos = cast<OperationPosition>(inputs.lookup(op.parent()));
291  resultPos = builder.getResult(parentPos, op.index());
292  predList.emplace_back(resultPos, builder.getIsNotNull());
293 }
294 
295 static void getResultPredicates(pdl::ResultsOp op,
296  std::vector<PositionalPredicate> &predList,
297  PredicateBuilder &builder,
298  DenseMap<Value, Position *> &inputs) {
299  Position *&resultPos = inputs[op];
300  if (resultPos)
301  return;
302 
303  // Ensure that the result isn't null if the result has an index.
304  auto *parentPos = cast<OperationPosition>(inputs.lookup(op.parent()));
305  bool isVariadic = op.getType().isa<pdl::RangeType>();
306  Optional<unsigned> index = op.index();
307  resultPos = builder.getResultGroup(parentPos, index, isVariadic);
308  if (index)
309  predList.emplace_back(resultPos, builder.getIsNotNull());
310 }
311 
312 static void getTypePredicates(Value typeValue,
313  function_ref<Attribute()> typeAttrFn,
314  PredicateBuilder &builder,
315  DenseMap<Value, Position *> &inputs) {
316  Position *&typePos = inputs[typeValue];
317  if (typePos)
318  return;
319  Attribute typeAttr = typeAttrFn();
320  assert(typeAttr &&
321  "expected non-tree `pdl.type`/`pdl.types` to contain a value");
322  typePos = builder.getTypeLiteral(typeAttr);
323 }
324 
325 /// Collect all of the predicates that cannot be determined via walking the
326 /// tree.
327 static void getNonTreePredicates(pdl::PatternOp pattern,
328  std::vector<PositionalPredicate> &predList,
329  PredicateBuilder &builder,
330  DenseMap<Value, Position *> &inputs) {
331  for (Operation &op : pattern.body().getOps()) {
333  .Case([&](pdl::AttributeOp attrOp) {
334  getAttributePredicates(attrOp, predList, builder, inputs);
335  })
336  .Case<pdl::ApplyNativeConstraintOp>([&](auto constraintOp) {
337  getConstraintPredicates(constraintOp, predList, builder, inputs);
338  })
339  .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
340  getResultPredicates(resultOp, predList, builder, inputs);
341  })
342  .Case([&](pdl::TypeOp typeOp) {
344  typeOp, [&] { return typeOp.typeAttr(); }, builder, inputs);
345  })
346  .Case([&](pdl::TypesOp typeOp) {
348  typeOp, [&] { return typeOp.typesAttr(); }, builder, inputs);
349  });
350  }
351 }
352 
353 namespace {
354 
355 /// An op accepting a value at an optional index.
356 struct OpIndex {
357  Value parent;
358  Optional<unsigned> index;
359 };
360 
361 /// The parent and operand index of each operation for each root, stored
362 /// as a nested map [root][operation].
363 using ParentMaps = DenseMap<Value, DenseMap<Value, OpIndex>>;
364 
365 } // namespace
366 
367 /// Given a pattern, determines the set of roots present in this pattern.
368 /// These are the operations whose results are not consumed by other operations.
369 static SmallVector<Value> detectRoots(pdl::PatternOp pattern) {
370  // First, collect all the operations that are used as operands
371  // to other operations. These are not roots by default.
372  DenseSet<Value> used;
373  for (auto operationOp : pattern.body().getOps<pdl::OperationOp>()) {
374  for (Value operand : operationOp.operands())
375  TypeSwitch<Operation *>(operand.getDefiningOp())
376  .Case<pdl::ResultOp, pdl::ResultsOp>(
377  [&used](auto resultOp) { used.insert(resultOp.parent()); });
378  }
379 
380  // Remove the specified root from the use set, so that we can
381  // always select it as a root, even if it is used by other operations.
382  if (Value root = pattern.getRewriter().root())
383  used.erase(root);
384 
385  // Finally, collect all the unused operations.
386  SmallVector<Value> roots;
387  for (Value operationOp : pattern.body().getOps<pdl::OperationOp>())
388  if (!used.contains(operationOp))
389  roots.push_back(operationOp);
390 
391  return roots;
392 }
393 
394 /// Given a list of candidate roots, builds the cost graph for connecting them.
395 /// The graph is formed by traversing the DAG of operations starting from each
396 /// root and marking the depth of each connector value (operand). Then we join
397 /// the candidate roots based on the common connector values, taking the one
398 /// with the minimum depth. Along the way, we compute, for each candidate root,
399 /// a mapping from each operation (in the DAG underneath this root) to its
400 /// parent operation and the corresponding operand index.
402  ParentMaps &parentMaps) {
403 
404  // The entry of a queue. The entry consists of the following items:
405  // * the value in the DAG underneath the root;
406  // * the parent of the value;
407  // * the operand index of the value in its parent;
408  // * the depth of the visited value.
409  struct Entry {
410  Entry(Value value, Value parent, Optional<unsigned> index, unsigned depth)
411  : value(value), parent(parent), index(index), depth(depth) {}
412 
413  Value value;
414  Value parent;
415  Optional<unsigned> index;
416  unsigned depth;
417  };
418 
419  // A root of a value and its depth (distance from root to the value).
420  struct RootDepth {
421  Value root;
422  unsigned depth = 0;
423  };
424 
425  // Map from candidate connector values to their roots and depths. Using a
426  // small vector with 1 entry because most values belong to a single root.
427  llvm::MapVector<Value, SmallVector<RootDepth, 1>> connectorsRootsDepths;
428 
429  // Perform a breadth-first traversal of the op DAG rooted at each root.
430  for (Value root : roots) {
431  // The queue of visited values. A value may be present multiple times in
432  // the queue, for multiple parents. We only accept the first occurrence,
433  // which is guaranteed to have the lowest depth.
434  std::queue<Entry> toVisit;
435  toVisit.emplace(root, Value(), 0, 0);
436 
437  // The map from value to its parent for the current root.
438  DenseMap<Value, OpIndex> &parentMap = parentMaps[root];
439 
440  while (!toVisit.empty()) {
441  Entry entry = toVisit.front();
442  toVisit.pop();
443  // Skip if already visited.
444  if (!parentMap.insert({entry.value, {entry.parent, entry.index}}).second)
445  continue;
446 
447  // Mark the root and depth of the value.
448  connectorsRootsDepths[entry.value].push_back({root, entry.depth});
449 
450  // Traverse the operands of an operation and result ops.
451  // We intentionally do not traverse attributes and types, because those
452  // are expensive to join on.
453  TypeSwitch<Operation *>(entry.value.getDefiningOp())
454  .Case<pdl::OperationOp>([&](auto operationOp) {
455  OperandRange operands = operationOp.operands();
456  // Special case when we pass all the operands in one range.
457  // For those, the index is empty.
458  if (operands.size() == 1 &&
459  operands[0].getType().isa<pdl::RangeType>()) {
460  toVisit.emplace(operands[0], entry.value, llvm::None,
461  entry.depth + 1);
462  return;
463  }
464 
465  // Default case: visit all the operands.
466  for (const auto &p : llvm::enumerate(operationOp.operands()))
467  toVisit.emplace(p.value(), entry.value, p.index(),
468  entry.depth + 1);
469  })
470  .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
471  toVisit.emplace(resultOp.parent(), entry.value, resultOp.index(),
472  entry.depth);
473  });
474  }
475  }
476 
477  // Now build the cost graph.
478  // This is simply a minimum over all depths for the target root.
479  unsigned nextID = 0;
480  for (const auto &connectorRootsDepths : connectorsRootsDepths) {
481  Value value = connectorRootsDepths.first;
482  ArrayRef<RootDepth> rootsDepths = connectorRootsDepths.second;
483  // If there is only one root for this value, this will not trigger
484  // any edges in the cost graph (a perf optimization).
485  if (rootsDepths.size() == 1)
486  continue;
487 
488  for (const RootDepth &p : rootsDepths) {
489  for (const RootDepth &q : rootsDepths) {
490  if (&p == &q)
491  continue;
492  // Insert or retrieve the property of edge from p to q.
493  RootOrderingEntry &entry = graph[q.root][p.root];
494  if (!entry.connector /* new edge */ || entry.cost.first > q.depth) {
495  if (!entry.connector)
496  entry.cost.second = nextID++;
497  entry.cost.first = q.depth;
498  entry.connector = value;
499  }
500  }
501  }
502  }
503 
504  assert((llvm::hasSingleElement(roots) || graph.size() == roots.size()) &&
505  "the pattern contains a candidate root disconnected from the others");
506 }
507 
508 /// Returns true if the operand at the given index needs to be queried using an
509 /// operand group, i.e., if it is variadic itself or follows a variadic operand.
510 static bool useOperandGroup(pdl::OperationOp op, unsigned index) {
511  OperandRange operands = op.operands();
512  assert(index < operands.size() && "operand index out of range");
513  for (unsigned i = 0; i <= index; ++i)
514  if (operands[i].getType().isa<pdl::RangeType>())
515  return true;
516  return false;
517 }
518 
519 /// Visit a node during upward traversal.
520 static void visitUpward(std::vector<PositionalPredicate> &predList,
521  OpIndex opIndex, PredicateBuilder &builder,
522  DenseMap<Value, Position *> &valueToPosition,
523  Position *&pos, unsigned rootID) {
524  Value value = opIndex.parent;
526  .Case<pdl::OperationOp>([&](auto operationOp) {
527  LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
528 
529  // Get users and iterate over them.
530  Position *usersPos = builder.getUsers(pos, /*useRepresentative=*/true);
531  Position *foreachPos = builder.getForEach(usersPos, rootID);
532  OperationPosition *opPos = builder.getPassthroughOp(foreachPos);
533 
534  // Compare the operand(s) of the user against the input value(s).
535  Position *operandPos;
536  if (!opIndex.index) {
537  // We are querying all the operands of the operation.
538  operandPos = builder.getAllOperands(opPos);
539  } else if (useOperandGroup(operationOp, *opIndex.index)) {
540  // We are querying an operand group.
541  Type type = operationOp.operands()[*opIndex.index].getType();
542  bool variadic = type.isa<pdl::RangeType>();
543  operandPos = builder.getOperandGroup(opPos, opIndex.index, variadic);
544  } else {
545  // We are querying an individual operand.
546  operandPos = builder.getOperand(opPos, *opIndex.index);
547  }
548  predList.emplace_back(operandPos, builder.getEqualTo(pos));
549 
550  // Guard against duplicate upward visits. These are not possible,
551  // because if this value was already visited, it would have been
552  // cheaper to start the traversal at this value rather than at the
553  // `connector`, violating the optimality of our spanning tree.
554  bool inserted = valueToPosition.try_emplace(value, opPos).second;
555  (void)inserted;
556  assert(inserted && "duplicate upward visit");
557 
558  // Obtain the tree predicates at the current value.
559  getTreePredicates(predList, value, builder, valueToPosition, opPos,
560  opIndex.index);
561 
562  // Update the position
563  pos = opPos;
564  })
565  .Case<pdl::ResultOp>([&](auto resultOp) {
566  // Traverse up an individual result.
567  auto *opPos = dyn_cast<OperationPosition>(pos);
568  assert(opPos && "operations and results must be interleaved");
569  pos = builder.getResult(opPos, *opIndex.index);
570 
571  // Insert the result position in case we have not visited it yet.
572  valueToPosition.try_emplace(value, pos);
573  })
574  .Case<pdl::ResultsOp>([&](auto resultOp) {
575  // Traverse up a group of results.
576  auto *opPos = dyn_cast<OperationPosition>(pos);
577  assert(opPos && "operations and results must be interleaved");
578  bool isVariadic = value.getType().isa<pdl::RangeType>();
579  if (opIndex.index)
580  pos = builder.getResultGroup(opPos, opIndex.index, isVariadic);
581  else
582  pos = builder.getAllResults(opPos);
583 
584  // Insert the result position in case we have not visited it yet.
585  valueToPosition.try_emplace(value, pos);
586  });
587 }
588 
589 /// Given a pattern operation, build the set of matcher predicates necessary to
590 /// match this pattern.
591 static Value buildPredicateList(pdl::PatternOp pattern,
592  PredicateBuilder &builder,
593  std::vector<PositionalPredicate> &predList,
594  DenseMap<Value, Position *> &valueToPosition) {
595  SmallVector<Value> roots = detectRoots(pattern);
596 
597  // Build the root ordering graph and compute the parent maps.
598  RootOrderingGraph graph;
599  ParentMaps parentMaps;
600  buildCostGraph(roots, graph, parentMaps);
601  LLVM_DEBUG({
602  llvm::dbgs() << "Graph:\n";
603  for (auto &target : graph) {
604  llvm::dbgs() << " * " << target.first.getLoc() << " " << target.first
605  << "\n";
606  for (auto &source : target.second) {
607  RootOrderingEntry &entry = source.second;
608  llvm::dbgs() << " <- " << source.first << ": " << entry.cost.first
609  << ":" << entry.cost.second << " via "
610  << entry.connector.getLoc() << "\n";
611  }
612  }
613  });
614 
615  // Solve the optimal branching problem for each candidate root, or use the
616  // provided one.
617  Value bestRoot = pattern.getRewriter().root();
618  OptimalBranching::EdgeList bestEdges;
619  if (!bestRoot) {
620  unsigned bestCost = 0;
621  LLVM_DEBUG(llvm::dbgs() << "Candidate roots:\n");
622  for (Value root : roots) {
623  OptimalBranching solver(graph, root);
624  unsigned cost = solver.solve();
625  LLVM_DEBUG(llvm::dbgs() << " * " << root << ": " << cost << "\n");
626  if (!bestRoot || bestCost > cost) {
627  bestCost = cost;
628  bestRoot = root;
629  bestEdges = solver.preOrderTraversal(roots);
630  }
631  }
632  } else {
633  OptimalBranching solver(graph, bestRoot);
634  solver.solve();
635  bestEdges = solver.preOrderTraversal(roots);
636  }
637 
638  // Print the best solution.
639  LLVM_DEBUG({
640  llvm::dbgs() << "Best tree:\n";
641  for (const std::pair<Value, Value> &edge : bestEdges) {
642  llvm::dbgs() << " * " << edge.first;
643  if (edge.second)
644  llvm::dbgs() << " <- " << edge.second;
645  llvm::dbgs() << "\n";
646  }
647  });
648 
649  LLVM_DEBUG(llvm::dbgs() << "Calling key getTreePredicates:\n");
650  LLVM_DEBUG(llvm::dbgs() << " * Value: " << bestRoot << "\n");
651 
652  // The best root is the starting point for the traversal. Get the tree
653  // predicates for the DAG rooted at bestRoot.
654  getTreePredicates(predList, bestRoot, builder, valueToPosition,
655  builder.getRoot());
656 
657  // Traverse the selected optimal branching. For all edges in order, traverse
658  // up starting from the connector, until the candidate root is reached, and
659  // call getTreePredicates at every node along the way.
660  for (const auto &it : llvm::enumerate(bestEdges)) {
661  Value target = it.value().first;
662  Value source = it.value().second;
663 
664  // Check if we already visited the target root. This happens in two cases:
665  // 1) the initial root (bestRoot);
666  // 2) a root that is dominated by (contained in the subtree rooted at) an
667  // already visited root.
668  if (valueToPosition.count(target))
669  continue;
670 
671  // Determine the connector.
672  Value connector = graph[target][source].connector;
673  assert(connector && "invalid edge");
674  LLVM_DEBUG(llvm::dbgs() << " * Connector: " << connector.getLoc() << "\n");
675  DenseMap<Value, OpIndex> parentMap = parentMaps.lookup(target);
676  Position *pos = valueToPosition.lookup(connector);
677  assert(pos && "connector has not been traversed yet");
678 
679  // Traverse from the connector upwards towards the target root.
680  for (Value value = connector; value != target;) {
681  OpIndex opIndex = parentMap.lookup(value);
682  assert(opIndex.parent && "missing parent");
683  visitUpward(predList, opIndex, builder, valueToPosition, pos, it.index());
684  value = opIndex.parent;
685  }
686  }
687 
688  getNonTreePredicates(pattern, predList, builder, valueToPosition);
689 
690  return bestRoot;
691 }
692 
693 //===----------------------------------------------------------------------===//
694 // Pattern Predicate Tree Merging
695 //===----------------------------------------------------------------------===//
696 
697 namespace {
698 
699 /// This class represents a specific predicate applied to a position, and
700 /// provides hashing and ordering operators. This class allows for computing a
701 /// frequence sum and ordering predicates based on a cost model.
702 struct OrderedPredicate {
703  OrderedPredicate(const std::pair<Position *, Qualifier *> &ip)
704  : position(ip.first), question(ip.second) {}
705  OrderedPredicate(const PositionalPredicate &ip)
706  : position(ip.position), question(ip.question) {}
707 
708  /// The position this predicate is applied to.
709  Position *position;
710 
711  /// The question that is applied by this predicate onto the position.
712  Qualifier *question;
713 
714  /// The first and second order benefit sums.
715  /// The primary sum is the number of occurrences of this predicate among all
716  /// of the patterns.
717  unsigned primary = 0;
718  /// The secondary sum is a squared summation of the primary sum of all of the
719  /// predicates within each pattern that contains this predicate. This allows
720  /// for favoring predicates that are more commonly shared within a pattern, as
721  /// opposed to those shared across patterns.
722  unsigned secondary = 0;
723 
724  /// The tie breaking ID, used to preserve a deterministic (insertion) order
725  /// among all the predicates with the same priority, depth, and position /
726  /// predicate dependency.
727  unsigned id = 0;
728 
729  /// A map between a pattern operation and the answer to the predicate question
730  /// within that pattern.
731  DenseMap<Operation *, Qualifier *> patternToAnswer;
732 
733  /// Returns true if this predicate is ordered before `rhs`, based on the cost
734  /// model.
735  bool operator<(const OrderedPredicate &rhs) const {
736  // Sort by:
737  // * higher first and secondary order sums
738  // * lower depth
739  // * lower position dependency
740  // * lower predicate dependency
741  // * lower tie breaking ID
742  auto *rhsPos = rhs.position;
743  return std::make_tuple(primary, secondary, rhsPos->getOperationDepth(),
744  rhsPos->getKind(), rhs.question->getKind(), rhs.id) >
745  std::make_tuple(rhs.primary, rhs.secondary,
746  position->getOperationDepth(), position->getKind(),
747  question->getKind(), id);
748  }
749 };
750 
751 /// A DenseMapInfo for OrderedPredicate based solely on the position and
752 /// question.
753 struct OrderedPredicateDenseInfo {
755 
756  static OrderedPredicate getEmptyKey() { return Base::getEmptyKey(); }
757  static OrderedPredicate getTombstoneKey() { return Base::getTombstoneKey(); }
758  static bool isEqual(const OrderedPredicate &lhs,
759  const OrderedPredicate &rhs) {
760  return lhs.position == rhs.position && lhs.question == rhs.question;
761  }
762  static unsigned getHashValue(const OrderedPredicate &p) {
763  return llvm::hash_combine(p.position, p.question);
764  }
765 };
766 
767 /// This class wraps a set of ordered predicates that are used within a specific
768 /// pattern operation.
769 struct OrderedPredicateList {
770  OrderedPredicateList(pdl::PatternOp pattern, Value root)
771  : pattern(pattern), root(root) {}
772 
773  pdl::PatternOp pattern;
774  Value root;
775  DenseSet<OrderedPredicate *> predicates;
776 };
777 } // namespace
778 
779 /// Returns true if the given matcher refers to the same predicate as the given
780 /// ordered predicate. This means that the position and questions of the two
781 /// match.
782 static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate) {
783  return node->getPosition() == predicate->position &&
784  node->getQuestion() == predicate->question;
785 }
786 
787 /// Get or insert a child matcher for the given parent switch node, given a
788 /// predicate and parent pattern.
789 std::unique_ptr<MatcherNode> &getOrCreateChild(SwitchNode *node,
790  OrderedPredicate *predicate,
791  pdl::PatternOp pattern) {
792  assert(isSamePredicate(node, predicate) &&
793  "expected matcher to equal the given predicate");
794 
795  auto it = predicate->patternToAnswer.find(pattern);
796  assert(it != predicate->patternToAnswer.end() &&
797  "expected pattern to exist in predicate");
798  return node->getChildren().insert({it->second, nullptr}).first->second;
799 }
800 
801 /// Build the matcher CFG by "pushing" patterns through by sorted predicate
802 /// order. A pattern will traverse as far as possible using common predicates
803 /// and then either diverge from the CFG or reach the end of a branch and start
804 /// creating new nodes.
805 static void propagatePattern(std::unique_ptr<MatcherNode> &node,
806  OrderedPredicateList &list,
807  std::vector<OrderedPredicate *>::iterator current,
808  std::vector<OrderedPredicate *>::iterator end) {
809  if (current == end) {
810  // We've hit the end of a pattern, so create a successful result node.
811  node =
812  std::make_unique<SuccessNode>(list.pattern, list.root, std::move(node));
813 
814  // If the pattern doesn't contain this predicate, ignore it.
815  } else if (list.predicates.find(*current) == list.predicates.end()) {
816  propagatePattern(node, list, std::next(current), end);
817 
818  // If the current matcher node is invalid, create a new one for this
819  // position and continue propagation.
820  } else if (!node) {
821  // Create a new node at this position and continue
822  node = std::make_unique<SwitchNode>((*current)->position,
823  (*current)->question);
825  getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
826  list, std::next(current), end);
827 
828  // If the matcher has already been created, and it is for this predicate we
829  // continue propagation to the child.
830  } else if (isSamePredicate(node.get(), *current)) {
832  getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
833  list, std::next(current), end);
834 
835  // If the matcher doesn't match the current predicate, insert a branch as
836  // the common set of matchers has diverged.
837  } else {
838  propagatePattern(node->getFailureNode(), list, current, end);
839  }
840 }
841 
842 /// Fold any switch nodes nested under `node` to boolean nodes when possible.
843 /// `node` is updated in-place if it is a switch.
844 static void foldSwitchToBool(std::unique_ptr<MatcherNode> &node) {
845  if (!node)
846  return;
847 
848  if (SwitchNode *switchNode = dyn_cast<SwitchNode>(&*node)) {
849  SwitchNode::ChildMapT &children = switchNode->getChildren();
850  for (auto &it : children)
851  foldSwitchToBool(it.second);
852 
853  // If the node only contains one child, collapse it into a boolean predicate
854  // node.
855  if (children.size() == 1) {
856  auto childIt = children.begin();
857  node = std::make_unique<BoolNode>(
858  node->getPosition(), node->getQuestion(), childIt->first,
859  std::move(childIt->second), std::move(node->getFailureNode()));
860  }
861  } else if (BoolNode *boolNode = dyn_cast<BoolNode>(&*node)) {
862  foldSwitchToBool(boolNode->getSuccessNode());
863  }
864 
865  foldSwitchToBool(node->getFailureNode());
866 }
867 
868 /// Insert an exit node at the end of the failure path of the `root`.
869 static void insertExitNode(std::unique_ptr<MatcherNode> *root) {
870  while (*root)
871  root = &(*root)->getFailureNode();
872  *root = std::make_unique<ExitNode>();
873 }
874 
875 /// Given a module containing PDL pattern operations, generate a matcher tree
876 /// using the patterns within the given module and return the root matcher node.
877 std::unique_ptr<MatcherNode>
879  DenseMap<Value, Position *> &valueToPosition) {
880  // The set of predicates contained within the pattern operations of the
881  // module.
882  struct PatternPredicates {
883  PatternPredicates(pdl::PatternOp pattern, Value root,
884  std::vector<PositionalPredicate> predicates)
885  : pattern(pattern), root(root), predicates(std::move(predicates)) {}
886 
887  /// A pattern.
888  pdl::PatternOp pattern;
889 
890  /// A root of the pattern chosen among the candidate roots in pdl.rewrite.
891  Value root;
892 
893  /// The extracted predicates for this pattern and root.
894  std::vector<PositionalPredicate> predicates;
895  };
896 
897  SmallVector<PatternPredicates, 16> patternsAndPredicates;
898  for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
899  std::vector<PositionalPredicate> predicateList;
900  Value root =
901  buildPredicateList(pattern, builder, predicateList, valueToPosition);
902  patternsAndPredicates.emplace_back(pattern, root, std::move(predicateList));
903  }
904 
905  // Associate a pattern result with each unique predicate.
907  for (auto &patternAndPredList : patternsAndPredicates) {
908  for (auto &predicate : patternAndPredList.predicates) {
909  auto it = uniqued.insert(predicate);
910  it.first->patternToAnswer.try_emplace(patternAndPredList.pattern,
911  predicate.answer);
912  // Mark the insertion order (0-based indexing).
913  if (it.second)
914  it.first->id = uniqued.size() - 1;
915  }
916  }
917 
918  // Associate each pattern to a set of its ordered predicates for later lookup.
919  std::vector<OrderedPredicateList> lists;
920  lists.reserve(patternsAndPredicates.size());
921  for (auto &patternAndPredList : patternsAndPredicates) {
922  OrderedPredicateList list(patternAndPredList.pattern,
923  patternAndPredList.root);
924  for (auto &predicate : patternAndPredList.predicates) {
925  OrderedPredicate *orderedPredicate = &*uniqued.find(predicate);
926  list.predicates.insert(orderedPredicate);
927 
928  // Increment the primary sum for each reference to a particular predicate.
929  ++orderedPredicate->primary;
930  }
931  lists.push_back(std::move(list));
932  }
933 
934  // For a particular pattern, get the total primary sum and add it to the
935  // secondary sum of each predicate. Square the primary sums to emphasize
936  // shared predicates within rather than across patterns.
937  for (auto &list : lists) {
938  unsigned total = 0;
939  for (auto *predicate : list.predicates)
940  total += predicate->primary * predicate->primary;
941  for (auto *predicate : list.predicates)
942  predicate->secondary += total;
943  }
944 
945  // Sort the set of predicates now that the cost primary and secondary sums
946  // have been computed.
947  std::vector<OrderedPredicate *> ordered;
948  ordered.reserve(uniqued.size());
949  for (auto &ip : uniqued)
950  ordered.push_back(&ip);
951  llvm::sort(ordered, [](OrderedPredicate *lhs, OrderedPredicate *rhs) {
952  return *lhs < *rhs;
953  });
954 
955  // Build the matchers for each of the pattern predicate lists.
956  std::unique_ptr<MatcherNode> root;
957  for (OrderedPredicateList &list : lists)
958  propagatePattern(root, list, ordered.begin(), ordered.end());
959 
960  // Collapse the graph and insert the exit node.
961  foldSwitchToBool(root);
962  insertExitNode(&root);
963  return root;
964 }
965 
966 //===----------------------------------------------------------------------===//
967 // MatcherNode
968 //===----------------------------------------------------------------------===//
969 
971  std::unique_ptr<MatcherNode> failureNode)
972  : position(p), question(q), failureNode(std::move(failureNode)),
973  matcherTypeID(matcherTypeID) {}
974 
975 //===----------------------------------------------------------------------===//
976 // BoolNode
977 //===----------------------------------------------------------------------===//
978 
979 BoolNode::BoolNode(Position *position, Qualifier *question, Qualifier *answer,
980  std::unique_ptr<MatcherNode> successNode,
981  std::unique_ptr<MatcherNode> failureNode)
982  : MatcherNode(TypeID::get<BoolNode>(), position, question,
983  std::move(failureNode)),
984  answer(answer), successNode(std::move(successNode)) {}
985 
986 //===----------------------------------------------------------------------===//
987 // SuccessNode
988 //===----------------------------------------------------------------------===//
989 
990 SuccessNode::SuccessNode(pdl::PatternOp pattern, Value root,
991  std::unique_ptr<MatcherNode> failureNode)
992  : MatcherNode(TypeID::get<SuccessNode>(), /*position=*/nullptr,
993  /*question=*/nullptr, std::move(failureNode)),
994  pattern(pattern), root(root) {}
995 
996 //===----------------------------------------------------------------------===//
997 // SwitchNode
998 //===----------------------------------------------------------------------===//
999 
1001  : MatcherNode(TypeID::get<SwitchNode>(), position, question) {}
Position * getForEach(Position *p, unsigned id)
Definition: Predicate.h:595
Include the generated interface declarations.
static Value buildPredicateList(pdl::PatternOp pattern, PredicateBuilder &builder, std::vector< PositionalPredicate > &predList, DenseMap< Value, Position *> &valueToPosition)
Given a pattern operation, build the set of matcher predicates necessary to match this pattern...
Position * getResultGroup(OperationPosition *p, Optional< unsigned > group, bool isVariadic)
Returns a position for a group of results of the given operation.
Definition: Predicate.h:619
Qualifier * getQuestion() const
Returns the predicate checked on this node.
Definition: PredicateTree.h:66
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
A position describing the result type of an entity, i.e.
Definition: Predicate.h:326
A SwitchNode denotes a question with multiple potential results.
static SmallVector< Value > detectRoots(pdl::PatternOp pattern)
Given a pattern, determines the set of roots present in this pattern.
Predicate getTypeConstraint(Attribute type)
Create a predicate comparing the type of an attribute or value to a known type.
Definition: Predicate.h:713
static void insertExitNode(std::unique_ptr< MatcherNode > *root)
Insert an exit node at the end of the failure path of the root.
Predicate getEqualTo(Position *pos)
Create a predicate checking if two values are equal.
Definition: Predicate.h:660
unsigned getOperationDepth() const
Returns the depth of the first ancestor operation position.
Definition: Predicate.cpp:21
Position * getAllResults(OperationPosition *p)
Definition: Predicate.h:623
An ordinal predicate consists of a "Question" and a set of acceptable "Answers" (later converted to o...
Definition: Predicate.h:383
Position * position
The position the predicate is applied to.
Definition: PredicateTree.h:36
Predicate getResultCount(unsigned count)
Create a predicate comparing the number of results of an operation to a known value.
Definition: Predicate.h:701
static void visitUpward(std::vector< PositionalPredicate > &predList, OpIndex opIndex, PredicateBuilder &builder, DenseMap< Value, Position *> &valueToPosition, Position *&pos, unsigned rootID)
Visit a node during upward traversal.
This class represents the base of a predicate matcher node.
Definition: PredicateTree.h:50
static unsigned getNumNonRangeValues(ValueRange values)
Returns the number of non-range elements within values.
llvm::MapVector< Qualifier *, std::unique_ptr< MatcherNode > > ChildMapT
Returns the children of this switch node.
OperationPosition * getPassthroughOp(Position *p)
Returns the operation position equivalent to the given position.
Definition: Predicate.h:580
Position * getOperand(OperationPosition *p, unsigned operand)
Returns an operand position for an operand of the given operation.
Definition: Predicate.h:600
SuccessNode(pdl::PatternOp pattern, Value root, std::unique_ptr< MatcherNode > failureNode)
std::vector< std::pair< Value, Value > > EdgeList
A list of edges (child, parent).
Definition: RootOrdering.h:93
static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate)
Returns true if the given matcher refers to the same predicate as the given ordered predicate...
static bool useOperandGroup(pdl::OperationOp op, unsigned index)
Returns true if the operand at the given index needs to be queried using an operand group...
A position describing an attribute of an operation.
Definition: Predicate.h:170
static void getResultPredicates(pdl::ResultOp op, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position *> &inputs)
Position * getAllOperands(OperationPosition *p)
Definition: Predicate.h:609
The information associated with an edge in the cost graph.
Definition: RootOrdering.h:59
static constexpr const bool value
Predicate getOperandCount(unsigned count)
Create a predicate comparing the number of operands of an operation to a known value.
Definition: Predicate.h:684
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:52
A position describes a value on the input IR on which a predicate may be applied, such as an operatio...
Definition: Predicate.h:143
static void getOperandTreePredicates(std::vector< PositionalPredicate > &predList, Value val, PredicateBuilder &builder, DenseMap< Value, Position *> &inputs, Position *pos)
Collect all of the predicates for the given operand position.
Position * getAttribute(OperationPosition *p, StringRef name)
Returns an attribute position for an attribute of the given operation.
Definition: Predicate.h:586
Position * getRoot()
Returns the root operation position.
Definition: Predicate.h:570
A PositionalPredicate is a predicate that is associated with a specific positional value...
Definition: PredicateTree.h:30
An operation position describes an operation node in the IR.
Definition: Predicate.h:248
OperationPosition * getOperandDefiningOp(Position *p)
Returns the parent position defining the value held by the given operand.
Definition: Predicate.h:573
static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position *> &inputs)
std::pair< unsigned, unsigned > cost
The depth of the connector Value w.r.t.
Definition: RootOrdering.h:65
Position * getAttributeLiteral(Attribute attr)
Returns an attribute position for the given attribute.
Definition: Predicate.h:591
EdgeList preOrderTraversal(ArrayRef< Value > nodes) const
Returns the computed edges as visited in the preorder traversal.
Value connector
The connector value in the intersection of the two subtrees rooted at the source and target root that...
Definition: RootOrdering.h:70
Predicate getOperationName(StringRef name)
Create a predicate comparing the name of an operation to a known value.
Definition: Predicate.h:694
std::unique_ptr< MatcherNode > & getOrCreateChild(SwitchNode *node, OrderedPredicate *predicate, pdl::PatternOp pattern)
Get or insert a child matcher for the given parent switch node, given a predicate and parent pattern...
Attributes are known-constant values of operations.
Definition: Attributes.h:24
SwitchNode(Position *position, Qualifier *question)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:206
Predicate getOperandCountAtLeast(unsigned count)
Definition: Predicate.h:688
Predicate getIsNotNull()
Create a predicate comparing a value with null.
Definition: Predicate.h:678
static void buildCostGraph(ArrayRef< Value > roots, RootOrderingGraph &graph, ParentMaps &parentMaps)
Given a list of candidate roots, builds the cost graph for connecting them.
MatcherNode(TypeID matcherTypeID, Position *position=nullptr, Qualifier *question=nullptr, std::unique_ptr< MatcherNode > failureNode=nullptr)
A BoolNode denotes a question with a boolean-like result.
static bool comparePosDepth(Position *lhs, Position *rhs)
Compares the depths of two positions.
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Predicate getAttributeConstraint(Attribute attr)
Create a predicate comparing an attribute to a known value.
Definition: Predicate.h:654
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
Predicates::Kind getKind() const
Returns the kind of this position.
Definition: Predicate.h:155
Position * getTypeLiteral(Attribute attr)
Returns a type position for the given type value.
Definition: Predicate.h:632
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
Position * getPosition() const
Returns the position on which the question predicate should be checked.
Definition: PredicateTree.h:63
The optimal branching algorithm solver.
Definition: RootOrdering.h:90
Qualifier * question
The question that the predicate applies.
Definition: PredicateTree.h:39
Position * getOperandGroup(OperationPosition *p, Optional< unsigned > group, bool isVariadic)
Returns a position for a group of operands of the given operation.
Definition: Predicate.h:605
Type getType() const
Return the type of this value.
Definition: Value.h:117
static void getTreePredicates(std::vector< PositionalPredicate > &predList, Value val, PredicateBuilder &builder, DenseMap< Value, Position *> &inputs, Position *pos)
Collect the tree predicates anchored at the given value.
UsersPosition * getUsers(Position *p, bool useRepresentative)
Returns the users of a position using the value at the given operand.
Definition: Predicate.h:637
type_range getTypes() const
static void getNonTreePredicates(pdl::PatternOp pattern, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position *> &inputs)
Collect all of the predicates that cannot be determined via walking the tree.
static void getAttributePredicates(pdl::AttributeOp op, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position *> &inputs)
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static void propagatePattern(std::unique_ptr< MatcherNode > &node, OrderedPredicateList &list, std::vector< OrderedPredicate *>::iterator current, std::vector< OrderedPredicate *>::iterator end)
Build the matcher CFG by "pushing" patterns through by sorted predicate order.
This class implements the operand iterators for the Operation class.
std::pair< Qualifier *, Qualifier * > Predicate
An ordinal predicate consists of a "Question" and a set of acceptable "Answers" (later converted to o...
Definition: Predicate.h:651
Predicate getConstraint(StringRef name, ArrayRef< Position *> pos, Attribute params)
Create a predicate that applies a generic constraint.
Definition: Predicate.h:670
A SuccessNode denotes that a given high level pattern has successfully been matched.
static void foldSwitchToBool(std::unique_ptr< MatcherNode > &node)
Fold any switch nodes nested under node to boolean nodes when possible.
bool isa() const
Definition: Types.h:234
bool operator<(Fraction x, Fraction y)
Definition: Fraction.h:61
static void getTypePredicates(Value typeValue, function_ref< Attribute()> typeAttrFn, PredicateBuilder &builder, DenseMap< Value, Position *> &inputs)
This class provides an abstraction over the different types of ranges over Values.
unsigned solve()
Runs the Edmonds&#39; algorithm for the current graph, returning the total cost of the minimum-weight spa...
Position * getType(Position *p)
Returns a type position for the given entity.
Definition: Predicate.h:628
Predicate getResultCountAtLeast(unsigned count)
Definition: Predicate.h:705
This class provides utilities for constructing predicates.
Definition: Predicate.h:560
Predicates::Kind getKind() const
Returns the kind of this qualifier.
Definition: Predicate.h:388
static std::unique_ptr< MatcherNode > generateMatcherTree(ModuleOp module, PredicateBuilder &builder, DenseMap< Value, Position *> &valueToPosition)
Given a module containing PDL pattern operations, generate a matcher tree using the patterns within t...
BoolNode(Position *position, Qualifier *question, Qualifier *answer, std::unique_ptr< MatcherNode > successNode, std::unique_ptr< MatcherNode > failureNode=nullptr)
Position * getResult(OperationPosition *p, unsigned result)
Returns a result position for a result of the given operation.
Definition: Predicate.h:614