MLIR  22.0.0git
PredicateTree.cpp
Go to the documentation of this file.
1 //===- PredicateTree.cpp - Predicate tree merging -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PredicateTree.h"
10 #include "RootOrdering.h"
11 
13 #include "mlir/IR/BuiltinOps.h"
14 #include "llvm/ADT/MapVector.h"
15 #include "llvm/ADT/SmallPtrSet.h"
16 #include "llvm/ADT/TypeSwitch.h"
17 #include "llvm/Support/Debug.h"
18 #include "llvm/Support/DebugLog.h"
19 #include <queue>
20 
21 #define DEBUG_TYPE "pdl-predicate-tree"
22 
23 using namespace mlir;
24 using namespace mlir::pdl_to_pdl_interp;
25 
26 //===----------------------------------------------------------------------===//
27 // Predicate List Building
28 //===----------------------------------------------------------------------===//
29 
30 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
31  Value val, PredicateBuilder &builder,
33  Position *pos);
34 
35 /// Compares the depths of two positions.
36 static bool comparePosDepth(Position *lhs, Position *rhs) {
37  return lhs->getOperationDepth() < rhs->getOperationDepth();
38 }
39 
40 /// Returns the number of non-range elements within `values`.
41 static unsigned getNumNonRangeValues(ValueRange values) {
42  return llvm::count_if(values.getTypes(),
43  [](Type type) { return !isa<pdl::RangeType>(type); });
44 }
45 
46 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
47  Value val, PredicateBuilder &builder,
49  AttributePosition *pos) {
50  assert(isa<pdl::AttributeType>(val.getType()) && "expected attribute type");
51  predList.emplace_back(pos, builder.getIsNotNull());
52 
53  if (auto attr = val.getDefiningOp<pdl::AttributeOp>()) {
54  // If the attribute has a type or value, add a constraint.
55  if (Value type = attr.getValueType())
56  getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
57  else if (Attribute value = attr.getValueAttr())
58  predList.emplace_back(pos, builder.getAttributeConstraint(value));
59  }
60 }
61 
62 /// Collect all of the predicates for the given operand position.
63 static void getOperandTreePredicates(std::vector<PositionalPredicate> &predList,
64  Value val, PredicateBuilder &builder,
66  Position *pos) {
67  Type valueType = val.getType();
68  bool isVariadic = isa<pdl::RangeType>(valueType);
69 
70  // If this is a typed operand, add a type constraint.
72  .Case<pdl::OperandOp, pdl::OperandsOp>([&](auto op) {
73  // Prevent traversal into a null value if the operand has a proper
74  // index.
75  if (std::is_same<pdl::OperandOp, decltype(op)>::value ||
76  cast<OperandGroupPosition>(pos)->getOperandGroupNumber())
77  predList.emplace_back(pos, builder.getIsNotNull());
78 
79  if (Value type = op.getValueType())
80  getTreePredicates(predList, type, builder, inputs,
81  builder.getType(pos));
82  })
83  .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto op) {
84  std::optional<unsigned> index = op.getIndex();
85 
86  // Prevent traversal into a null value if the result has a proper index.
87  if (index)
88  predList.emplace_back(pos, builder.getIsNotNull());
89 
90  // Get the parent operation of this operand.
91  OperationPosition *parentPos = builder.getOperandDefiningOp(pos);
92  predList.emplace_back(parentPos, builder.getIsNotNull());
93 
94  // Ensure that the operands match the corresponding results of the
95  // parent operation.
96  Position *resultPos = nullptr;
97  if (std::is_same<pdl::ResultOp, decltype(op)>::value)
98  resultPos = builder.getResult(parentPos, *index);
99  else
100  resultPos = builder.getResultGroup(parentPos, index, isVariadic);
101  predList.emplace_back(resultPos, builder.getEqualTo(pos));
102 
103  // Collect the predicates of the parent operation.
104  getTreePredicates(predList, op.getParent(), builder, inputs,
105  (Position *)parentPos);
106  });
107 }
108 
109 static void
110 getTreePredicates(std::vector<PositionalPredicate> &predList, Value val,
111  PredicateBuilder &builder,
113  std::optional<unsigned> ignoreOperand = std::nullopt) {
114  assert(isa<pdl::OperationType>(val.getType()) && "expected operation");
115  pdl::OperationOp op = cast<pdl::OperationOp>(val.getDefiningOp());
116  OperationPosition *opPos = cast<OperationPosition>(pos);
117 
118  // Ensure getDefiningOp returns a non-null operation.
119  if (!opPos->isRoot())
120  predList.emplace_back(pos, builder.getIsNotNull());
121 
122  // Check that this is the correct root operation.
123  if (std::optional<StringRef> opName = op.getOpName())
124  predList.emplace_back(pos, builder.getOperationName(*opName));
125 
126  // Check that the operation has the proper number of operands. If there are
127  // any variable length operands, we check a minimum instead of an exact count.
128  OperandRange operands = op.getOperandValues();
129  unsigned minOperands = getNumNonRangeValues(operands);
130  if (minOperands != operands.size()) {
131  if (minOperands)
132  predList.emplace_back(pos, builder.getOperandCountAtLeast(minOperands));
133  } else {
134  predList.emplace_back(pos, builder.getOperandCount(minOperands));
135  }
136 
137  // Check that the operation has the proper number of results. If there are
138  // any variable length results, we check a minimum instead of an exact count.
139  OperandRange types = op.getTypeValues();
140  unsigned minResults = getNumNonRangeValues(types);
141  if (minResults == types.size())
142  predList.emplace_back(pos, builder.getResultCount(types.size()));
143  else if (minResults)
144  predList.emplace_back(pos, builder.getResultCountAtLeast(minResults));
145 
146  // Recurse into any attributes, operands, or results.
147  for (auto [attrName, attr] :
148  llvm::zip(op.getAttributeValueNames(), op.getAttributeValues())) {
150  predList, attr, builder, inputs,
151  builder.getAttribute(opPos, cast<StringAttr>(attrName).getValue()));
152  }
153 
154  // Process the operands and results of the operation. For all values up to
155  // the first variable length value, we use the concrete operand/result
156  // number. After that, we use the "group" given that we can't know the
157  // concrete indices until runtime. If there is only one variadic operand
158  // group, we treat it as all of the operands/results of the operation.
159  /// Operands.
160  if (operands.size() == 1 && isa<pdl::RangeType>(operands[0].getType())) {
161  // Ignore the operands if we are performing an upward traversal (in that
162  // case, they have already been visited).
163  if (opPos->isRoot() || opPos->isOperandDefiningOp())
164  getTreePredicates(predList, operands.front(), builder, inputs,
165  builder.getAllOperands(opPos));
166  } else {
167  bool foundVariableLength = false;
168  for (const auto &operandIt : llvm::enumerate(operands)) {
169  bool isVariadic = isa<pdl::RangeType>(operandIt.value().getType());
170  foundVariableLength |= isVariadic;
171 
172  // Ignore the specified operand, usually because this position was
173  // visited in an upward traversal via an iterative choice.
174  if (ignoreOperand == operandIt.index())
175  continue;
176 
177  Position *pos =
178  foundVariableLength
179  ? builder.getOperandGroup(opPos, operandIt.index(), isVariadic)
180  : builder.getOperand(opPos, operandIt.index());
181  getTreePredicates(predList, operandIt.value(), builder, inputs, pos);
182  }
183  }
184  /// Results.
185  if (types.size() == 1 && isa<pdl::RangeType>(types[0].getType())) {
186  getTreePredicates(predList, types.front(), builder, inputs,
187  builder.getType(builder.getAllResults(opPos)));
188  return;
189  }
190 
191  bool foundVariableLength = false;
192  for (auto [idx, typeValue] : llvm::enumerate(types)) {
193  bool isVariadic = isa<pdl::RangeType>(typeValue.getType());
194  foundVariableLength |= isVariadic;
195 
196  auto *resultPos = foundVariableLength
197  ? builder.getResultGroup(pos, idx, isVariadic)
198  : builder.getResult(pos, idx);
199  predList.emplace_back(resultPos, builder.getIsNotNull());
200  getTreePredicates(predList, typeValue, builder, inputs,
201  builder.getType(resultPos));
202  }
203 }
204 
205 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
206  Value val, PredicateBuilder &builder,
208  TypePosition *pos) {
209  // Check for a constraint on a constant type.
210  if (pdl::TypeOp typeOp = val.getDefiningOp<pdl::TypeOp>()) {
211  if (Attribute type = typeOp.getConstantTypeAttr())
212  predList.emplace_back(pos, builder.getTypeConstraint(type));
213  } else if (pdl::TypesOp typeOp = val.getDefiningOp<pdl::TypesOp>()) {
214  if (Attribute typeAttr = typeOp.getConstantTypesAttr())
215  predList.emplace_back(pos, builder.getTypeConstraint(typeAttr));
216  }
217 }
218 
219 /// Collect the tree predicates anchored at the given value.
220 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
221  Value val, PredicateBuilder &builder,
223  Position *pos) {
224  // Make sure this input value is accessible to the rewrite.
225  auto it = inputs.try_emplace(val, pos);
226  if (!it.second) {
227  // If this is an input value that has been visited in the tree, add a
228  // constraint to ensure that both instances refer to the same value.
229  if (isa<pdl::AttributeOp, pdl::OperandOp, pdl::OperandsOp, pdl::OperationOp,
230  pdl::TypeOp>(val.getDefiningOp())) {
231  auto minMaxPositions =
232  std::minmax(pos, it.first->second, comparePosDepth);
233  predList.emplace_back(minMaxPositions.second,
234  builder.getEqualTo(minMaxPositions.first));
235  }
236  return;
237  }
238 
240  .Case<AttributePosition, OperationPosition, TypePosition>([&](auto *pos) {
241  getTreePredicates(predList, val, builder, inputs, pos);
242  })
243  .Case<OperandPosition, OperandGroupPosition>([&](auto *pos) {
244  getOperandTreePredicates(predList, val, builder, inputs, pos);
245  })
246  .Default([](auto *) { llvm_unreachable("unexpected position kind"); });
247 }
248 
249 static void getAttributePredicates(pdl::AttributeOp op,
250  std::vector<PositionalPredicate> &predList,
251  PredicateBuilder &builder,
252  DenseMap<Value, Position *> &inputs) {
253  Position *&attrPos = inputs[op];
254  if (attrPos)
255  return;
256  Attribute value = op.getValueAttr();
257  assert(value && "expected non-tree `pdl.attribute` to contain a value");
258  attrPos = builder.getAttributeLiteral(value);
259 }
260 
261 static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
262  std::vector<PositionalPredicate> &predList,
263  PredicateBuilder &builder,
264  DenseMap<Value, Position *> &inputs) {
265  OperandRange arguments = op.getArgs();
266 
267  std::vector<Position *> allPositions;
268  allPositions.reserve(arguments.size());
269  for (Value arg : arguments)
270  allPositions.push_back(inputs.lookup(arg));
271 
272  // Push the constraint to the furthest position.
273  Position *pos = *llvm::max_element(allPositions, comparePosDepth);
274  ResultRange results = op.getResults();
276  op.getName(), allPositions, SmallVector<Type>(results.getTypes()),
277  op.getIsNegated());
278 
279  // For each result register a position so it can be used later
280  for (auto [i, result] : llvm::enumerate(results)) {
281  ConstraintQuestion *q = cast<ConstraintQuestion>(pred.first);
282  ConstraintPosition *pos = builder.getConstraintPosition(q, i);
283  auto [it, inserted] = inputs.try_emplace(result, pos);
284  // If this is an input value that has been visited in the tree, add a
285  // constraint to ensure that both instances refer to the same value.
286  if (!inserted) {
287  Position *first = pos;
288  Position *second = it->second;
289  if (comparePosDepth(second, first))
290  std::tie(second, first) = std::make_pair(first, second);
291 
292  predList.emplace_back(second, builder.getEqualTo(first));
293  }
294  }
295  predList.emplace_back(pos, pred);
296 }
297 
298 static void getResultPredicates(pdl::ResultOp op,
299  std::vector<PositionalPredicate> &predList,
300  PredicateBuilder &builder,
301  DenseMap<Value, Position *> &inputs) {
302  Position *&resultPos = inputs[op];
303  if (resultPos)
304  return;
305 
306  // Ensure that the result isn't null.
307  auto *parentPos = cast<OperationPosition>(inputs.lookup(op.getParent()));
308  resultPos = builder.getResult(parentPos, op.getIndex());
309  predList.emplace_back(resultPos, builder.getIsNotNull());
310 }
311 
312 static void getResultPredicates(pdl::ResultsOp op,
313  std::vector<PositionalPredicate> &predList,
314  PredicateBuilder &builder,
315  DenseMap<Value, Position *> &inputs) {
316  Position *&resultPos = inputs[op];
317  if (resultPos)
318  return;
319 
320  // Ensure that the result isn't null if the result has an index.
321  auto *parentPos = cast<OperationPosition>(inputs.lookup(op.getParent()));
322  bool isVariadic = isa<pdl::RangeType>(op.getType());
323  std::optional<unsigned> index = op.getIndex();
324  resultPos = builder.getResultGroup(parentPos, index, isVariadic);
325  if (index)
326  predList.emplace_back(resultPos, builder.getIsNotNull());
327 }
328 
329 static void getTypePredicates(Value typeValue,
330  function_ref<Attribute()> typeAttrFn,
331  PredicateBuilder &builder,
332  DenseMap<Value, Position *> &inputs) {
333  Position *&typePos = inputs[typeValue];
334  if (typePos)
335  return;
336  Attribute typeAttr = typeAttrFn();
337  assert(typeAttr &&
338  "expected non-tree `pdl.type`/`pdl.types` to contain a value");
339  typePos = builder.getTypeLiteral(typeAttr);
340 }
341 
342 /// Collect all of the predicates that cannot be determined via walking the
343 /// tree.
344 static void getNonTreePredicates(pdl::PatternOp pattern,
345  std::vector<PositionalPredicate> &predList,
346  PredicateBuilder &builder,
347  DenseMap<Value, Position *> &inputs) {
348  for (Operation &op : pattern.getBodyRegion().getOps()) {
350  .Case([&](pdl::AttributeOp attrOp) {
351  getAttributePredicates(attrOp, predList, builder, inputs);
352  })
353  .Case<pdl::ApplyNativeConstraintOp>([&](auto constraintOp) {
354  getConstraintPredicates(constraintOp, predList, builder, inputs);
355  })
356  .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
357  getResultPredicates(resultOp, predList, builder, inputs);
358  })
359  .Case([&](pdl::TypeOp typeOp) {
361  typeOp, [&] { return typeOp.getConstantTypeAttr(); }, builder,
362  inputs);
363  })
364  .Case([&](pdl::TypesOp typeOp) {
366  typeOp, [&] { return typeOp.getConstantTypesAttr(); }, builder,
367  inputs);
368  });
369  }
370 }
371 
372 namespace {
373 
374 /// An op accepting a value at an optional index.
375 struct OpIndex {
376  Value parent;
377  std::optional<unsigned> index;
378 };
379 
380 /// The parent and operand index of each operation for each root, stored
381 /// as a nested map [root][operation].
382 using ParentMaps = DenseMap<Value, DenseMap<Value, OpIndex>>;
383 
384 } // namespace
385 
386 /// Given a pattern, determines the set of roots present in this pattern.
387 /// These are the operations whose results are not consumed by other operations.
388 static SmallVector<Value> detectRoots(pdl::PatternOp pattern) {
389  // First, collect all the operations that are used as operands
390  // to other operations. These are not roots by default.
391  DenseSet<Value> used;
392  for (auto operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>()) {
393  for (Value operand : operationOp.getOperandValues())
394  TypeSwitch<Operation *>(operand.getDefiningOp())
395  .Case<pdl::ResultOp, pdl::ResultsOp>(
396  [&used](auto resultOp) { used.insert(resultOp.getParent()); });
397  }
398 
399  // Remove the specified root from the use set, so that we can
400  // always select it as a root, even if it is used by other operations.
401  if (Value root = pattern.getRewriter().getRoot())
402  used.erase(root);
403 
404  // Finally, collect all the unused operations.
405  SmallVector<Value> roots;
406  for (Value operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>())
407  if (!used.contains(operationOp))
408  roots.push_back(operationOp);
409 
410  return roots;
411 }
412 
413 /// Given a list of candidate roots, builds the cost graph for connecting them.
414 /// The graph is formed by traversing the DAG of operations starting from each
415 /// root and marking the depth of each connector value (operand). Then we join
416 /// the candidate roots based on the common connector values, taking the one
417 /// with the minimum depth. Along the way, we compute, for each candidate root,
418 /// a mapping from each operation (in the DAG underneath this root) to its
419 /// parent operation and the corresponding operand index.
421  ParentMaps &parentMaps) {
422 
423  // The entry of a queue. The entry consists of the following items:
424  // * the value in the DAG underneath the root;
425  // * the parent of the value;
426  // * the operand index of the value in its parent;
427  // * the depth of the visited value.
428  struct Entry {
429  Entry(Value value, Value parent, std::optional<unsigned> index,
430  unsigned depth)
431  : value(value), parent(parent), index(index), depth(depth) {}
432 
433  Value value;
434  Value parent;
435  std::optional<unsigned> index;
436  unsigned depth;
437  };
438 
439  // A root of a value and its depth (distance from root to the value).
440  struct RootDepth {
441  Value root;
442  unsigned depth = 0;
443  };
444 
445  // Map from candidate connector values to their roots and depths. Using a
446  // small vector with 1 entry because most values belong to a single root.
447  llvm::MapVector<Value, SmallVector<RootDepth, 1>> connectorsRootsDepths;
448 
449  // Perform a breadth-first traversal of the op DAG rooted at each root.
450  for (Value root : roots) {
451  // The queue of visited values. A value may be present multiple times in
452  // the queue, for multiple parents. We only accept the first occurrence,
453  // which is guaranteed to have the lowest depth.
454  std::queue<Entry> toVisit;
455  toVisit.emplace(root, Value(), 0, 0);
456 
457  // The map from value to its parent for the current root.
458  DenseMap<Value, OpIndex> &parentMap = parentMaps[root];
459 
460  while (!toVisit.empty()) {
461  Entry entry = toVisit.front();
462  toVisit.pop();
463  // Skip if already visited.
464  if (!parentMap.insert({entry.value, {entry.parent, entry.index}}).second)
465  continue;
466 
467  // Mark the root and depth of the value.
468  connectorsRootsDepths[entry.value].push_back({root, entry.depth});
469 
470  // Traverse the operands of an operation and result ops.
471  // We intentionally do not traverse attributes and types, because those
472  // are expensive to join on.
473  TypeSwitch<Operation *>(entry.value.getDefiningOp())
474  .Case<pdl::OperationOp>([&](auto operationOp) {
475  OperandRange operands = operationOp.getOperandValues();
476  // Special case when we pass all the operands in one range.
477  // For those, the index is empty.
478  if (operands.size() == 1 &&
479  isa<pdl::RangeType>(operands[0].getType())) {
480  toVisit.emplace(operands[0], entry.value, std::nullopt,
481  entry.depth + 1);
482  return;
483  }
484 
485  // Default case: visit all the operands.
486  for (const auto &p :
487  llvm::enumerate(operationOp.getOperandValues()))
488  toVisit.emplace(p.value(), entry.value, p.index(),
489  entry.depth + 1);
490  })
491  .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
492  toVisit.emplace(resultOp.getParent(), entry.value,
493  resultOp.getIndex(), entry.depth);
494  });
495  }
496  }
497 
498  // Now build the cost graph.
499  // This is simply a minimum over all depths for the target root.
500  unsigned nextID = 0;
501  for (const auto &connectorRootsDepths : connectorsRootsDepths) {
502  Value value = connectorRootsDepths.first;
503  ArrayRef<RootDepth> rootsDepths = connectorRootsDepths.second;
504  // If there is only one root for this value, this will not trigger
505  // any edges in the cost graph (a perf optimization).
506  if (rootsDepths.size() == 1)
507  continue;
508 
509  for (const RootDepth &p : rootsDepths) {
510  for (const RootDepth &q : rootsDepths) {
511  if (&p == &q)
512  continue;
513  // Insert or retrieve the property of edge from p to q.
514  RootOrderingEntry &entry = graph[q.root][p.root];
515  if (!entry.connector /* new edge */ || entry.cost.first > q.depth) {
516  if (!entry.connector)
517  entry.cost.second = nextID++;
518  entry.cost.first = q.depth;
519  entry.connector = value;
520  }
521  }
522  }
523  }
524 
525  assert((llvm::hasSingleElement(roots) || graph.size() == roots.size()) &&
526  "the pattern contains a candidate root disconnected from the others");
527 }
528 
529 /// Returns true if the operand at the given index needs to be queried using an
530 /// operand group, i.e., if it is variadic itself or follows a variadic operand.
531 static bool useOperandGroup(pdl::OperationOp op, unsigned index) {
532  OperandRange operands = op.getOperandValues();
533  assert(index < operands.size() && "operand index out of range");
534  for (unsigned i = 0; i <= index; ++i)
535  if (isa<pdl::RangeType>(operands[i].getType()))
536  return true;
537  return false;
538 }
539 
540 /// Visit a node during upward traversal.
541 static void visitUpward(std::vector<PositionalPredicate> &predList,
542  OpIndex opIndex, PredicateBuilder &builder,
543  DenseMap<Value, Position *> &valueToPosition,
544  Position *&pos, unsigned rootID) {
545  Value value = opIndex.parent;
547  .Case<pdl::OperationOp>([&](auto operationOp) {
548  LDBG() << " * Value: " << value;
549 
550  // Get users and iterate over them.
551  Position *usersPos = builder.getUsers(pos, /*useRepresentative=*/true);
552  Position *foreachPos = builder.getForEach(usersPos, rootID);
553  OperationPosition *opPos = builder.getPassthroughOp(foreachPos);
554 
555  // Compare the operand(s) of the user against the input value(s).
556  Position *operandPos;
557  if (!opIndex.index) {
558  // We are querying all the operands of the operation.
559  operandPos = builder.getAllOperands(opPos);
560  } else if (useOperandGroup(operationOp, *opIndex.index)) {
561  // We are querying an operand group.
562  Type type = operationOp.getOperandValues()[*opIndex.index].getType();
563  bool variadic = isa<pdl::RangeType>(type);
564  operandPos = builder.getOperandGroup(opPos, opIndex.index, variadic);
565  } else {
566  // We are querying an individual operand.
567  operandPos = builder.getOperand(opPos, *opIndex.index);
568  }
569  predList.emplace_back(operandPos, builder.getEqualTo(pos));
570 
571  // Guard against duplicate upward visits. These are not possible,
572  // because if this value was already visited, it would have been
573  // cheaper to start the traversal at this value rather than at the
574  // `connector`, violating the optimality of our spanning tree.
575  bool inserted = valueToPosition.try_emplace(value, opPos).second;
576  (void)inserted;
577  assert(inserted && "duplicate upward visit");
578 
579  // Obtain the tree predicates at the current value.
580  getTreePredicates(predList, value, builder, valueToPosition, opPos,
581  opIndex.index);
582 
583  // Update the position
584  pos = opPos;
585  })
586  .Case<pdl::ResultOp>([&](auto resultOp) {
587  // Traverse up an individual result.
588  auto *opPos = dyn_cast<OperationPosition>(pos);
589  assert(opPos && "operations and results must be interleaved");
590  pos = builder.getResult(opPos, *opIndex.index);
591 
592  // Insert the result position in case we have not visited it yet.
593  valueToPosition.try_emplace(value, pos);
594  })
595  .Case<pdl::ResultsOp>([&](auto resultOp) {
596  // Traverse up a group of results.
597  auto *opPos = dyn_cast<OperationPosition>(pos);
598  assert(opPos && "operations and results must be interleaved");
599  bool isVariadic = isa<pdl::RangeType>(value.getType());
600  if (opIndex.index)
601  pos = builder.getResultGroup(opPos, opIndex.index, isVariadic);
602  else
603  pos = builder.getAllResults(opPos);
604 
605  // Insert the result position in case we have not visited it yet.
606  valueToPosition.try_emplace(value, pos);
607  });
608 }
609 
610 /// Given a pattern operation, build the set of matcher predicates necessary to
611 /// match this pattern.
612 static Value buildPredicateList(pdl::PatternOp pattern,
613  PredicateBuilder &builder,
614  std::vector<PositionalPredicate> &predList,
615  DenseMap<Value, Position *> &valueToPosition) {
616  SmallVector<Value> roots = detectRoots(pattern);
617 
618  // Build the root ordering graph and compute the parent maps.
619  RootOrderingGraph graph;
620  ParentMaps parentMaps;
621  buildCostGraph(roots, graph, parentMaps);
622  LDBG() << "Graph:";
623  for (auto &target : graph) {
624  LDBG() << " * " << target.first.getLoc() << " " << target.first;
625  for (auto &source : target.second) {
626  RootOrderingEntry &entry = source.second;
627  LDBG() << " <- " << source.first << ": " << entry.cost.first << ":"
628  << entry.cost.second << " via " << entry.connector.getLoc();
629  }
630  }
631 
632  // Solve the optimal branching problem for each candidate root, or use the
633  // provided one.
634  Value bestRoot = pattern.getRewriter().getRoot();
635  OptimalBranching::EdgeList bestEdges;
636  if (!bestRoot) {
637  unsigned bestCost = 0;
638  LDBG() << "Candidate roots:";
639  for (Value root : roots) {
640  OptimalBranching solver(graph, root);
641  unsigned cost = solver.solve();
642  LDBG() << " * " << root << ": " << cost;
643  if (!bestRoot || bestCost > cost) {
644  bestCost = cost;
645  bestRoot = root;
646  bestEdges = solver.preOrderTraversal(roots);
647  }
648  }
649  } else {
650  OptimalBranching solver(graph, bestRoot);
651  solver.solve();
652  bestEdges = solver.preOrderTraversal(roots);
653  }
654 
655  // Print the best solution.
656  LDBG() << "Best tree:";
657  for (const std::pair<Value, Value> &edge : bestEdges) {
658  if (edge.second)
659  LDBG() << " * " << edge.first << " <- " << edge.second;
660  else
661  LDBG() << " * " << edge.first;
662  }
663 
664  LDBG() << "Calling key getTreePredicates (Value: " << bestRoot << ")";
665 
666  // The best root is the starting point for the traversal. Get the tree
667  // predicates for the DAG rooted at bestRoot.
668  getTreePredicates(predList, bestRoot, builder, valueToPosition,
669  builder.getRoot());
670 
671  // Traverse the selected optimal branching. For all edges in order, traverse
672  // up starting from the connector, until the candidate root is reached, and
673  // call getTreePredicates at every node along the way.
674  for (const auto &it : llvm::enumerate(bestEdges)) {
675  Value target = it.value().first;
676  Value source = it.value().second;
677 
678  // Check if we already visited the target root. This happens in two cases:
679  // 1) the initial root (bestRoot);
680  // 2) a root that is dominated by (contained in the subtree rooted at) an
681  // already visited root.
682  if (valueToPosition.count(target))
683  continue;
684 
685  // Determine the connector.
686  Value connector = graph[target][source].connector;
687  assert(connector && "invalid edge");
688  LDBG() << " * Connector: " << connector.getLoc();
689  DenseMap<Value, OpIndex> parentMap = parentMaps.lookup(target);
690  Position *pos = valueToPosition.lookup(connector);
691  assert(pos && "connector has not been traversed yet");
692 
693  // Traverse from the connector upwards towards the target root.
694  for (Value value = connector; value != target;) {
695  OpIndex opIndex = parentMap.lookup(value);
696  assert(opIndex.parent && "missing parent");
697  visitUpward(predList, opIndex, builder, valueToPosition, pos, it.index());
698  value = opIndex.parent;
699  }
700  }
701 
702  getNonTreePredicates(pattern, predList, builder, valueToPosition);
703 
704  return bestRoot;
705 }
706 
707 //===----------------------------------------------------------------------===//
708 // Pattern Predicate Tree Merging
709 //===----------------------------------------------------------------------===//
710 
711 namespace {
712 
713 /// This class represents a specific predicate applied to a position, and
714 /// provides hashing and ordering operators. This class allows for computing a
715 /// frequence sum and ordering predicates based on a cost model.
716 struct OrderedPredicate {
717  OrderedPredicate(const std::pair<Position *, Qualifier *> &ip)
718  : position(ip.first), question(ip.second) {}
719  OrderedPredicate(const PositionalPredicate &ip)
720  : position(ip.position), question(ip.question) {}
721 
722  /// The position this predicate is applied to.
723  Position *position;
724 
725  /// The question that is applied by this predicate onto the position.
726  Qualifier *question;
727 
728  /// The first and second order benefit sums.
729  /// The primary sum is the number of occurrences of this predicate among all
730  /// of the patterns.
731  unsigned primary = 0;
732  /// The secondary sum is a squared summation of the primary sum of all of the
733  /// predicates within each pattern that contains this predicate. This allows
734  /// for favoring predicates that are more commonly shared within a pattern, as
735  /// opposed to those shared across patterns.
736  unsigned secondary = 0;
737 
738  /// The tie breaking ID, used to preserve a deterministic (insertion) order
739  /// among all the predicates with the same priority, depth, and position /
740  /// predicate dependency.
741  unsigned id = 0;
742 
743  /// A map between a pattern operation and the answer to the predicate question
744  /// within that pattern.
745  DenseMap<Operation *, Qualifier *> patternToAnswer;
746 
747  /// Returns true if this predicate is ordered before `rhs`, based on the cost
748  /// model.
749  bool operator<(const OrderedPredicate &rhs) const {
750  // Sort by:
751  // * higher first and secondary order sums
752  // * lower depth
753  // * lower position dependency
754  // * lower predicate dependency
755  // * lower tie breaking ID
756  auto *rhsPos = rhs.position;
757  return std::make_tuple(primary, secondary, rhsPos->getOperationDepth(),
758  rhsPos->getKind(), rhs.question->getKind(), rhs.id) >
759  std::make_tuple(rhs.primary, rhs.secondary,
760  position->getOperationDepth(), position->getKind(),
761  question->getKind(), id);
762  }
763 };
764 
765 /// A DenseMapInfo for OrderedPredicate based solely on the position and
766 /// question.
767 struct OrderedPredicateDenseInfo {
769 
770  static OrderedPredicate getEmptyKey() { return Base::getEmptyKey(); }
771  static OrderedPredicate getTombstoneKey() { return Base::getTombstoneKey(); }
772  static bool isEqual(const OrderedPredicate &lhs,
773  const OrderedPredicate &rhs) {
774  return lhs.position == rhs.position && lhs.question == rhs.question;
775  }
776  static unsigned getHashValue(const OrderedPredicate &p) {
777  return llvm::hash_combine(p.position, p.question);
778  }
779 };
780 
781 /// This class wraps a set of ordered predicates that are used within a specific
782 /// pattern operation.
783 struct OrderedPredicateList {
784  OrderedPredicateList(pdl::PatternOp pattern, Value root)
785  : pattern(pattern), root(root) {}
786 
787  pdl::PatternOp pattern;
788  Value root;
789  DenseSet<OrderedPredicate *> predicates;
790 };
791 } // namespace
792 
793 /// Returns true if the given matcher refers to the same predicate as the given
794 /// ordered predicate. This means that the position and questions of the two
795 /// match.
796 static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate) {
797  return node->getPosition() == predicate->position &&
798  node->getQuestion() == predicate->question;
799 }
800 
801 /// Get or insert a child matcher for the given parent switch node, given a
802 /// predicate and parent pattern.
803 static std::unique_ptr<MatcherNode> &
804 getOrCreateChild(SwitchNode *node, OrderedPredicate *predicate,
805  pdl::PatternOp pattern) {
806  assert(isSamePredicate(node, predicate) &&
807  "expected matcher to equal the given predicate");
808 
809  auto it = predicate->patternToAnswer.find(pattern);
810  assert(it != predicate->patternToAnswer.end() &&
811  "expected pattern to exist in predicate");
812  return node->getChildren()[it->second];
813 }
814 
815 /// Build the matcher CFG by "pushing" patterns through by sorted predicate
816 /// order. A pattern will traverse as far as possible using common predicates
817 /// and then either diverge from the CFG or reach the end of a branch and start
818 /// creating new nodes.
819 static void propagatePattern(std::unique_ptr<MatcherNode> &node,
820  OrderedPredicateList &list,
821  std::vector<OrderedPredicate *>::iterator current,
822  std::vector<OrderedPredicate *>::iterator end) {
823  if (current == end) {
824  // We've hit the end of a pattern, so create a successful result node.
825  node =
826  std::make_unique<SuccessNode>(list.pattern, list.root, std::move(node));
827 
828  // If the pattern doesn't contain this predicate, ignore it.
829  } else if (!list.predicates.contains(*current)) {
830  propagatePattern(node, list, std::next(current), end);
831 
832  // If the current matcher node is invalid, create a new one for this
833  // position and continue propagation.
834  } else if (!node) {
835  // Create a new node at this position and continue
836  node = std::make_unique<SwitchNode>((*current)->position,
837  (*current)->question);
839  getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
840  list, std::next(current), end);
841 
842  // If the matcher has already been created, and it is for this predicate we
843  // continue propagation to the child.
844  } else if (isSamePredicate(node.get(), *current)) {
846  getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
847  list, std::next(current), end);
848 
849  // If the matcher doesn't match the current predicate, insert a branch as
850  // the common set of matchers has diverged.
851  } else {
852  propagatePattern(node->getFailureNode(), list, current, end);
853  }
854 }
855 
856 /// Fold any switch nodes nested under `node` to boolean nodes when possible.
857 /// `node` is updated in-place if it is a switch.
858 static void foldSwitchToBool(std::unique_ptr<MatcherNode> &node) {
859  if (!node)
860  return;
861 
862  if (SwitchNode *switchNode = dyn_cast<SwitchNode>(&*node)) {
863  SwitchNode::ChildMapT &children = switchNode->getChildren();
864  for (auto &it : children)
865  foldSwitchToBool(it.second);
866 
867  // If the node only contains one child, collapse it into a boolean predicate
868  // node.
869  if (children.size() == 1) {
870  auto *childIt = children.begin();
871  node = std::make_unique<BoolNode>(
872  node->getPosition(), node->getQuestion(), childIt->first,
873  std::move(childIt->second), std::move(node->getFailureNode()));
874  }
875  } else if (BoolNode *boolNode = dyn_cast<BoolNode>(&*node)) {
876  foldSwitchToBool(boolNode->getSuccessNode());
877  }
878 
879  foldSwitchToBool(node->getFailureNode());
880 }
881 
882 /// Insert an exit node at the end of the failure path of the `root`.
883 static void insertExitNode(std::unique_ptr<MatcherNode> *root) {
884  while (*root)
885  root = &(*root)->getFailureNode();
886  *root = std::make_unique<ExitNode>();
887 }
888 
889 /// Sorts the range begin/end with the partial order given by cmp.
890 template <typename Iterator, typename Compare>
891 static void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp) {
892  while (begin != end) {
893  // Cannot compute sortBeforeOthers in the predicate of stable_partition
894  // because stable_partition will not keep the [begin, end) range intact
895  // while it runs.
897  for (auto i = begin; i != end; ++i) {
898  if (std::none_of(begin, end, [&](auto const &b) { return cmp(b, *i); }))
899  sortBeforeOthers.insert(*i);
900  }
901 
902  auto const next = std::stable_partition(begin, end, [&](auto const &a) {
903  return sortBeforeOthers.contains(a);
904  });
905  assert(next != begin && "not a partial ordering");
906  begin = next;
907  }
908 }
909 
910 /// Returns true if 'b' depends on a result of 'a'.
911 static bool dependsOn(OrderedPredicate *a, OrderedPredicate *b) {
912  auto *cqa = dyn_cast<ConstraintQuestion>(a->question);
913  if (!cqa)
914  return false;
915 
916  auto positionDependsOnA = [&](Position *p) {
917  auto *cp = dyn_cast<ConstraintPosition>(p);
918  return cp && cp->getQuestion() == cqa;
919  };
920 
921  if (auto *cqb = dyn_cast<ConstraintQuestion>(b->question)) {
922  // Does any argument of b use a?
923  return llvm::any_of(cqb->getArgs(), positionDependsOnA);
924  }
925  if (auto *equalTo = dyn_cast<EqualToQuestion>(b->question)) {
926  return positionDependsOnA(b->position) ||
927  positionDependsOnA(equalTo->getValue());
928  }
929  return positionDependsOnA(b->position);
930 }
931 
932 /// Given a module containing PDL pattern operations, generate a matcher tree
933 /// using the patterns within the given module and return the root matcher node.
934 std::unique_ptr<MatcherNode>
935 MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
936  DenseMap<Value, Position *> &valueToPosition) {
937  // The set of predicates contained within the pattern operations of the
938  // module.
939  struct PatternPredicates {
940  PatternPredicates(pdl::PatternOp pattern, Value root,
941  std::vector<PositionalPredicate> predicates)
942  : pattern(pattern), root(root), predicates(std::move(predicates)) {}
943 
944  /// A pattern.
945  pdl::PatternOp pattern;
946 
947  /// A root of the pattern chosen among the candidate roots in pdl.rewrite.
948  Value root;
949 
950  /// The extracted predicates for this pattern and root.
951  std::vector<PositionalPredicate> predicates;
952  };
953 
954  SmallVector<PatternPredicates, 16> patternsAndPredicates;
955  for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
956  std::vector<PositionalPredicate> predicateList;
957  Value root =
958  buildPredicateList(pattern, builder, predicateList, valueToPosition);
959  patternsAndPredicates.emplace_back(pattern, root, std::move(predicateList));
960  }
961 
962  // Associate a pattern result with each unique predicate.
964  for (auto &patternAndPredList : patternsAndPredicates) {
965  for (auto &predicate : patternAndPredList.predicates) {
966  auto it = uniqued.insert(predicate);
967  it.first->patternToAnswer.try_emplace(patternAndPredList.pattern,
968  predicate.answer);
969  // Mark the insertion order (0-based indexing).
970  if (it.second)
971  it.first->id = uniqued.size() - 1;
972  }
973  }
974 
975  // Associate each pattern to a set of its ordered predicates for later lookup.
976  std::vector<OrderedPredicateList> lists;
977  lists.reserve(patternsAndPredicates.size());
978  for (auto &patternAndPredList : patternsAndPredicates) {
979  OrderedPredicateList list(patternAndPredList.pattern,
980  patternAndPredList.root);
981  for (auto &predicate : patternAndPredList.predicates) {
982  OrderedPredicate *orderedPredicate = &*uniqued.find(predicate);
983  list.predicates.insert(orderedPredicate);
984 
985  // Increment the primary sum for each reference to a particular predicate.
986  ++orderedPredicate->primary;
987  }
988  lists.push_back(std::move(list));
989  }
990 
991  // For a particular pattern, get the total primary sum and add it to the
992  // secondary sum of each predicate. Square the primary sums to emphasize
993  // shared predicates within rather than across patterns.
994  for (auto &list : lists) {
995  unsigned total = 0;
996  for (auto *predicate : list.predicates)
997  total += predicate->primary * predicate->primary;
998  for (auto *predicate : list.predicates)
999  predicate->secondary += total;
1000  }
1001 
1002  // Sort the set of predicates now that the cost primary and secondary sums
1003  // have been computed.
1004  std::vector<OrderedPredicate *> ordered;
1005  ordered.reserve(uniqued.size());
1006  for (auto &ip : uniqued)
1007  ordered.push_back(&ip);
1008  llvm::sort(ordered, [](OrderedPredicate *lhs, OrderedPredicate *rhs) {
1009  return *lhs < *rhs;
1010  });
1011 
1012  // Mostly keep the now established order, but also ensure that
1013  // ConstraintQuestions come after the results they use.
1014  stableTopologicalSort(ordered.begin(), ordered.end(), dependsOn);
1015 
1016  // Build the matchers for each of the pattern predicate lists.
1017  std::unique_ptr<MatcherNode> root;
1018  for (OrderedPredicateList &list : lists)
1019  propagatePattern(root, list, ordered.begin(), ordered.end());
1020 
1021  // Collapse the graph and insert the exit node.
1022  foldSwitchToBool(root);
1023  insertExitNode(&root);
1024  return root;
1025 }
1026 
1027 //===----------------------------------------------------------------------===//
1028 // MatcherNode
1029 //===----------------------------------------------------------------------===//
1030 
1031 MatcherNode::MatcherNode(TypeID matcherTypeID, Position *p, Qualifier *q,
1032  std::unique_ptr<MatcherNode> failureNode)
1033  : position(p), question(q), failureNode(std::move(failureNode)),
1034  matcherTypeID(matcherTypeID) {}
1035 
1036 //===----------------------------------------------------------------------===//
1037 // BoolNode
1038 //===----------------------------------------------------------------------===//
1039 
1040 BoolNode::BoolNode(Position *position, Qualifier *question, Qualifier *answer,
1041  std::unique_ptr<MatcherNode> successNode,
1042  std::unique_ptr<MatcherNode> failureNode)
1043  : MatcherNode(TypeID::get<BoolNode>(), position, question,
1044  std::move(failureNode)),
1045  answer(answer), successNode(std::move(successNode)) {}
1046 
1047 //===----------------------------------------------------------------------===//
1048 // SuccessNode
1049 //===----------------------------------------------------------------------===//
1050 
1051 SuccessNode::SuccessNode(pdl::PatternOp pattern, Value root,
1052  std::unique_ptr<MatcherNode> failureNode)
1053  : MatcherNode(TypeID::get<SuccessNode>(), /*position=*/nullptr,
1054  /*question=*/nullptr, std::move(failureNode)),
1055  pattern(pattern), root(root) {}
1056 
1057 //===----------------------------------------------------------------------===//
1058 // SwitchNode
1059 //===----------------------------------------------------------------------===//
1060 
1062  : MatcherNode(TypeID::get<SwitchNode>(), position, question) {}
static Value buildPredicateList(pdl::PatternOp pattern, PredicateBuilder &builder, std::vector< PositionalPredicate > &predList, DenseMap< Value, Position * > &valueToPosition)
Given a pattern operation, build the set of matcher predicates necessary to match this pattern.
static std::unique_ptr< MatcherNode > & getOrCreateChild(SwitchNode *node, OrderedPredicate *predicate, pdl::PatternOp pattern)
Get or insert a child matcher for the given parent switch node, given a predicate and parent pattern.
static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
static bool dependsOn(OrderedPredicate *a, OrderedPredicate *b)
Returns true if 'b' depends on a result of 'a'.
static void getTypePredicates(Value typeValue, function_ref< Attribute()> typeAttrFn, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
static void getNonTreePredicates(pdl::PatternOp pattern, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
Collect all of the predicates that cannot be determined via walking the tree.
static bool useOperandGroup(pdl::OperationOp op, unsigned index)
Returns true if the operand at the given index needs to be queried using an operand group,...
static void getTreePredicates(std::vector< PositionalPredicate > &predList, Value val, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs, Position *pos)
Collect the tree predicates anchored at the given value.
static void getResultPredicates(pdl::ResultOp op, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
static void getOperandTreePredicates(std::vector< PositionalPredicate > &predList, Value val, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs, Position *pos)
Collect all of the predicates for the given operand position.
static bool comparePosDepth(Position *lhs, Position *rhs)
Compares the depths of two positions.
static void visitUpward(std::vector< PositionalPredicate > &predList, OpIndex opIndex, PredicateBuilder &builder, DenseMap< Value, Position * > &valueToPosition, Position *&pos, unsigned rootID)
Visit a node during upward traversal.
static unsigned getNumNonRangeValues(ValueRange values)
Returns the number of non-range elements within values.
static SmallVector< Value > detectRoots(pdl::PatternOp pattern)
Given a pattern, determines the set of roots present in this pattern.
static void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp)
Sorts the range begin/end with the partial order given by cmp.
static void insertExitNode(std::unique_ptr< MatcherNode > *root)
Insert an exit node at the end of the failure path of the root.
static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate)
Returns true if the given matcher refers to the same predicate as the given ordered predicate.
static void foldSwitchToBool(std::unique_ptr< MatcherNode > &node)
Fold any switch nodes nested under node to boolean nodes when possible.
static void buildCostGraph(ArrayRef< Value > roots, RootOrderingGraph &graph, ParentMaps &parentMaps)
Given a list of candidate roots, builds the cost graph for connecting them.
static void getAttributePredicates(pdl::AttributeOp op, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
static void propagatePattern(std::unique_ptr< MatcherNode > &node, OrderedPredicateList &list, std::vector< OrderedPredicate * >::iterator current, std::vector< OrderedPredicate * >::iterator end)
Build the matcher CFG by "pushing" patterns through by sorted predicate order.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
type_range getType() const
Definition: ValueRange.cpp:32
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:247
type_range getTypes() const
Definition: ValueRange.cpp:38
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:107
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
This class represents the base of a predicate matcher node.
Definition: PredicateTree.h:50
Position * getPosition() const
Returns the position on which the question predicate should be checked.
Definition: PredicateTree.h:63
Qualifier * getQuestion() const
Returns the predicate checked on this node.
Definition: PredicateTree.h:66
The optimal branching algorithm solver.
Definition: RootOrdering.h:90
unsigned solve()
Runs the Edmonds' algorithm for the current graph, returning the total cost of the minimum-weight spa...
std::vector< std::pair< Value, Value > > EdgeList
A list of edges (child, parent).
Definition: RootOrdering.h:93
EdgeList preOrderTraversal(ArrayRef< Value > nodes) const
Returns the computed edges as visited in the preorder traversal.
A position describes a value on the input IR on which a predicate may be applied, such as an operatio...
Definition: Predicate.h:144
unsigned getOperationDepth() const
Returns the depth of the first ancestor operation position.
Definition: Predicate.cpp:21
Predicates::Kind getKind() const
Returns the kind of this position.
Definition: Predicate.h:156
const KeyTy & getValue() const
Return the key value of this predicate.
Definition: Predicate.h:111
This class provides utilities for constructing predicates.
Definition: Predicate.h:610
ConstraintPosition * getConstraintPosition(ConstraintQuestion *q, unsigned index)
Definition: Predicate.h:636
Position * getTypeLiteral(Attribute attr)
Returns a type position for the given type value.
Definition: Predicate.h:688
Predicate getOperandCount(unsigned count)
Create a predicate comparing the number of operands of an operation to a known value.
Definition: Predicate.h:740
OperationPosition * getPassthroughOp(Position *p)
Returns the operation position equivalent to the given position.
Definition: Predicate.h:630
Predicate getIsNotNull()
Create a predicate comparing a value with null.
Definition: Predicate.h:734
Predicate getOperandCountAtLeast(unsigned count)
Definition: Predicate.h:744
Predicate getResultCountAtLeast(unsigned count)
Definition: Predicate.h:761
Position * getType(Position *p)
Returns a type position for the given entity.
Definition: Predicate.h:684
Position * getAttribute(OperationPosition *p, StringRef name)
Returns an attribute position for an attribute of the given operation.
Definition: Predicate.h:642
Position * getOperandGroup(OperationPosition *p, std::optional< unsigned > group, bool isVariadic)
Returns a position for a group of operands of the given operation.
Definition: Predicate.h:661
Position * getForEach(Position *p, unsigned id)
Definition: Predicate.h:651
Position * getOperand(OperationPosition *p, unsigned operand)
Returns an operand position for an operand of the given operation.
Definition: Predicate.h:656
Position * getResult(OperationPosition *p, unsigned result)
Returns a result position for a result of the given operation.
Definition: Predicate.h:670
Position * getRoot()
Returns the root operation position.
Definition: Predicate.h:620
Predicate getAttributeConstraint(Attribute attr)
Create a predicate comparing an attribute to a known value.
Definition: Predicate.h:710
Position * getResultGroup(OperationPosition *p, std::optional< unsigned > group, bool isVariadic)
Returns a position for a group of results of the given operation.
Definition: Predicate.h:675
Position * getAllResults(OperationPosition *p)
Definition: Predicate.h:679
UsersPosition * getUsers(Position *p, bool useRepresentative)
Returns the users of a position using the value at the given operand.
Definition: Predicate.h:693
Predicate getTypeConstraint(Attribute type)
Create a predicate comparing the type of an attribute or value to a known type.
Definition: Predicate.h:769
OperationPosition * getOperandDefiningOp(Position *p)
Returns the parent position defining the value held by the given operand.
Definition: Predicate.h:623
Predicate getResultCount(unsigned count)
Create a predicate comparing the number of results of an operation to a known value.
Definition: Predicate.h:757
std::pair< Qualifier *, Qualifier * > Predicate
An ordinal predicate consists of a "Question" and a set of acceptable "Answers" (later converted to o...
Definition: Predicate.h:707
Predicate getEqualTo(Position *pos)
Create a predicate checking if two values are equal.
Definition: Predicate.h:716
Position * getAllOperands(OperationPosition *p)
Definition: Predicate.h:665
Position * getAttributeLiteral(Attribute attr)
Returns an attribute position for the given attribute.
Definition: Predicate.h:647
Predicate getConstraint(StringRef name, ArrayRef< Position * > args, ArrayRef< Type > resultTypes, bool isNegated)
Create a predicate that applies a generic constraint.
Definition: Predicate.h:726
Predicate getOperationName(StringRef name)
Create a predicate comparing the name of an operation to a known value.
Definition: Predicate.h:750
An ordinal predicate consists of a "Question" and a set of acceptable "Answers" (later converted to o...
Definition: Predicate.h:422
Predicates::Kind getKind() const
Returns the kind of this qualifier.
Definition: Predicate.h:427
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
bool operator<(const Fraction &x, const Fraction &y)
Definition: Fraction.h:83
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
A position describing an attribute of an operation.
Definition: Predicate.h:175
A BoolNode denotes a question with a boolean-like result.
BoolNode(Position *position, Qualifier *question, Qualifier *answer, std::unique_ptr< MatcherNode > successNode, std::unique_ptr< MatcherNode > failureNode=nullptr)
A position describing the result of a native constraint.
Definition: Predicate.h:301
Apply a parameterized constraint to multiple position values and possibly produce results.
Definition: Predicate.h:493
An operation position describes an operation node in the IR.
Definition: Predicate.h:259
bool isRoot() const
Returns if this operation position corresponds to the root.
Definition: Predicate.h:283
bool isOperandDefiningOp() const
Returns if this operation represents an operand defining op.
Definition: Predicate.cpp:55
A PositionalPredicate is a predicate that is associated with a specific positional value.
Definition: PredicateTree.h:30
The information associated with an edge in the cost graph.
Definition: RootOrdering.h:59
Value connector
The connector value in the intersection of the two subtrees rooted at the source and target root that...
Definition: RootOrdering.h:70
std::pair< unsigned, unsigned > cost
The depth of the connector Value w.r.t.
Definition: RootOrdering.h:65
A SuccessNode denotes that a given high level pattern has successfully been matched.
SuccessNode(pdl::PatternOp pattern, Value root, std::unique_ptr< MatcherNode > failureNode)
A SwitchNode denotes a question with multiple potential results.
llvm::MapVector< Qualifier *, std::unique_ptr< MatcherNode > > ChildMapT
Returns the children of this switch node.
SwitchNode(Position *position, Qualifier *question)
A position describing the result type of an entity, i.e.
Definition: Predicate.h:364