MLIR  16.0.0git
PredicateTree.cpp
Go to the documentation of this file.
1 //===- PredicateTree.cpp - Predicate tree merging -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PredicateTree.h"
10 #include "RootOrdering.h"
11 
15 #include "mlir/IR/BuiltinOps.h"
17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 #include "llvm/Support/Debug.h"
20 #include <queue>
21 
22 #define DEBUG_TYPE "pdl-predicate-tree"
23 
24 using namespace mlir;
25 using namespace mlir::pdl_to_pdl_interp;
26 
27 //===----------------------------------------------------------------------===//
28 // Predicate List Building
29 //===----------------------------------------------------------------------===//
30 
31 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
32  Value val, PredicateBuilder &builder,
34  Position *pos);
35 
36 /// Compares the depths of two positions.
37 static bool comparePosDepth(Position *lhs, Position *rhs) {
38  return lhs->getOperationDepth() < rhs->getOperationDepth();
39 }
40 
41 /// Returns the number of non-range elements within `values`.
42 static unsigned getNumNonRangeValues(ValueRange values) {
43  return llvm::count_if(values.getTypes(),
44  [](Type type) { return !type.isa<pdl::RangeType>(); });
45 }
46 
47 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
48  Value val, PredicateBuilder &builder,
50  AttributePosition *pos) {
51  assert(val.getType().isa<pdl::AttributeType>() && "expected attribute type");
52  pdl::AttributeOp attr = cast<pdl::AttributeOp>(val.getDefiningOp());
53  predList.emplace_back(pos, builder.getIsNotNull());
54 
55  // If the attribute has a type or value, add a constraint.
56  if (Value type = attr.type())
57  getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
58  else if (Attribute value = attr.valueAttr())
59  predList.emplace_back(pos, builder.getAttributeConstraint(value));
60 }
61 
62 /// Collect all of the predicates for the given operand position.
63 static void getOperandTreePredicates(std::vector<PositionalPredicate> &predList,
64  Value val, PredicateBuilder &builder,
66  Position *pos) {
67  Type valueType = val.getType();
68  bool isVariadic = valueType.isa<pdl::RangeType>();
69 
70  // If this is a typed operand, add a type constraint.
72  .Case<pdl::OperandOp, pdl::OperandsOp>([&](auto op) {
73  // Prevent traversal into a null value if the operand has a proper
74  // index.
75  if (std::is_same<pdl::OperandOp, decltype(op)>::value ||
76  cast<OperandGroupPosition>(pos)->getOperandGroupNumber())
77  predList.emplace_back(pos, builder.getIsNotNull());
78 
79  if (Value type = op.type())
80  getTreePredicates(predList, type, builder, inputs,
81  builder.getType(pos));
82  })
83  .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto op) {
84  Optional<unsigned> index = op.index();
85 
86  // Prevent traversal into a null value if the result has a proper index.
87  if (index)
88  predList.emplace_back(pos, builder.getIsNotNull());
89 
90  // Get the parent operation of this operand.
91  OperationPosition *parentPos = builder.getOperandDefiningOp(pos);
92  predList.emplace_back(parentPos, builder.getIsNotNull());
93 
94  // Ensure that the operands match the corresponding results of the
95  // parent operation.
96  Position *resultPos = nullptr;
97  if (std::is_same<pdl::ResultOp, decltype(op)>::value)
98  resultPos = builder.getResult(parentPos, *index);
99  else
100  resultPos = builder.getResultGroup(parentPos, index, isVariadic);
101  predList.emplace_back(resultPos, builder.getEqualTo(pos));
102 
103  // Collect the predicates of the parent operation.
104  getTreePredicates(predList, op.parent(), builder, inputs,
105  (Position *)parentPos);
106  });
107 }
108 
109 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
110  Value val, PredicateBuilder &builder,
112  OperationPosition *pos,
113  Optional<unsigned> ignoreOperand = llvm::None) {
114  assert(val.getType().isa<pdl::OperationType>() && "expected operation");
115  pdl::OperationOp op = cast<pdl::OperationOp>(val.getDefiningOp());
116  OperationPosition *opPos = cast<OperationPosition>(pos);
117 
118  // Ensure getDefiningOp returns a non-null operation.
119  if (!opPos->isRoot())
120  predList.emplace_back(pos, builder.getIsNotNull());
121 
122  // Check that this is the correct root operation.
123  if (Optional<StringRef> opName = op.name())
124  predList.emplace_back(pos, builder.getOperationName(*opName));
125 
126  // Check that the operation has the proper number of operands. If there are
127  // any variable length operands, we check a minimum instead of an exact count.
128  OperandRange operands = op.operands();
129  unsigned minOperands = getNumNonRangeValues(operands);
130  if (minOperands != operands.size()) {
131  if (minOperands)
132  predList.emplace_back(pos, builder.getOperandCountAtLeast(minOperands));
133  } else {
134  predList.emplace_back(pos, builder.getOperandCount(minOperands));
135  }
136 
137  // Check that the operation has the proper number of results. If there are
138  // any variable length results, we check a minimum instead of an exact count.
139  OperandRange types = op.types();
140  unsigned minResults = getNumNonRangeValues(types);
141  if (minResults == types.size())
142  predList.emplace_back(pos, builder.getResultCount(types.size()));
143  else if (minResults)
144  predList.emplace_back(pos, builder.getResultCountAtLeast(minResults));
145 
146  // Recurse into any attributes, operands, or results.
147  for (auto it : llvm::zip(op.attributeNames(), op.attributes())) {
149  predList, std::get<1>(it), builder, inputs,
150  builder.getAttribute(opPos,
151  std::get<0>(it).cast<StringAttr>().getValue()));
152  }
153 
154  // Process the operands and results of the operation. For all values up to
155  // the first variable length value, we use the concrete operand/result
156  // number. After that, we use the "group" given that we can't know the
157  // concrete indices until runtime. If there is only one variadic operand
158  // group, we treat it as all of the operands/results of the operation.
159  /// Operands.
160  if (operands.size() == 1 && operands[0].getType().isa<pdl::RangeType>()) {
161  // Ignore the operands if we are performing an upward traversal (in that
162  // case, they have already been visited).
163  if (opPos->isRoot() || opPos->isOperandDefiningOp())
164  getTreePredicates(predList, operands.front(), builder, inputs,
165  builder.getAllOperands(opPos));
166  } else {
167  bool foundVariableLength = false;
168  for (const auto &operandIt : llvm::enumerate(operands)) {
169  bool isVariadic = operandIt.value().getType().isa<pdl::RangeType>();
170  foundVariableLength |= isVariadic;
171 
172  // Ignore the specified operand, usually because this position was
173  // visited in an upward traversal via an iterative choice.
174  if (ignoreOperand && *ignoreOperand == operandIt.index())
175  continue;
176 
177  Position *pos =
178  foundVariableLength
179  ? builder.getOperandGroup(opPos, operandIt.index(), isVariadic)
180  : builder.getOperand(opPos, operandIt.index());
181  getTreePredicates(predList, operandIt.value(), builder, inputs, pos);
182  }
183  }
184  /// Results.
185  if (types.size() == 1 && types[0].getType().isa<pdl::RangeType>()) {
186  getTreePredicates(predList, types.front(), builder, inputs,
187  builder.getType(builder.getAllResults(opPos)));
188  } else {
189  bool foundVariableLength = false;
190  for (auto &resultIt : llvm::enumerate(types)) {
191  bool isVariadic = resultIt.value().getType().isa<pdl::RangeType>();
192  foundVariableLength |= isVariadic;
193 
194  auto *resultPos =
195  foundVariableLength
196  ? builder.getResultGroup(pos, resultIt.index(), isVariadic)
197  : builder.getResult(pos, resultIt.index());
198  predList.emplace_back(resultPos, builder.getIsNotNull());
199  getTreePredicates(predList, resultIt.value(), builder, inputs,
200  builder.getType(resultPos));
201  }
202  }
203 }
204 
205 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
206  Value val, PredicateBuilder &builder,
208  TypePosition *pos) {
209  // Check for a constraint on a constant type.
210  if (pdl::TypeOp typeOp = val.getDefiningOp<pdl::TypeOp>()) {
211  if (Attribute type = typeOp.typeAttr())
212  predList.emplace_back(pos, builder.getTypeConstraint(type));
213  } else if (pdl::TypesOp typeOp = val.getDefiningOp<pdl::TypesOp>()) {
214  if (Attribute typeAttr = typeOp.typesAttr())
215  predList.emplace_back(pos, builder.getTypeConstraint(typeAttr));
216  }
217 }
218 
219 /// Collect the tree predicates anchored at the given value.
220 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
221  Value val, PredicateBuilder &builder,
223  Position *pos) {
224  // Make sure this input value is accessible to the rewrite.
225  auto it = inputs.try_emplace(val, pos);
226  if (!it.second) {
227  // If this is an input value that has been visited in the tree, add a
228  // constraint to ensure that both instances refer to the same value.
229  if (isa<pdl::AttributeOp, pdl::OperandOp, pdl::OperandsOp, pdl::OperationOp,
230  pdl::TypeOp>(val.getDefiningOp())) {
231  auto minMaxPositions =
232  std::minmax(pos, it.first->second, comparePosDepth);
233  predList.emplace_back(minMaxPositions.second,
234  builder.getEqualTo(minMaxPositions.first));
235  }
236  return;
237  }
238 
240  .Case<AttributePosition, OperationPosition, TypePosition>([&](auto *pos) {
241  getTreePredicates(predList, val, builder, inputs, pos);
242  })
243  .Case<OperandPosition, OperandGroupPosition>([&](auto *pos) {
244  getOperandTreePredicates(predList, val, builder, inputs, pos);
245  })
246  .Default([](auto *) { llvm_unreachable("unexpected position kind"); });
247 }
248 
249 static void getAttributePredicates(pdl::AttributeOp op,
250  std::vector<PositionalPredicate> &predList,
251  PredicateBuilder &builder,
252  DenseMap<Value, Position *> &inputs) {
253  Position *&attrPos = inputs[op];
254  if (attrPos)
255  return;
256  Attribute value = op.valueAttr();
257  assert(value && "expected non-tree `pdl.attribute` to contain a value");
258  attrPos = builder.getAttributeLiteral(value);
259 }
260 
261 static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
262  std::vector<PositionalPredicate> &predList,
263  PredicateBuilder &builder,
264  DenseMap<Value, Position *> &inputs) {
265  OperandRange arguments = op.args();
266 
267  std::vector<Position *> allPositions;
268  allPositions.reserve(arguments.size());
269  for (Value arg : arguments)
270  allPositions.push_back(inputs.lookup(arg));
271 
272  // Push the constraint to the furthest position.
273  Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
276  builder.getConstraint(op.name(), allPositions);
277  predList.emplace_back(pos, pred);
278 }
279 
280 static void getResultPredicates(pdl::ResultOp op,
281  std::vector<PositionalPredicate> &predList,
282  PredicateBuilder &builder,
283  DenseMap<Value, Position *> &inputs) {
284  Position *&resultPos = inputs[op];
285  if (resultPos)
286  return;
287 
288  // Ensure that the result isn't null.
289  auto *parentPos = cast<OperationPosition>(inputs.lookup(op.parent()));
290  resultPos = builder.getResult(parentPos, op.index());
291  predList.emplace_back(resultPos, builder.getIsNotNull());
292 }
293 
294 static void getResultPredicates(pdl::ResultsOp op,
295  std::vector<PositionalPredicate> &predList,
296  PredicateBuilder &builder,
297  DenseMap<Value, Position *> &inputs) {
298  Position *&resultPos = inputs[op];
299  if (resultPos)
300  return;
301 
302  // Ensure that the result isn't null if the result has an index.
303  auto *parentPos = cast<OperationPosition>(inputs.lookup(op.parent()));
304  bool isVariadic = op.getType().isa<pdl::RangeType>();
305  Optional<unsigned> index = op.index();
306  resultPos = builder.getResultGroup(parentPos, index, isVariadic);
307  if (index)
308  predList.emplace_back(resultPos, builder.getIsNotNull());
309 }
310 
311 static void getTypePredicates(Value typeValue,
312  function_ref<Attribute()> typeAttrFn,
313  PredicateBuilder &builder,
314  DenseMap<Value, Position *> &inputs) {
315  Position *&typePos = inputs[typeValue];
316  if (typePos)
317  return;
318  Attribute typeAttr = typeAttrFn();
319  assert(typeAttr &&
320  "expected non-tree `pdl.type`/`pdl.types` to contain a value");
321  typePos = builder.getTypeLiteral(typeAttr);
322 }
323 
324 /// Collect all of the predicates that cannot be determined via walking the
325 /// tree.
326 static void getNonTreePredicates(pdl::PatternOp pattern,
327  std::vector<PositionalPredicate> &predList,
328  PredicateBuilder &builder,
329  DenseMap<Value, Position *> &inputs) {
330  for (Operation &op : pattern.body().getOps()) {
332  .Case([&](pdl::AttributeOp attrOp) {
333  getAttributePredicates(attrOp, predList, builder, inputs);
334  })
335  .Case<pdl::ApplyNativeConstraintOp>([&](auto constraintOp) {
336  getConstraintPredicates(constraintOp, predList, builder, inputs);
337  })
338  .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
339  getResultPredicates(resultOp, predList, builder, inputs);
340  })
341  .Case([&](pdl::TypeOp typeOp) {
343  typeOp, [&] { return typeOp.typeAttr(); }, builder, inputs);
344  })
345  .Case([&](pdl::TypesOp typeOp) {
347  typeOp, [&] { return typeOp.typesAttr(); }, builder, inputs);
348  });
349  }
350 }
351 
352 namespace {
353 
354 /// An op accepting a value at an optional index.
355 struct OpIndex {
356  Value parent;
357  Optional<unsigned> index;
358 };
359 
360 /// The parent and operand index of each operation for each root, stored
361 /// as a nested map [root][operation].
362 using ParentMaps = DenseMap<Value, DenseMap<Value, OpIndex>>;
363 
364 } // namespace
365 
366 /// Given a pattern, determines the set of roots present in this pattern.
367 /// These are the operations whose results are not consumed by other operations.
368 static SmallVector<Value> detectRoots(pdl::PatternOp pattern) {
369  // First, collect all the operations that are used as operands
370  // to other operations. These are not roots by default.
371  DenseSet<Value> used;
372  for (auto operationOp : pattern.body().getOps<pdl::OperationOp>()) {
373  for (Value operand : operationOp.operands())
374  TypeSwitch<Operation *>(operand.getDefiningOp())
375  .Case<pdl::ResultOp, pdl::ResultsOp>(
376  [&used](auto resultOp) { used.insert(resultOp.parent()); });
377  }
378 
379  // Remove the specified root from the use set, so that we can
380  // always select it as a root, even if it is used by other operations.
381  if (Value root = pattern.getRewriter().root())
382  used.erase(root);
383 
384  // Finally, collect all the unused operations.
385  SmallVector<Value> roots;
386  for (Value operationOp : pattern.body().getOps<pdl::OperationOp>())
387  if (!used.contains(operationOp))
388  roots.push_back(operationOp);
389 
390  return roots;
391 }
392 
393 /// Given a list of candidate roots, builds the cost graph for connecting them.
394 /// The graph is formed by traversing the DAG of operations starting from each
395 /// root and marking the depth of each connector value (operand). Then we join
396 /// the candidate roots based on the common connector values, taking the one
397 /// with the minimum depth. Along the way, we compute, for each candidate root,
398 /// a mapping from each operation (in the DAG underneath this root) to its
399 /// parent operation and the corresponding operand index.
401  ParentMaps &parentMaps) {
402 
403  // The entry of a queue. The entry consists of the following items:
404  // * the value in the DAG underneath the root;
405  // * the parent of the value;
406  // * the operand index of the value in its parent;
407  // * the depth of the visited value.
408  struct Entry {
409  Entry(Value value, Value parent, Optional<unsigned> index, unsigned depth)
410  : value(value), parent(parent), index(index), depth(depth) {}
411 
412  Value value;
413  Value parent;
414  Optional<unsigned> index;
415  unsigned depth;
416  };
417 
418  // A root of a value and its depth (distance from root to the value).
419  struct RootDepth {
420  Value root;
421  unsigned depth = 0;
422  };
423 
424  // Map from candidate connector values to their roots and depths. Using a
425  // small vector with 1 entry because most values belong to a single root.
426  llvm::MapVector<Value, SmallVector<RootDepth, 1>> connectorsRootsDepths;
427 
428  // Perform a breadth-first traversal of the op DAG rooted at each root.
429  for (Value root : roots) {
430  // The queue of visited values. A value may be present multiple times in
431  // the queue, for multiple parents. We only accept the first occurrence,
432  // which is guaranteed to have the lowest depth.
433  std::queue<Entry> toVisit;
434  toVisit.emplace(root, Value(), 0, 0);
435 
436  // The map from value to its parent for the current root.
437  DenseMap<Value, OpIndex> &parentMap = parentMaps[root];
438 
439  while (!toVisit.empty()) {
440  Entry entry = toVisit.front();
441  toVisit.pop();
442  // Skip if already visited.
443  if (!parentMap.insert({entry.value, {entry.parent, entry.index}}).second)
444  continue;
445 
446  // Mark the root and depth of the value.
447  connectorsRootsDepths[entry.value].push_back({root, entry.depth});
448 
449  // Traverse the operands of an operation and result ops.
450  // We intentionally do not traverse attributes and types, because those
451  // are expensive to join on.
452  TypeSwitch<Operation *>(entry.value.getDefiningOp())
453  .Case<pdl::OperationOp>([&](auto operationOp) {
454  OperandRange operands = operationOp.operands();
455  // Special case when we pass all the operands in one range.
456  // For those, the index is empty.
457  if (operands.size() == 1 &&
458  operands[0].getType().isa<pdl::RangeType>()) {
459  toVisit.emplace(operands[0], entry.value, llvm::None,
460  entry.depth + 1);
461  return;
462  }
463 
464  // Default case: visit all the operands.
465  for (const auto &p : llvm::enumerate(operationOp.operands()))
466  toVisit.emplace(p.value(), entry.value, p.index(),
467  entry.depth + 1);
468  })
469  .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
470  toVisit.emplace(resultOp.parent(), entry.value, resultOp.index(),
471  entry.depth);
472  });
473  }
474  }
475 
476  // Now build the cost graph.
477  // This is simply a minimum over all depths for the target root.
478  unsigned nextID = 0;
479  for (const auto &connectorRootsDepths : connectorsRootsDepths) {
480  Value value = connectorRootsDepths.first;
481  ArrayRef<RootDepth> rootsDepths = connectorRootsDepths.second;
482  // If there is only one root for this value, this will not trigger
483  // any edges in the cost graph (a perf optimization).
484  if (rootsDepths.size() == 1)
485  continue;
486 
487  for (const RootDepth &p : rootsDepths) {
488  for (const RootDepth &q : rootsDepths) {
489  if (&p == &q)
490  continue;
491  // Insert or retrieve the property of edge from p to q.
492  RootOrderingEntry &entry = graph[q.root][p.root];
493  if (!entry.connector /* new edge */ || entry.cost.first > q.depth) {
494  if (!entry.connector)
495  entry.cost.second = nextID++;
496  entry.cost.first = q.depth;
497  entry.connector = value;
498  }
499  }
500  }
501  }
502 
503  assert((llvm::hasSingleElement(roots) || graph.size() == roots.size()) &&
504  "the pattern contains a candidate root disconnected from the others");
505 }
506 
507 /// Returns true if the operand at the given index needs to be queried using an
508 /// operand group, i.e., if it is variadic itself or follows a variadic operand.
509 static bool useOperandGroup(pdl::OperationOp op, unsigned index) {
510  OperandRange operands = op.operands();
511  assert(index < operands.size() && "operand index out of range");
512  for (unsigned i = 0; i <= index; ++i)
513  if (operands[i].getType().isa<pdl::RangeType>())
514  return true;
515  return false;
516 }
517 
518 /// Visit a node during upward traversal.
519 static void visitUpward(std::vector<PositionalPredicate> &predList,
520  OpIndex opIndex, PredicateBuilder &builder,
521  DenseMap<Value, Position *> &valueToPosition,
522  Position *&pos, unsigned rootID) {
523  Value value = opIndex.parent;
525  .Case<pdl::OperationOp>([&](auto operationOp) {
526  LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
527 
528  // Get users and iterate over them.
529  Position *usersPos = builder.getUsers(pos, /*useRepresentative=*/true);
530  Position *foreachPos = builder.getForEach(usersPos, rootID);
531  OperationPosition *opPos = builder.getPassthroughOp(foreachPos);
532 
533  // Compare the operand(s) of the user against the input value(s).
534  Position *operandPos;
535  if (!opIndex.index) {
536  // We are querying all the operands of the operation.
537  operandPos = builder.getAllOperands(opPos);
538  } else if (useOperandGroup(operationOp, *opIndex.index)) {
539  // We are querying an operand group.
540  Type type = operationOp.operands()[*opIndex.index].getType();
541  bool variadic = type.isa<pdl::RangeType>();
542  operandPos = builder.getOperandGroup(opPos, opIndex.index, variadic);
543  } else {
544  // We are querying an individual operand.
545  operandPos = builder.getOperand(opPos, *opIndex.index);
546  }
547  predList.emplace_back(operandPos, builder.getEqualTo(pos));
548 
549  // Guard against duplicate upward visits. These are not possible,
550  // because if this value was already visited, it would have been
551  // cheaper to start the traversal at this value rather than at the
552  // `connector`, violating the optimality of our spanning tree.
553  bool inserted = valueToPosition.try_emplace(value, opPos).second;
554  (void)inserted;
555  assert(inserted && "duplicate upward visit");
556 
557  // Obtain the tree predicates at the current value.
558  getTreePredicates(predList, value, builder, valueToPosition, opPos,
559  opIndex.index);
560 
561  // Update the position
562  pos = opPos;
563  })
564  .Case<pdl::ResultOp>([&](auto resultOp) {
565  // Traverse up an individual result.
566  auto *opPos = dyn_cast<OperationPosition>(pos);
567  assert(opPos && "operations and results must be interleaved");
568  pos = builder.getResult(opPos, *opIndex.index);
569 
570  // Insert the result position in case we have not visited it yet.
571  valueToPosition.try_emplace(value, pos);
572  })
573  .Case<pdl::ResultsOp>([&](auto resultOp) {
574  // Traverse up a group of results.
575  auto *opPos = dyn_cast<OperationPosition>(pos);
576  assert(opPos && "operations and results must be interleaved");
577  bool isVariadic = value.getType().isa<pdl::RangeType>();
578  if (opIndex.index)
579  pos = builder.getResultGroup(opPos, opIndex.index, isVariadic);
580  else
581  pos = builder.getAllResults(opPos);
582 
583  // Insert the result position in case we have not visited it yet.
584  valueToPosition.try_emplace(value, pos);
585  });
586 }
587 
588 /// Given a pattern operation, build the set of matcher predicates necessary to
589 /// match this pattern.
590 static Value buildPredicateList(pdl::PatternOp pattern,
591  PredicateBuilder &builder,
592  std::vector<PositionalPredicate> &predList,
593  DenseMap<Value, Position *> &valueToPosition) {
594  SmallVector<Value> roots = detectRoots(pattern);
595 
596  // Build the root ordering graph and compute the parent maps.
597  RootOrderingGraph graph;
598  ParentMaps parentMaps;
599  buildCostGraph(roots, graph, parentMaps);
600  LLVM_DEBUG({
601  llvm::dbgs() << "Graph:\n";
602  for (auto &target : graph) {
603  llvm::dbgs() << " * " << target.first.getLoc() << " " << target.first
604  << "\n";
605  for (auto &source : target.second) {
606  RootOrderingEntry &entry = source.second;
607  llvm::dbgs() << " <- " << source.first << ": " << entry.cost.first
608  << ":" << entry.cost.second << " via "
609  << entry.connector.getLoc() << "\n";
610  }
611  }
612  });
613 
614  // Solve the optimal branching problem for each candidate root, or use the
615  // provided one.
616  Value bestRoot = pattern.getRewriter().root();
617  OptimalBranching::EdgeList bestEdges;
618  if (!bestRoot) {
619  unsigned bestCost = 0;
620  LLVM_DEBUG(llvm::dbgs() << "Candidate roots:\n");
621  for (Value root : roots) {
622  OptimalBranching solver(graph, root);
623  unsigned cost = solver.solve();
624  LLVM_DEBUG(llvm::dbgs() << " * " << root << ": " << cost << "\n");
625  if (!bestRoot || bestCost > cost) {
626  bestCost = cost;
627  bestRoot = root;
628  bestEdges = solver.preOrderTraversal(roots);
629  }
630  }
631  } else {
632  OptimalBranching solver(graph, bestRoot);
633  solver.solve();
634  bestEdges = solver.preOrderTraversal(roots);
635  }
636 
637  // Print the best solution.
638  LLVM_DEBUG({
639  llvm::dbgs() << "Best tree:\n";
640  for (const std::pair<Value, Value> &edge : bestEdges) {
641  llvm::dbgs() << " * " << edge.first;
642  if (edge.second)
643  llvm::dbgs() << " <- " << edge.second;
644  llvm::dbgs() << "\n";
645  }
646  });
647 
648  LLVM_DEBUG(llvm::dbgs() << "Calling key getTreePredicates:\n");
649  LLVM_DEBUG(llvm::dbgs() << " * Value: " << bestRoot << "\n");
650 
651  // The best root is the starting point for the traversal. Get the tree
652  // predicates for the DAG rooted at bestRoot.
653  getTreePredicates(predList, bestRoot, builder, valueToPosition,
654  builder.getRoot());
655 
656  // Traverse the selected optimal branching. For all edges in order, traverse
657  // up starting from the connector, until the candidate root is reached, and
658  // call getTreePredicates at every node along the way.
659  for (const auto &it : llvm::enumerate(bestEdges)) {
660  Value target = it.value().first;
661  Value source = it.value().second;
662 
663  // Check if we already visited the target root. This happens in two cases:
664  // 1) the initial root (bestRoot);
665  // 2) a root that is dominated by (contained in the subtree rooted at) an
666  // already visited root.
667  if (valueToPosition.count(target))
668  continue;
669 
670  // Determine the connector.
671  Value connector = graph[target][source].connector;
672  assert(connector && "invalid edge");
673  LLVM_DEBUG(llvm::dbgs() << " * Connector: " << connector.getLoc() << "\n");
674  DenseMap<Value, OpIndex> parentMap = parentMaps.lookup(target);
675  Position *pos = valueToPosition.lookup(connector);
676  assert(pos && "connector has not been traversed yet");
677 
678  // Traverse from the connector upwards towards the target root.
679  for (Value value = connector; value != target;) {
680  OpIndex opIndex = parentMap.lookup(value);
681  assert(opIndex.parent && "missing parent");
682  visitUpward(predList, opIndex, builder, valueToPosition, pos, it.index());
683  value = opIndex.parent;
684  }
685  }
686 
687  getNonTreePredicates(pattern, predList, builder, valueToPosition);
688 
689  return bestRoot;
690 }
691 
692 //===----------------------------------------------------------------------===//
693 // Pattern Predicate Tree Merging
694 //===----------------------------------------------------------------------===//
695 
696 namespace {
697 
698 /// This class represents a specific predicate applied to a position, and
699 /// provides hashing and ordering operators. This class allows for computing a
700 /// frequence sum and ordering predicates based on a cost model.
701 struct OrderedPredicate {
702  OrderedPredicate(const std::pair<Position *, Qualifier *> &ip)
703  : position(ip.first), question(ip.second) {}
704  OrderedPredicate(const PositionalPredicate &ip)
705  : position(ip.position), question(ip.question) {}
706 
707  /// The position this predicate is applied to.
708  Position *position;
709 
710  /// The question that is applied by this predicate onto the position.
711  Qualifier *question;
712 
713  /// The first and second order benefit sums.
714  /// The primary sum is the number of occurrences of this predicate among all
715  /// of the patterns.
716  unsigned primary = 0;
717  /// The secondary sum is a squared summation of the primary sum of all of the
718  /// predicates within each pattern that contains this predicate. This allows
719  /// for favoring predicates that are more commonly shared within a pattern, as
720  /// opposed to those shared across patterns.
721  unsigned secondary = 0;
722 
723  /// The tie breaking ID, used to preserve a deterministic (insertion) order
724  /// among all the predicates with the same priority, depth, and position /
725  /// predicate dependency.
726  unsigned id = 0;
727 
728  /// A map between a pattern operation and the answer to the predicate question
729  /// within that pattern.
730  DenseMap<Operation *, Qualifier *> patternToAnswer;
731 
732  /// Returns true if this predicate is ordered before `rhs`, based on the cost
733  /// model.
734  bool operator<(const OrderedPredicate &rhs) const {
735  // Sort by:
736  // * higher first and secondary order sums
737  // * lower depth
738  // * lower position dependency
739  // * lower predicate dependency
740  // * lower tie breaking ID
741  auto *rhsPos = rhs.position;
742  return std::make_tuple(primary, secondary, rhsPos->getOperationDepth(),
743  rhsPos->getKind(), rhs.question->getKind(), rhs.id) >
744  std::make_tuple(rhs.primary, rhs.secondary,
745  position->getOperationDepth(), position->getKind(),
746  question->getKind(), id);
747  }
748 };
749 
750 /// A DenseMapInfo for OrderedPredicate based solely on the position and
751 /// question.
752 struct OrderedPredicateDenseInfo {
754 
755  static OrderedPredicate getEmptyKey() { return Base::getEmptyKey(); }
756  static OrderedPredicate getTombstoneKey() { return Base::getTombstoneKey(); }
757  static bool isEqual(const OrderedPredicate &lhs,
758  const OrderedPredicate &rhs) {
759  return lhs.position == rhs.position && lhs.question == rhs.question;
760  }
761  static unsigned getHashValue(const OrderedPredicate &p) {
762  return llvm::hash_combine(p.position, p.question);
763  }
764 };
765 
766 /// This class wraps a set of ordered predicates that are used within a specific
767 /// pattern operation.
768 struct OrderedPredicateList {
769  OrderedPredicateList(pdl::PatternOp pattern, Value root)
770  : pattern(pattern), root(root) {}
771 
772  pdl::PatternOp pattern;
773  Value root;
774  DenseSet<OrderedPredicate *> predicates;
775 };
776 } // namespace
777 
778 /// Returns true if the given matcher refers to the same predicate as the given
779 /// ordered predicate. This means that the position and questions of the two
780 /// match.
781 static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate) {
782  return node->getPosition() == predicate->position &&
783  node->getQuestion() == predicate->question;
784 }
785 
786 /// Get or insert a child matcher for the given parent switch node, given a
787 /// predicate and parent pattern.
788 std::unique_ptr<MatcherNode> &getOrCreateChild(SwitchNode *node,
789  OrderedPredicate *predicate,
790  pdl::PatternOp pattern) {
791  assert(isSamePredicate(node, predicate) &&
792  "expected matcher to equal the given predicate");
793 
794  auto it = predicate->patternToAnswer.find(pattern);
795  assert(it != predicate->patternToAnswer.end() &&
796  "expected pattern to exist in predicate");
797  return node->getChildren().insert({it->second, nullptr}).first->second;
798 }
799 
800 /// Build the matcher CFG by "pushing" patterns through by sorted predicate
801 /// order. A pattern will traverse as far as possible using common predicates
802 /// and then either diverge from the CFG or reach the end of a branch and start
803 /// creating new nodes.
804 static void propagatePattern(std::unique_ptr<MatcherNode> &node,
805  OrderedPredicateList &list,
806  std::vector<OrderedPredicate *>::iterator current,
807  std::vector<OrderedPredicate *>::iterator end) {
808  if (current == end) {
809  // We've hit the end of a pattern, so create a successful result node.
810  node =
811  std::make_unique<SuccessNode>(list.pattern, list.root, std::move(node));
812 
813  // If the pattern doesn't contain this predicate, ignore it.
814  } else if (list.predicates.find(*current) == list.predicates.end()) {
815  propagatePattern(node, list, std::next(current), end);
816 
817  // If the current matcher node is invalid, create a new one for this
818  // position and continue propagation.
819  } else if (!node) {
820  // Create a new node at this position and continue
821  node = std::make_unique<SwitchNode>((*current)->position,
822  (*current)->question);
824  getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
825  list, std::next(current), end);
826 
827  // If the matcher has already been created, and it is for this predicate we
828  // continue propagation to the child.
829  } else if (isSamePredicate(node.get(), *current)) {
831  getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
832  list, std::next(current), end);
833 
834  // If the matcher doesn't match the current predicate, insert a branch as
835  // the common set of matchers has diverged.
836  } else {
837  propagatePattern(node->getFailureNode(), list, current, end);
838  }
839 }
840 
841 /// Fold any switch nodes nested under `node` to boolean nodes when possible.
842 /// `node` is updated in-place if it is a switch.
843 static void foldSwitchToBool(std::unique_ptr<MatcherNode> &node) {
844  if (!node)
845  return;
846 
847  if (SwitchNode *switchNode = dyn_cast<SwitchNode>(&*node)) {
848  SwitchNode::ChildMapT &children = switchNode->getChildren();
849  for (auto &it : children)
850  foldSwitchToBool(it.second);
851 
852  // If the node only contains one child, collapse it into a boolean predicate
853  // node.
854  if (children.size() == 1) {
855  auto childIt = children.begin();
856  node = std::make_unique<BoolNode>(
857  node->getPosition(), node->getQuestion(), childIt->first,
858  std::move(childIt->second), std::move(node->getFailureNode()));
859  }
860  } else if (BoolNode *boolNode = dyn_cast<BoolNode>(&*node)) {
861  foldSwitchToBool(boolNode->getSuccessNode());
862  }
863 
864  foldSwitchToBool(node->getFailureNode());
865 }
866 
867 /// Insert an exit node at the end of the failure path of the `root`.
868 static void insertExitNode(std::unique_ptr<MatcherNode> *root) {
869  while (*root)
870  root = &(*root)->getFailureNode();
871  *root = std::make_unique<ExitNode>();
872 }
873 
874 /// Given a module containing PDL pattern operations, generate a matcher tree
875 /// using the patterns within the given module and return the root matcher node.
876 std::unique_ptr<MatcherNode>
878  DenseMap<Value, Position *> &valueToPosition) {
879  // The set of predicates contained within the pattern operations of the
880  // module.
881  struct PatternPredicates {
882  PatternPredicates(pdl::PatternOp pattern, Value root,
883  std::vector<PositionalPredicate> predicates)
884  : pattern(pattern), root(root), predicates(std::move(predicates)) {}
885 
886  /// A pattern.
887  pdl::PatternOp pattern;
888 
889  /// A root of the pattern chosen among the candidate roots in pdl.rewrite.
890  Value root;
891 
892  /// The extracted predicates for this pattern and root.
893  std::vector<PositionalPredicate> predicates;
894  };
895 
896  SmallVector<PatternPredicates, 16> patternsAndPredicates;
897  for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
898  std::vector<PositionalPredicate> predicateList;
899  Value root =
900  buildPredicateList(pattern, builder, predicateList, valueToPosition);
901  patternsAndPredicates.emplace_back(pattern, root, std::move(predicateList));
902  }
903 
904  // Associate a pattern result with each unique predicate.
906  for (auto &patternAndPredList : patternsAndPredicates) {
907  for (auto &predicate : patternAndPredList.predicates) {
908  auto it = uniqued.insert(predicate);
909  it.first->patternToAnswer.try_emplace(patternAndPredList.pattern,
910  predicate.answer);
911  // Mark the insertion order (0-based indexing).
912  if (it.second)
913  it.first->id = uniqued.size() - 1;
914  }
915  }
916 
917  // Associate each pattern to a set of its ordered predicates for later lookup.
918  std::vector<OrderedPredicateList> lists;
919  lists.reserve(patternsAndPredicates.size());
920  for (auto &patternAndPredList : patternsAndPredicates) {
921  OrderedPredicateList list(patternAndPredList.pattern,
922  patternAndPredList.root);
923  for (auto &predicate : patternAndPredList.predicates) {
924  OrderedPredicate *orderedPredicate = &*uniqued.find(predicate);
925  list.predicates.insert(orderedPredicate);
926 
927  // Increment the primary sum for each reference to a particular predicate.
928  ++orderedPredicate->primary;
929  }
930  lists.push_back(std::move(list));
931  }
932 
933  // For a particular pattern, get the total primary sum and add it to the
934  // secondary sum of each predicate. Square the primary sums to emphasize
935  // shared predicates within rather than across patterns.
936  for (auto &list : lists) {
937  unsigned total = 0;
938  for (auto *predicate : list.predicates)
939  total += predicate->primary * predicate->primary;
940  for (auto *predicate : list.predicates)
941  predicate->secondary += total;
942  }
943 
944  // Sort the set of predicates now that the cost primary and secondary sums
945  // have been computed.
946  std::vector<OrderedPredicate *> ordered;
947  ordered.reserve(uniqued.size());
948  for (auto &ip : uniqued)
949  ordered.push_back(&ip);
950  llvm::sort(ordered, [](OrderedPredicate *lhs, OrderedPredicate *rhs) {
951  return *lhs < *rhs;
952  });
953 
954  // Build the matchers for each of the pattern predicate lists.
955  std::unique_ptr<MatcherNode> root;
956  for (OrderedPredicateList &list : lists)
957  propagatePattern(root, list, ordered.begin(), ordered.end());
958 
959  // Collapse the graph and insert the exit node.
960  foldSwitchToBool(root);
961  insertExitNode(&root);
962  return root;
963 }
964 
965 //===----------------------------------------------------------------------===//
966 // MatcherNode
967 //===----------------------------------------------------------------------===//
968 
970  std::unique_ptr<MatcherNode> failureNode)
971  : position(p), question(q), failureNode(std::move(failureNode)),
972  matcherTypeID(matcherTypeID) {}
973 
974 //===----------------------------------------------------------------------===//
975 // BoolNode
976 //===----------------------------------------------------------------------===//
977 
978 BoolNode::BoolNode(Position *position, Qualifier *question, Qualifier *answer,
979  std::unique_ptr<MatcherNode> successNode,
980  std::unique_ptr<MatcherNode> failureNode)
981  : MatcherNode(TypeID::get<BoolNode>(), position, question,
982  std::move(failureNode)),
983  answer(answer), successNode(std::move(successNode)) {}
984 
985 //===----------------------------------------------------------------------===//
986 // SuccessNode
987 //===----------------------------------------------------------------------===//
988 
989 SuccessNode::SuccessNode(pdl::PatternOp pattern, Value root,
990  std::unique_ptr<MatcherNode> failureNode)
991  : MatcherNode(TypeID::get<SuccessNode>(), /*position=*/nullptr,
992  /*question=*/nullptr, std::move(failureNode)),
993  pattern(pattern), root(root) {}
994 
995 //===----------------------------------------------------------------------===//
996 // SwitchNode
997 //===----------------------------------------------------------------------===//
998 
1000  : MatcherNode(TypeID::get<SwitchNode>(), position, question) {}
Position * getForEach(Position *p, unsigned id)
Definition: Predicate.h:588
Include the generated interface declarations.
static Value buildPredicateList(pdl::PatternOp pattern, PredicateBuilder &builder, std::vector< PositionalPredicate > &predList, DenseMap< Value, Position *> &valueToPosition)
Given a pattern operation, build the set of matcher predicates necessary to match this pattern...
bool operator<(Fraction x, Fraction y)
Definition: Fraction.h:69
Position * getResultGroup(OperationPosition *p, Optional< unsigned > group, bool isVariadic)
Returns a position for a group of results of the given operation.
Definition: Predicate.h:612
Qualifier * getQuestion() const
Returns the predicate checked on this node.
Definition: PredicateTree.h:66
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
A position describing the result type of an entity, i.e.
Definition: Predicate.h:326
A SwitchNode denotes a question with multiple potential results.
static SmallVector< Value > detectRoots(pdl::PatternOp pattern)
Given a pattern, determines the set of roots present in this pattern.
Predicate getTypeConstraint(Attribute type)
Create a predicate comparing the type of an attribute or value to a known type.
Definition: Predicate.h:704
static void insertExitNode(std::unique_ptr< MatcherNode > *root)
Insert an exit node at the end of the failure path of the root.
Predicate getEqualTo(Position *pos)
Create a predicate checking if two values are equal.
Definition: Predicate.h:653
unsigned getOperationDepth() const
Returns the depth of the first ancestor operation position.
Definition: Predicate.cpp:21
Position * getAllResults(OperationPosition *p)
Definition: Predicate.h:616
An ordinal predicate consists of a "Question" and a set of acceptable "Answers" (later converted to o...
Definition: Predicate.h:383
Position * position
The position the predicate is applied to.
Definition: PredicateTree.h:36
Predicate getResultCount(unsigned count)
Create a predicate comparing the number of results of an operation to a known value.
Definition: Predicate.h:692
static void visitUpward(std::vector< PositionalPredicate > &predList, OpIndex opIndex, PredicateBuilder &builder, DenseMap< Value, Position *> &valueToPosition, Position *&pos, unsigned rootID)
Visit a node during upward traversal.
This class represents the base of a predicate matcher node.
Definition: PredicateTree.h:50
static unsigned getNumNonRangeValues(ValueRange values)
Returns the number of non-range elements within values.
llvm::MapVector< Qualifier *, std::unique_ptr< MatcherNode > > ChildMapT
Returns the children of this switch node.
OperationPosition * getPassthroughOp(Position *p)
Returns the operation position equivalent to the given position.
Definition: Predicate.h:573
Position * getOperand(OperationPosition *p, unsigned operand)
Returns an operand position for an operand of the given operation.
Definition: Predicate.h:593
SuccessNode(pdl::PatternOp pattern, Value root, std::unique_ptr< MatcherNode > failureNode)
std::vector< std::pair< Value, Value > > EdgeList
A list of edges (child, parent).
Definition: RootOrdering.h:93
static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate)
Returns true if the given matcher refers to the same predicate as the given ordered predicate...
static bool useOperandGroup(pdl::OperationOp op, unsigned index)
Returns true if the operand at the given index needs to be queried using an operand group...
A position describing an attribute of an operation.
Definition: Predicate.h:170
static void getResultPredicates(pdl::ResultOp op, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position *> &inputs)
Position * getAllOperands(OperationPosition *p)
Definition: Predicate.h:602
The information associated with an edge in the cost graph.
Definition: RootOrdering.h:59
static constexpr const bool value
Predicate getOperandCount(unsigned count)
Create a predicate comparing the number of operands of an operation to a known value.
Definition: Predicate.h:675
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
A position describes a value on the input IR on which a predicate may be applied, such as an operatio...
Definition: Predicate.h:143
static void getOperandTreePredicates(std::vector< PositionalPredicate > &predList, Value val, PredicateBuilder &builder, DenseMap< Value, Position *> &inputs, Position *pos)
Collect all of the predicates for the given operand position.
Position * getAttribute(OperationPosition *p, StringRef name)
Returns an attribute position for an attribute of the given operation.
Definition: Predicate.h:579
Position * getRoot()
Returns the root operation position.
Definition: Predicate.h:563
A PositionalPredicate is a predicate that is associated with a specific positional value...
Definition: PredicateTree.h:30
An operation position describes an operation node in the IR.
Definition: Predicate.h:248
OperationPosition * getOperandDefiningOp(Position *p)
Returns the parent position defining the value held by the given operand.
Definition: Predicate.h:566
static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position *> &inputs)
std::pair< unsigned, unsigned > cost
The depth of the connector Value w.r.t.
Definition: RootOrdering.h:65
Position * getAttributeLiteral(Attribute attr)
Returns an attribute position for the given attribute.
Definition: Predicate.h:584
EdgeList preOrderTraversal(ArrayRef< Value > nodes) const
Returns the computed edges as visited in the preorder traversal.
Value connector
The connector value in the intersection of the two subtrees rooted at the source and target root that...
Definition: RootOrdering.h:70
Predicate getOperationName(StringRef name)
Create a predicate comparing the name of an operation to a known value.
Definition: Predicate.h:685
std::unique_ptr< MatcherNode > & getOrCreateChild(SwitchNode *node, OrderedPredicate *predicate, pdl::PatternOp pattern)
Get or insert a child matcher for the given parent switch node, given a predicate and parent pattern...
Attributes are known-constant values of operations.
Definition: Attributes.h:24
SwitchNode(Position *position, Qualifier *question)
type_range getTypes() const
Definition: ValueRange.cpp:44
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:233
Predicate getOperandCountAtLeast(unsigned count)
Definition: Predicate.h:679
Predicate getIsNotNull()
Create a predicate comparing a value with null.
Definition: Predicate.h:669
static void buildCostGraph(ArrayRef< Value > roots, RootOrderingGraph &graph, ParentMaps &parentMaps)
Given a list of candidate roots, builds the cost graph for connecting them.
MatcherNode(TypeID matcherTypeID, Position *position=nullptr, Qualifier *question=nullptr, std::unique_ptr< MatcherNode > failureNode=nullptr)
A BoolNode denotes a question with a boolean-like result.
static bool comparePosDepth(Position *lhs, Position *rhs)
Compares the depths of two positions.
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Predicate getAttributeConstraint(Attribute attr)
Create a predicate comparing an attribute to a known value.
Definition: Predicate.h:647
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
Predicates::Kind getKind() const
Returns the kind of this position.
Definition: Predicate.h:155
Position * getTypeLiteral(Attribute attr)
Returns a type position for the given type value.
Definition: Predicate.h:625
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Position * getPosition() const
Returns the position on which the question predicate should be checked.
Definition: PredicateTree.h:63
The optimal branching algorithm solver.
Definition: RootOrdering.h:90
Qualifier * question
The question that the predicate applies.
Definition: PredicateTree.h:39
Position * getOperandGroup(OperationPosition *p, Optional< unsigned > group, bool isVariadic)
Returns a position for a group of operands of the given operation.
Definition: Predicate.h:598
Type getType() const
Return the type of this value.
Definition: Value.h:118
static void getTreePredicates(std::vector< PositionalPredicate > &predList, Value val, PredicateBuilder &builder, DenseMap< Value, Position *> &inputs, Position *pos)
Collect the tree predicates anchored at the given value.
UsersPosition * getUsers(Position *p, bool useRepresentative)
Returns the users of a position using the value at the given operand.
Definition: Predicate.h:630
static void getNonTreePredicates(pdl::PatternOp pattern, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position *> &inputs)
Collect all of the predicates that cannot be determined via walking the tree.
type_range getType() const
Definition: ValueRange.cpp:30
static void getAttributePredicates(pdl::AttributeOp op, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position *> &inputs)
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static void propagatePattern(std::unique_ptr< MatcherNode > &node, OrderedPredicateList &list, std::vector< OrderedPredicate *>::iterator current, std::vector< OrderedPredicate *>::iterator end)
Build the matcher CFG by "pushing" patterns through by sorted predicate order.
Predicate getConstraint(StringRef name, ArrayRef< Position *> pos)
Create a predicate that applies a generic constraint.
Definition: Predicate.h:663
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:40
std::pair< Qualifier *, Qualifier * > Predicate
An ordinal predicate consists of a "Question" and a set of acceptable "Answers" (later converted to o...
Definition: Predicate.h:644
A SuccessNode denotes that a given high level pattern has successfully been matched.
static void foldSwitchToBool(std::unique_ptr< MatcherNode > &node)
Fold any switch nodes nested under node to boolean nodes when possible.
bool isa() const
Definition: Types.h:254
static void getTypePredicates(Value typeValue, function_ref< Attribute()> typeAttrFn, PredicateBuilder &builder, DenseMap< Value, Position *> &inputs)
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:345
unsigned solve()
Runs the Edmonds&#39; algorithm for the current graph, returning the total cost of the minimum-weight spa...
Position * getType(Position *p)
Returns a type position for the given entity.
Definition: Predicate.h:621
Predicate getResultCountAtLeast(unsigned count)
Definition: Predicate.h:696
This class provides utilities for constructing predicates.
Definition: Predicate.h:553
Predicates::Kind getKind() const
Returns the kind of this qualifier.
Definition: Predicate.h:388
static std::unique_ptr< MatcherNode > generateMatcherTree(ModuleOp module, PredicateBuilder &builder, DenseMap< Value, Position *> &valueToPosition)
Given a module containing PDL pattern operations, generate a matcher tree using the patterns within t...
BoolNode(Position *position, Qualifier *question, Qualifier *answer, std::unique_ptr< MatcherNode > successNode, std::unique_ptr< MatcherNode > failureNode=nullptr)
Position * getResult(OperationPosition *p, unsigned result)
Returns a result position for a result of the given operation.
Definition: Predicate.h:607