MLIR  21.0.0git
PredicateTree.cpp
Go to the documentation of this file.
1 //===- PredicateTree.cpp - Predicate tree merging -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PredicateTree.h"
10 #include "RootOrdering.h"
11 
13 #include "mlir/IR/BuiltinOps.h"
14 #include "llvm/ADT/MapVector.h"
15 #include "llvm/ADT/SmallPtrSet.h"
16 #include "llvm/ADT/TypeSwitch.h"
17 #include "llvm/Support/Debug.h"
18 #include <queue>
19 
20 #define DEBUG_TYPE "pdl-predicate-tree"
21 
22 using namespace mlir;
23 using namespace mlir::pdl_to_pdl_interp;
24 
25 //===----------------------------------------------------------------------===//
26 // Predicate List Building
27 //===----------------------------------------------------------------------===//
28 
29 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
30  Value val, PredicateBuilder &builder,
32  Position *pos);
33 
34 /// Compares the depths of two positions.
35 static bool comparePosDepth(Position *lhs, Position *rhs) {
36  return lhs->getOperationDepth() < rhs->getOperationDepth();
37 }
38 
39 /// Returns the number of non-range elements within `values`.
40 static unsigned getNumNonRangeValues(ValueRange values) {
41  return llvm::count_if(values.getTypes(),
42  [](Type type) { return !isa<pdl::RangeType>(type); });
43 }
44 
45 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
46  Value val, PredicateBuilder &builder,
48  AttributePosition *pos) {
49  assert(isa<pdl::AttributeType>(val.getType()) && "expected attribute type");
50  predList.emplace_back(pos, builder.getIsNotNull());
51 
52  if (auto attr = dyn_cast<pdl::AttributeOp>(val.getDefiningOp())) {
53  // If the attribute has a type or value, add a constraint.
54  if (Value type = attr.getValueType())
55  getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
56  else if (Attribute value = attr.getValueAttr())
57  predList.emplace_back(pos, builder.getAttributeConstraint(value));
58  }
59 }
60 
61 /// Collect all of the predicates for the given operand position.
62 static void getOperandTreePredicates(std::vector<PositionalPredicate> &predList,
63  Value val, PredicateBuilder &builder,
65  Position *pos) {
66  Type valueType = val.getType();
67  bool isVariadic = isa<pdl::RangeType>(valueType);
68 
69  // If this is a typed operand, add a type constraint.
71  .Case<pdl::OperandOp, pdl::OperandsOp>([&](auto op) {
72  // Prevent traversal into a null value if the operand has a proper
73  // index.
74  if (std::is_same<pdl::OperandOp, decltype(op)>::value ||
75  cast<OperandGroupPosition>(pos)->getOperandGroupNumber())
76  predList.emplace_back(pos, builder.getIsNotNull());
77 
78  if (Value type = op.getValueType())
79  getTreePredicates(predList, type, builder, inputs,
80  builder.getType(pos));
81  })
82  .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto op) {
83  std::optional<unsigned> index = op.getIndex();
84 
85  // Prevent traversal into a null value if the result has a proper index.
86  if (index)
87  predList.emplace_back(pos, builder.getIsNotNull());
88 
89  // Get the parent operation of this operand.
90  OperationPosition *parentPos = builder.getOperandDefiningOp(pos);
91  predList.emplace_back(parentPos, builder.getIsNotNull());
92 
93  // Ensure that the operands match the corresponding results of the
94  // parent operation.
95  Position *resultPos = nullptr;
96  if (std::is_same<pdl::ResultOp, decltype(op)>::value)
97  resultPos = builder.getResult(parentPos, *index);
98  else
99  resultPos = builder.getResultGroup(parentPos, index, isVariadic);
100  predList.emplace_back(resultPos, builder.getEqualTo(pos));
101 
102  // Collect the predicates of the parent operation.
103  getTreePredicates(predList, op.getParent(), builder, inputs,
104  (Position *)parentPos);
105  });
106 }
107 
108 static void
109 getTreePredicates(std::vector<PositionalPredicate> &predList, Value val,
110  PredicateBuilder &builder,
112  std::optional<unsigned> ignoreOperand = std::nullopt) {
113  assert(isa<pdl::OperationType>(val.getType()) && "expected operation");
114  pdl::OperationOp op = cast<pdl::OperationOp>(val.getDefiningOp());
115  OperationPosition *opPos = cast<OperationPosition>(pos);
116 
117  // Ensure getDefiningOp returns a non-null operation.
118  if (!opPos->isRoot())
119  predList.emplace_back(pos, builder.getIsNotNull());
120 
121  // Check that this is the correct root operation.
122  if (std::optional<StringRef> opName = op.getOpName())
123  predList.emplace_back(pos, builder.getOperationName(*opName));
124 
125  // Check that the operation has the proper number of operands. If there are
126  // any variable length operands, we check a minimum instead of an exact count.
127  OperandRange operands = op.getOperandValues();
128  unsigned minOperands = getNumNonRangeValues(operands);
129  if (minOperands != operands.size()) {
130  if (minOperands)
131  predList.emplace_back(pos, builder.getOperandCountAtLeast(minOperands));
132  } else {
133  predList.emplace_back(pos, builder.getOperandCount(minOperands));
134  }
135 
136  // Check that the operation has the proper number of results. If there are
137  // any variable length results, we check a minimum instead of an exact count.
138  OperandRange types = op.getTypeValues();
139  unsigned minResults = getNumNonRangeValues(types);
140  if (minResults == types.size())
141  predList.emplace_back(pos, builder.getResultCount(types.size()));
142  else if (minResults)
143  predList.emplace_back(pos, builder.getResultCountAtLeast(minResults));
144 
145  // Recurse into any attributes, operands, or results.
146  for (auto [attrName, attr] :
147  llvm::zip(op.getAttributeValueNames(), op.getAttributeValues())) {
149  predList, attr, builder, inputs,
150  builder.getAttribute(opPos, cast<StringAttr>(attrName).getValue()));
151  }
152 
153  // Process the operands and results of the operation. For all values up to
154  // the first variable length value, we use the concrete operand/result
155  // number. After that, we use the "group" given that we can't know the
156  // concrete indices until runtime. If there is only one variadic operand
157  // group, we treat it as all of the operands/results of the operation.
158  /// Operands.
159  if (operands.size() == 1 && isa<pdl::RangeType>(operands[0].getType())) {
160  // Ignore the operands if we are performing an upward traversal (in that
161  // case, they have already been visited).
162  if (opPos->isRoot() || opPos->isOperandDefiningOp())
163  getTreePredicates(predList, operands.front(), builder, inputs,
164  builder.getAllOperands(opPos));
165  } else {
166  bool foundVariableLength = false;
167  for (const auto &operandIt : llvm::enumerate(operands)) {
168  bool isVariadic = isa<pdl::RangeType>(operandIt.value().getType());
169  foundVariableLength |= isVariadic;
170 
171  // Ignore the specified operand, usually because this position was
172  // visited in an upward traversal via an iterative choice.
173  if (ignoreOperand == operandIt.index())
174  continue;
175 
176  Position *pos =
177  foundVariableLength
178  ? builder.getOperandGroup(opPos, operandIt.index(), isVariadic)
179  : builder.getOperand(opPos, operandIt.index());
180  getTreePredicates(predList, operandIt.value(), builder, inputs, pos);
181  }
182  }
183  /// Results.
184  if (types.size() == 1 && isa<pdl::RangeType>(types[0].getType())) {
185  getTreePredicates(predList, types.front(), builder, inputs,
186  builder.getType(builder.getAllResults(opPos)));
187  return;
188  }
189 
190  bool foundVariableLength = false;
191  for (auto [idx, typeValue] : llvm::enumerate(types)) {
192  bool isVariadic = isa<pdl::RangeType>(typeValue.getType());
193  foundVariableLength |= isVariadic;
194 
195  auto *resultPos = foundVariableLength
196  ? builder.getResultGroup(pos, idx, isVariadic)
197  : builder.getResult(pos, idx);
198  predList.emplace_back(resultPos, builder.getIsNotNull());
199  getTreePredicates(predList, typeValue, builder, inputs,
200  builder.getType(resultPos));
201  }
202 }
203 
204 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
205  Value val, PredicateBuilder &builder,
207  TypePosition *pos) {
208  // Check for a constraint on a constant type.
209  if (pdl::TypeOp typeOp = val.getDefiningOp<pdl::TypeOp>()) {
210  if (Attribute type = typeOp.getConstantTypeAttr())
211  predList.emplace_back(pos, builder.getTypeConstraint(type));
212  } else if (pdl::TypesOp typeOp = val.getDefiningOp<pdl::TypesOp>()) {
213  if (Attribute typeAttr = typeOp.getConstantTypesAttr())
214  predList.emplace_back(pos, builder.getTypeConstraint(typeAttr));
215  }
216 }
217 
218 /// Collect the tree predicates anchored at the given value.
219 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
220  Value val, PredicateBuilder &builder,
222  Position *pos) {
223  // Make sure this input value is accessible to the rewrite.
224  auto it = inputs.try_emplace(val, pos);
225  if (!it.second) {
226  // If this is an input value that has been visited in the tree, add a
227  // constraint to ensure that both instances refer to the same value.
228  if (isa<pdl::AttributeOp, pdl::OperandOp, pdl::OperandsOp, pdl::OperationOp,
229  pdl::TypeOp>(val.getDefiningOp())) {
230  auto minMaxPositions =
231  std::minmax(pos, it.first->second, comparePosDepth);
232  predList.emplace_back(minMaxPositions.second,
233  builder.getEqualTo(minMaxPositions.first));
234  }
235  return;
236  }
237 
239  .Case<AttributePosition, OperationPosition, TypePosition>([&](auto *pos) {
240  getTreePredicates(predList, val, builder, inputs, pos);
241  })
242  .Case<OperandPosition, OperandGroupPosition>([&](auto *pos) {
243  getOperandTreePredicates(predList, val, builder, inputs, pos);
244  })
245  .Default([](auto *) { llvm_unreachable("unexpected position kind"); });
246 }
247 
248 static void getAttributePredicates(pdl::AttributeOp op,
249  std::vector<PositionalPredicate> &predList,
250  PredicateBuilder &builder,
251  DenseMap<Value, Position *> &inputs) {
252  Position *&attrPos = inputs[op];
253  if (attrPos)
254  return;
255  Attribute value = op.getValueAttr();
256  assert(value && "expected non-tree `pdl.attribute` to contain a value");
257  attrPos = builder.getAttributeLiteral(value);
258 }
259 
260 static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
261  std::vector<PositionalPredicate> &predList,
262  PredicateBuilder &builder,
263  DenseMap<Value, Position *> &inputs) {
264  OperandRange arguments = op.getArgs();
265 
266  std::vector<Position *> allPositions;
267  allPositions.reserve(arguments.size());
268  for (Value arg : arguments)
269  allPositions.push_back(inputs.lookup(arg));
270 
271  // Push the constraint to the furthest position.
272  Position *pos = *llvm::max_element(allPositions, comparePosDepth);
273  ResultRange results = op.getResults();
275  op.getName(), allPositions, SmallVector<Type>(results.getTypes()),
276  op.getIsNegated());
277 
278  // For each result register a position so it can be used later
279  for (auto [i, result] : llvm::enumerate(results)) {
280  ConstraintQuestion *q = cast<ConstraintQuestion>(pred.first);
281  ConstraintPosition *pos = builder.getConstraintPosition(q, i);
282  auto [it, inserted] = inputs.try_emplace(result, pos);
283  // If this is an input value that has been visited in the tree, add a
284  // constraint to ensure that both instances refer to the same value.
285  if (!inserted) {
286  Position *first = pos;
287  Position *second = it->second;
288  if (comparePosDepth(second, first))
289  std::tie(second, first) = std::make_pair(first, second);
290 
291  predList.emplace_back(second, builder.getEqualTo(first));
292  }
293  }
294  predList.emplace_back(pos, pred);
295 }
296 
297 static void getResultPredicates(pdl::ResultOp op,
298  std::vector<PositionalPredicate> &predList,
299  PredicateBuilder &builder,
300  DenseMap<Value, Position *> &inputs) {
301  Position *&resultPos = inputs[op];
302  if (resultPos)
303  return;
304 
305  // Ensure that the result isn't null.
306  auto *parentPos = cast<OperationPosition>(inputs.lookup(op.getParent()));
307  resultPos = builder.getResult(parentPos, op.getIndex());
308  predList.emplace_back(resultPos, builder.getIsNotNull());
309 }
310 
311 static void getResultPredicates(pdl::ResultsOp op,
312  std::vector<PositionalPredicate> &predList,
313  PredicateBuilder &builder,
314  DenseMap<Value, Position *> &inputs) {
315  Position *&resultPos = inputs[op];
316  if (resultPos)
317  return;
318 
319  // Ensure that the result isn't null if the result has an index.
320  auto *parentPos = cast<OperationPosition>(inputs.lookup(op.getParent()));
321  bool isVariadic = isa<pdl::RangeType>(op.getType());
322  std::optional<unsigned> index = op.getIndex();
323  resultPos = builder.getResultGroup(parentPos, index, isVariadic);
324  if (index)
325  predList.emplace_back(resultPos, builder.getIsNotNull());
326 }
327 
328 static void getTypePredicates(Value typeValue,
329  function_ref<Attribute()> typeAttrFn,
330  PredicateBuilder &builder,
331  DenseMap<Value, Position *> &inputs) {
332  Position *&typePos = inputs[typeValue];
333  if (typePos)
334  return;
335  Attribute typeAttr = typeAttrFn();
336  assert(typeAttr &&
337  "expected non-tree `pdl.type`/`pdl.types` to contain a value");
338  typePos = builder.getTypeLiteral(typeAttr);
339 }
340 
341 /// Collect all of the predicates that cannot be determined via walking the
342 /// tree.
343 static void getNonTreePredicates(pdl::PatternOp pattern,
344  std::vector<PositionalPredicate> &predList,
345  PredicateBuilder &builder,
346  DenseMap<Value, Position *> &inputs) {
347  for (Operation &op : pattern.getBodyRegion().getOps()) {
349  .Case([&](pdl::AttributeOp attrOp) {
350  getAttributePredicates(attrOp, predList, builder, inputs);
351  })
352  .Case<pdl::ApplyNativeConstraintOp>([&](auto constraintOp) {
353  getConstraintPredicates(constraintOp, predList, builder, inputs);
354  })
355  .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
356  getResultPredicates(resultOp, predList, builder, inputs);
357  })
358  .Case([&](pdl::TypeOp typeOp) {
360  typeOp, [&] { return typeOp.getConstantTypeAttr(); }, builder,
361  inputs);
362  })
363  .Case([&](pdl::TypesOp typeOp) {
365  typeOp, [&] { return typeOp.getConstantTypesAttr(); }, builder,
366  inputs);
367  });
368  }
369 }
370 
371 namespace {
372 
373 /// An op accepting a value at an optional index.
374 struct OpIndex {
375  Value parent;
376  std::optional<unsigned> index;
377 };
378 
379 /// The parent and operand index of each operation for each root, stored
380 /// as a nested map [root][operation].
381 using ParentMaps = DenseMap<Value, DenseMap<Value, OpIndex>>;
382 
383 } // namespace
384 
385 /// Given a pattern, determines the set of roots present in this pattern.
386 /// These are the operations whose results are not consumed by other operations.
387 static SmallVector<Value> detectRoots(pdl::PatternOp pattern) {
388  // First, collect all the operations that are used as operands
389  // to other operations. These are not roots by default.
390  DenseSet<Value> used;
391  for (auto operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>()) {
392  for (Value operand : operationOp.getOperandValues())
393  TypeSwitch<Operation *>(operand.getDefiningOp())
394  .Case<pdl::ResultOp, pdl::ResultsOp>(
395  [&used](auto resultOp) { used.insert(resultOp.getParent()); });
396  }
397 
398  // Remove the specified root from the use set, so that we can
399  // always select it as a root, even if it is used by other operations.
400  if (Value root = pattern.getRewriter().getRoot())
401  used.erase(root);
402 
403  // Finally, collect all the unused operations.
404  SmallVector<Value> roots;
405  for (Value operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>())
406  if (!used.contains(operationOp))
407  roots.push_back(operationOp);
408 
409  return roots;
410 }
411 
412 /// Given a list of candidate roots, builds the cost graph for connecting them.
413 /// The graph is formed by traversing the DAG of operations starting from each
414 /// root and marking the depth of each connector value (operand). Then we join
415 /// the candidate roots based on the common connector values, taking the one
416 /// with the minimum depth. Along the way, we compute, for each candidate root,
417 /// a mapping from each operation (in the DAG underneath this root) to its
418 /// parent operation and the corresponding operand index.
420  ParentMaps &parentMaps) {
421 
422  // The entry of a queue. The entry consists of the following items:
423  // * the value in the DAG underneath the root;
424  // * the parent of the value;
425  // * the operand index of the value in its parent;
426  // * the depth of the visited value.
427  struct Entry {
428  Entry(Value value, Value parent, std::optional<unsigned> index,
429  unsigned depth)
430  : value(value), parent(parent), index(index), depth(depth) {}
431 
432  Value value;
433  Value parent;
434  std::optional<unsigned> index;
435  unsigned depth;
436  };
437 
438  // A root of a value and its depth (distance from root to the value).
439  struct RootDepth {
440  Value root;
441  unsigned depth = 0;
442  };
443 
444  // Map from candidate connector values to their roots and depths. Using a
445  // small vector with 1 entry because most values belong to a single root.
446  llvm::MapVector<Value, SmallVector<RootDepth, 1>> connectorsRootsDepths;
447 
448  // Perform a breadth-first traversal of the op DAG rooted at each root.
449  for (Value root : roots) {
450  // The queue of visited values. A value may be present multiple times in
451  // the queue, for multiple parents. We only accept the first occurrence,
452  // which is guaranteed to have the lowest depth.
453  std::queue<Entry> toVisit;
454  toVisit.emplace(root, Value(), 0, 0);
455 
456  // The map from value to its parent for the current root.
457  DenseMap<Value, OpIndex> &parentMap = parentMaps[root];
458 
459  while (!toVisit.empty()) {
460  Entry entry = toVisit.front();
461  toVisit.pop();
462  // Skip if already visited.
463  if (!parentMap.insert({entry.value, {entry.parent, entry.index}}).second)
464  continue;
465 
466  // Mark the root and depth of the value.
467  connectorsRootsDepths[entry.value].push_back({root, entry.depth});
468 
469  // Traverse the operands of an operation and result ops.
470  // We intentionally do not traverse attributes and types, because those
471  // are expensive to join on.
472  TypeSwitch<Operation *>(entry.value.getDefiningOp())
473  .Case<pdl::OperationOp>([&](auto operationOp) {
474  OperandRange operands = operationOp.getOperandValues();
475  // Special case when we pass all the operands in one range.
476  // For those, the index is empty.
477  if (operands.size() == 1 &&
478  isa<pdl::RangeType>(operands[0].getType())) {
479  toVisit.emplace(operands[0], entry.value, std::nullopt,
480  entry.depth + 1);
481  return;
482  }
483 
484  // Default case: visit all the operands.
485  for (const auto &p :
486  llvm::enumerate(operationOp.getOperandValues()))
487  toVisit.emplace(p.value(), entry.value, p.index(),
488  entry.depth + 1);
489  })
490  .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
491  toVisit.emplace(resultOp.getParent(), entry.value,
492  resultOp.getIndex(), entry.depth);
493  });
494  }
495  }
496 
497  // Now build the cost graph.
498  // This is simply a minimum over all depths for the target root.
499  unsigned nextID = 0;
500  for (const auto &connectorRootsDepths : connectorsRootsDepths) {
501  Value value = connectorRootsDepths.first;
502  ArrayRef<RootDepth> rootsDepths = connectorRootsDepths.second;
503  // If there is only one root for this value, this will not trigger
504  // any edges in the cost graph (a perf optimization).
505  if (rootsDepths.size() == 1)
506  continue;
507 
508  for (const RootDepth &p : rootsDepths) {
509  for (const RootDepth &q : rootsDepths) {
510  if (&p == &q)
511  continue;
512  // Insert or retrieve the property of edge from p to q.
513  RootOrderingEntry &entry = graph[q.root][p.root];
514  if (!entry.connector /* new edge */ || entry.cost.first > q.depth) {
515  if (!entry.connector)
516  entry.cost.second = nextID++;
517  entry.cost.first = q.depth;
518  entry.connector = value;
519  }
520  }
521  }
522  }
523 
524  assert((llvm::hasSingleElement(roots) || graph.size() == roots.size()) &&
525  "the pattern contains a candidate root disconnected from the others");
526 }
527 
528 /// Returns true if the operand at the given index needs to be queried using an
529 /// operand group, i.e., if it is variadic itself or follows a variadic operand.
530 static bool useOperandGroup(pdl::OperationOp op, unsigned index) {
531  OperandRange operands = op.getOperandValues();
532  assert(index < operands.size() && "operand index out of range");
533  for (unsigned i = 0; i <= index; ++i)
534  if (isa<pdl::RangeType>(operands[i].getType()))
535  return true;
536  return false;
537 }
538 
539 /// Visit a node during upward traversal.
540 static void visitUpward(std::vector<PositionalPredicate> &predList,
541  OpIndex opIndex, PredicateBuilder &builder,
542  DenseMap<Value, Position *> &valueToPosition,
543  Position *&pos, unsigned rootID) {
544  Value value = opIndex.parent;
546  .Case<pdl::OperationOp>([&](auto operationOp) {
547  LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
548 
549  // Get users and iterate over them.
550  Position *usersPos = builder.getUsers(pos, /*useRepresentative=*/true);
551  Position *foreachPos = builder.getForEach(usersPos, rootID);
552  OperationPosition *opPos = builder.getPassthroughOp(foreachPos);
553 
554  // Compare the operand(s) of the user against the input value(s).
555  Position *operandPos;
556  if (!opIndex.index) {
557  // We are querying all the operands of the operation.
558  operandPos = builder.getAllOperands(opPos);
559  } else if (useOperandGroup(operationOp, *opIndex.index)) {
560  // We are querying an operand group.
561  Type type = operationOp.getOperandValues()[*opIndex.index].getType();
562  bool variadic = isa<pdl::RangeType>(type);
563  operandPos = builder.getOperandGroup(opPos, opIndex.index, variadic);
564  } else {
565  // We are querying an individual operand.
566  operandPos = builder.getOperand(opPos, *opIndex.index);
567  }
568  predList.emplace_back(operandPos, builder.getEqualTo(pos));
569 
570  // Guard against duplicate upward visits. These are not possible,
571  // because if this value was already visited, it would have been
572  // cheaper to start the traversal at this value rather than at the
573  // `connector`, violating the optimality of our spanning tree.
574  bool inserted = valueToPosition.try_emplace(value, opPos).second;
575  (void)inserted;
576  assert(inserted && "duplicate upward visit");
577 
578  // Obtain the tree predicates at the current value.
579  getTreePredicates(predList, value, builder, valueToPosition, opPos,
580  opIndex.index);
581 
582  // Update the position
583  pos = opPos;
584  })
585  .Case<pdl::ResultOp>([&](auto resultOp) {
586  // Traverse up an individual result.
587  auto *opPos = dyn_cast<OperationPosition>(pos);
588  assert(opPos && "operations and results must be interleaved");
589  pos = builder.getResult(opPos, *opIndex.index);
590 
591  // Insert the result position in case we have not visited it yet.
592  valueToPosition.try_emplace(value, pos);
593  })
594  .Case<pdl::ResultsOp>([&](auto resultOp) {
595  // Traverse up a group of results.
596  auto *opPos = dyn_cast<OperationPosition>(pos);
597  assert(opPos && "operations and results must be interleaved");
598  bool isVariadic = isa<pdl::RangeType>(value.getType());
599  if (opIndex.index)
600  pos = builder.getResultGroup(opPos, opIndex.index, isVariadic);
601  else
602  pos = builder.getAllResults(opPos);
603 
604  // Insert the result position in case we have not visited it yet.
605  valueToPosition.try_emplace(value, pos);
606  });
607 }
608 
609 /// Given a pattern operation, build the set of matcher predicates necessary to
610 /// match this pattern.
611 static Value buildPredicateList(pdl::PatternOp pattern,
612  PredicateBuilder &builder,
613  std::vector<PositionalPredicate> &predList,
614  DenseMap<Value, Position *> &valueToPosition) {
615  SmallVector<Value> roots = detectRoots(pattern);
616 
617  // Build the root ordering graph and compute the parent maps.
618  RootOrderingGraph graph;
619  ParentMaps parentMaps;
620  buildCostGraph(roots, graph, parentMaps);
621  LLVM_DEBUG({
622  llvm::dbgs() << "Graph:\n";
623  for (auto &target : graph) {
624  llvm::dbgs() << " * " << target.first.getLoc() << " " << target.first
625  << "\n";
626  for (auto &source : target.second) {
627  RootOrderingEntry &entry = source.second;
628  llvm::dbgs() << " <- " << source.first << ": " << entry.cost.first
629  << ":" << entry.cost.second << " via "
630  << entry.connector.getLoc() << "\n";
631  }
632  }
633  });
634 
635  // Solve the optimal branching problem for each candidate root, or use the
636  // provided one.
637  Value bestRoot = pattern.getRewriter().getRoot();
638  OptimalBranching::EdgeList bestEdges;
639  if (!bestRoot) {
640  unsigned bestCost = 0;
641  LLVM_DEBUG(llvm::dbgs() << "Candidate roots:\n");
642  for (Value root : roots) {
643  OptimalBranching solver(graph, root);
644  unsigned cost = solver.solve();
645  LLVM_DEBUG(llvm::dbgs() << " * " << root << ": " << cost << "\n");
646  if (!bestRoot || bestCost > cost) {
647  bestCost = cost;
648  bestRoot = root;
649  bestEdges = solver.preOrderTraversal(roots);
650  }
651  }
652  } else {
653  OptimalBranching solver(graph, bestRoot);
654  solver.solve();
655  bestEdges = solver.preOrderTraversal(roots);
656  }
657 
658  // Print the best solution.
659  LLVM_DEBUG({
660  llvm::dbgs() << "Best tree:\n";
661  for (const std::pair<Value, Value> &edge : bestEdges) {
662  llvm::dbgs() << " * " << edge.first;
663  if (edge.second)
664  llvm::dbgs() << " <- " << edge.second;
665  llvm::dbgs() << "\n";
666  }
667  });
668 
669  LLVM_DEBUG(llvm::dbgs() << "Calling key getTreePredicates:\n");
670  LLVM_DEBUG(llvm::dbgs() << " * Value: " << bestRoot << "\n");
671 
672  // The best root is the starting point for the traversal. Get the tree
673  // predicates for the DAG rooted at bestRoot.
674  getTreePredicates(predList, bestRoot, builder, valueToPosition,
675  builder.getRoot());
676 
677  // Traverse the selected optimal branching. For all edges in order, traverse
678  // up starting from the connector, until the candidate root is reached, and
679  // call getTreePredicates at every node along the way.
680  for (const auto &it : llvm::enumerate(bestEdges)) {
681  Value target = it.value().first;
682  Value source = it.value().second;
683 
684  // Check if we already visited the target root. This happens in two cases:
685  // 1) the initial root (bestRoot);
686  // 2) a root that is dominated by (contained in the subtree rooted at) an
687  // already visited root.
688  if (valueToPosition.count(target))
689  continue;
690 
691  // Determine the connector.
692  Value connector = graph[target][source].connector;
693  assert(connector && "invalid edge");
694  LLVM_DEBUG(llvm::dbgs() << " * Connector: " << connector.getLoc() << "\n");
695  DenseMap<Value, OpIndex> parentMap = parentMaps.lookup(target);
696  Position *pos = valueToPosition.lookup(connector);
697  assert(pos && "connector has not been traversed yet");
698 
699  // Traverse from the connector upwards towards the target root.
700  for (Value value = connector; value != target;) {
701  OpIndex opIndex = parentMap.lookup(value);
702  assert(opIndex.parent && "missing parent");
703  visitUpward(predList, opIndex, builder, valueToPosition, pos, it.index());
704  value = opIndex.parent;
705  }
706  }
707 
708  getNonTreePredicates(pattern, predList, builder, valueToPosition);
709 
710  return bestRoot;
711 }
712 
713 //===----------------------------------------------------------------------===//
714 // Pattern Predicate Tree Merging
715 //===----------------------------------------------------------------------===//
716 
717 namespace {
718 
719 /// This class represents a specific predicate applied to a position, and
720 /// provides hashing and ordering operators. This class allows for computing a
721 /// frequence sum and ordering predicates based on a cost model.
722 struct OrderedPredicate {
723  OrderedPredicate(const std::pair<Position *, Qualifier *> &ip)
724  : position(ip.first), question(ip.second) {}
725  OrderedPredicate(const PositionalPredicate &ip)
726  : position(ip.position), question(ip.question) {}
727 
728  /// The position this predicate is applied to.
729  Position *position;
730 
731  /// The question that is applied by this predicate onto the position.
732  Qualifier *question;
733 
734  /// The first and second order benefit sums.
735  /// The primary sum is the number of occurrences of this predicate among all
736  /// of the patterns.
737  unsigned primary = 0;
738  /// The secondary sum is a squared summation of the primary sum of all of the
739  /// predicates within each pattern that contains this predicate. This allows
740  /// for favoring predicates that are more commonly shared within a pattern, as
741  /// opposed to those shared across patterns.
742  unsigned secondary = 0;
743 
744  /// The tie breaking ID, used to preserve a deterministic (insertion) order
745  /// among all the predicates with the same priority, depth, and position /
746  /// predicate dependency.
747  unsigned id = 0;
748 
749  /// A map between a pattern operation and the answer to the predicate question
750  /// within that pattern.
751  DenseMap<Operation *, Qualifier *> patternToAnswer;
752 
753  /// Returns true if this predicate is ordered before `rhs`, based on the cost
754  /// model.
755  bool operator<(const OrderedPredicate &rhs) const {
756  // Sort by:
757  // * higher first and secondary order sums
758  // * lower depth
759  // * lower position dependency
760  // * lower predicate dependency
761  // * lower tie breaking ID
762  auto *rhsPos = rhs.position;
763  return std::make_tuple(primary, secondary, rhsPos->getOperationDepth(),
764  rhsPos->getKind(), rhs.question->getKind(), rhs.id) >
765  std::make_tuple(rhs.primary, rhs.secondary,
766  position->getOperationDepth(), position->getKind(),
767  question->getKind(), id);
768  }
769 };
770 
771 /// A DenseMapInfo for OrderedPredicate based solely on the position and
772 /// question.
773 struct OrderedPredicateDenseInfo {
775 
776  static OrderedPredicate getEmptyKey() { return Base::getEmptyKey(); }
777  static OrderedPredicate getTombstoneKey() { return Base::getTombstoneKey(); }
778  static bool isEqual(const OrderedPredicate &lhs,
779  const OrderedPredicate &rhs) {
780  return lhs.position == rhs.position && lhs.question == rhs.question;
781  }
782  static unsigned getHashValue(const OrderedPredicate &p) {
783  return llvm::hash_combine(p.position, p.question);
784  }
785 };
786 
787 /// This class wraps a set of ordered predicates that are used within a specific
788 /// pattern operation.
789 struct OrderedPredicateList {
790  OrderedPredicateList(pdl::PatternOp pattern, Value root)
791  : pattern(pattern), root(root) {}
792 
793  pdl::PatternOp pattern;
794  Value root;
795  DenseSet<OrderedPredicate *> predicates;
796 };
797 } // namespace
798 
799 /// Returns true if the given matcher refers to the same predicate as the given
800 /// ordered predicate. This means that the position and questions of the two
801 /// match.
802 static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate) {
803  return node->getPosition() == predicate->position &&
804  node->getQuestion() == predicate->question;
805 }
806 
807 /// Get or insert a child matcher for the given parent switch node, given a
808 /// predicate and parent pattern.
809 std::unique_ptr<MatcherNode> &getOrCreateChild(SwitchNode *node,
810  OrderedPredicate *predicate,
811  pdl::PatternOp pattern) {
812  assert(isSamePredicate(node, predicate) &&
813  "expected matcher to equal the given predicate");
814 
815  auto it = predicate->patternToAnswer.find(pattern);
816  assert(it != predicate->patternToAnswer.end() &&
817  "expected pattern to exist in predicate");
818  return node->getChildren()[it->second];
819 }
820 
821 /// Build the matcher CFG by "pushing" patterns through by sorted predicate
822 /// order. A pattern will traverse as far as possible using common predicates
823 /// and then either diverge from the CFG or reach the end of a branch and start
824 /// creating new nodes.
825 static void propagatePattern(std::unique_ptr<MatcherNode> &node,
826  OrderedPredicateList &list,
827  std::vector<OrderedPredicate *>::iterator current,
828  std::vector<OrderedPredicate *>::iterator end) {
829  if (current == end) {
830  // We've hit the end of a pattern, so create a successful result node.
831  node =
832  std::make_unique<SuccessNode>(list.pattern, list.root, std::move(node));
833 
834  // If the pattern doesn't contain this predicate, ignore it.
835  } else if (!list.predicates.contains(*current)) {
836  propagatePattern(node, list, std::next(current), end);
837 
838  // If the current matcher node is invalid, create a new one for this
839  // position and continue propagation.
840  } else if (!node) {
841  // Create a new node at this position and continue
842  node = std::make_unique<SwitchNode>((*current)->position,
843  (*current)->question);
845  getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
846  list, std::next(current), end);
847 
848  // If the matcher has already been created, and it is for this predicate we
849  // continue propagation to the child.
850  } else if (isSamePredicate(node.get(), *current)) {
852  getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
853  list, std::next(current), end);
854 
855  // If the matcher doesn't match the current predicate, insert a branch as
856  // the common set of matchers has diverged.
857  } else {
858  propagatePattern(node->getFailureNode(), list, current, end);
859  }
860 }
861 
862 /// Fold any switch nodes nested under `node` to boolean nodes when possible.
863 /// `node` is updated in-place if it is a switch.
864 static void foldSwitchToBool(std::unique_ptr<MatcherNode> &node) {
865  if (!node)
866  return;
867 
868  if (SwitchNode *switchNode = dyn_cast<SwitchNode>(&*node)) {
869  SwitchNode::ChildMapT &children = switchNode->getChildren();
870  for (auto &it : children)
871  foldSwitchToBool(it.second);
872 
873  // If the node only contains one child, collapse it into a boolean predicate
874  // node.
875  if (children.size() == 1) {
876  auto *childIt = children.begin();
877  node = std::make_unique<BoolNode>(
878  node->getPosition(), node->getQuestion(), childIt->first,
879  std::move(childIt->second), std::move(node->getFailureNode()));
880  }
881  } else if (BoolNode *boolNode = dyn_cast<BoolNode>(&*node)) {
882  foldSwitchToBool(boolNode->getSuccessNode());
883  }
884 
885  foldSwitchToBool(node->getFailureNode());
886 }
887 
888 /// Insert an exit node at the end of the failure path of the `root`.
889 static void insertExitNode(std::unique_ptr<MatcherNode> *root) {
890  while (*root)
891  root = &(*root)->getFailureNode();
892  *root = std::make_unique<ExitNode>();
893 }
894 
895 /// Sorts the range begin/end with the partial order given by cmp.
896 template <typename Iterator, typename Compare>
897 static void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp) {
898  while (begin != end) {
899  // Cannot compute sortBeforeOthers in the predicate of stable_partition
900  // because stable_partition will not keep the [begin, end) range intact
901  // while it runs.
903  for (auto i = begin; i != end; ++i) {
904  if (std::none_of(begin, end, [&](auto const &b) { return cmp(b, *i); }))
905  sortBeforeOthers.insert(*i);
906  }
907 
908  auto const next = std::stable_partition(begin, end, [&](auto const &a) {
909  return sortBeforeOthers.contains(a);
910  });
911  assert(next != begin && "not a partial ordering");
912  begin = next;
913  }
914 }
915 
916 /// Returns true if 'b' depends on a result of 'a'.
917 static bool dependsOn(OrderedPredicate *a, OrderedPredicate *b) {
918  auto *cqa = dyn_cast<ConstraintQuestion>(a->question);
919  if (!cqa)
920  return false;
921 
922  auto positionDependsOnA = [&](Position *p) {
923  auto *cp = dyn_cast<ConstraintPosition>(p);
924  return cp && cp->getQuestion() == cqa;
925  };
926 
927  if (auto *cqb = dyn_cast<ConstraintQuestion>(b->question)) {
928  // Does any argument of b use a?
929  return llvm::any_of(cqb->getArgs(), positionDependsOnA);
930  }
931  if (auto *equalTo = dyn_cast<EqualToQuestion>(b->question)) {
932  return positionDependsOnA(b->position) ||
933  positionDependsOnA(equalTo->getValue());
934  }
935  return positionDependsOnA(b->position);
936 }
937 
938 /// Given a module containing PDL pattern operations, generate a matcher tree
939 /// using the patterns within the given module and return the root matcher node.
940 std::unique_ptr<MatcherNode>
941 MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
942  DenseMap<Value, Position *> &valueToPosition) {
943  // The set of predicates contained within the pattern operations of the
944  // module.
945  struct PatternPredicates {
946  PatternPredicates(pdl::PatternOp pattern, Value root,
947  std::vector<PositionalPredicate> predicates)
948  : pattern(pattern), root(root), predicates(std::move(predicates)) {}
949 
950  /// A pattern.
951  pdl::PatternOp pattern;
952 
953  /// A root of the pattern chosen among the candidate roots in pdl.rewrite.
954  Value root;
955 
956  /// The extracted predicates for this pattern and root.
957  std::vector<PositionalPredicate> predicates;
958  };
959 
960  SmallVector<PatternPredicates, 16> patternsAndPredicates;
961  for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
962  std::vector<PositionalPredicate> predicateList;
963  Value root =
964  buildPredicateList(pattern, builder, predicateList, valueToPosition);
965  patternsAndPredicates.emplace_back(pattern, root, std::move(predicateList));
966  }
967 
968  // Associate a pattern result with each unique predicate.
970  for (auto &patternAndPredList : patternsAndPredicates) {
971  for (auto &predicate : patternAndPredList.predicates) {
972  auto it = uniqued.insert(predicate);
973  it.first->patternToAnswer.try_emplace(patternAndPredList.pattern,
974  predicate.answer);
975  // Mark the insertion order (0-based indexing).
976  if (it.second)
977  it.first->id = uniqued.size() - 1;
978  }
979  }
980 
981  // Associate each pattern to a set of its ordered predicates for later lookup.
982  std::vector<OrderedPredicateList> lists;
983  lists.reserve(patternsAndPredicates.size());
984  for (auto &patternAndPredList : patternsAndPredicates) {
985  OrderedPredicateList list(patternAndPredList.pattern,
986  patternAndPredList.root);
987  for (auto &predicate : patternAndPredList.predicates) {
988  OrderedPredicate *orderedPredicate = &*uniqued.find(predicate);
989  list.predicates.insert(orderedPredicate);
990 
991  // Increment the primary sum for each reference to a particular predicate.
992  ++orderedPredicate->primary;
993  }
994  lists.push_back(std::move(list));
995  }
996 
997  // For a particular pattern, get the total primary sum and add it to the
998  // secondary sum of each predicate. Square the primary sums to emphasize
999  // shared predicates within rather than across patterns.
1000  for (auto &list : lists) {
1001  unsigned total = 0;
1002  for (auto *predicate : list.predicates)
1003  total += predicate->primary * predicate->primary;
1004  for (auto *predicate : list.predicates)
1005  predicate->secondary += total;
1006  }
1007 
1008  // Sort the set of predicates now that the cost primary and secondary sums
1009  // have been computed.
1010  std::vector<OrderedPredicate *> ordered;
1011  ordered.reserve(uniqued.size());
1012  for (auto &ip : uniqued)
1013  ordered.push_back(&ip);
1014  llvm::sort(ordered, [](OrderedPredicate *lhs, OrderedPredicate *rhs) {
1015  return *lhs < *rhs;
1016  });
1017 
1018  // Mostly keep the now established order, but also ensure that
1019  // ConstraintQuestions come after the results they use.
1020  stableTopologicalSort(ordered.begin(), ordered.end(), dependsOn);
1021 
1022  // Build the matchers for each of the pattern predicate lists.
1023  std::unique_ptr<MatcherNode> root;
1024  for (OrderedPredicateList &list : lists)
1025  propagatePattern(root, list, ordered.begin(), ordered.end());
1026 
1027  // Collapse the graph and insert the exit node.
1028  foldSwitchToBool(root);
1029  insertExitNode(&root);
1030  return root;
1031 }
1032 
1033 //===----------------------------------------------------------------------===//
1034 // MatcherNode
1035 //===----------------------------------------------------------------------===//
1036 
1037 MatcherNode::MatcherNode(TypeID matcherTypeID, Position *p, Qualifier *q,
1038  std::unique_ptr<MatcherNode> failureNode)
1039  : position(p), question(q), failureNode(std::move(failureNode)),
1040  matcherTypeID(matcherTypeID) {}
1041 
1042 //===----------------------------------------------------------------------===//
1043 // BoolNode
1044 //===----------------------------------------------------------------------===//
1045 
1046 BoolNode::BoolNode(Position *position, Qualifier *question, Qualifier *answer,
1047  std::unique_ptr<MatcherNode> successNode,
1048  std::unique_ptr<MatcherNode> failureNode)
1049  : MatcherNode(TypeID::get<BoolNode>(), position, question,
1050  std::move(failureNode)),
1051  answer(answer), successNode(std::move(successNode)) {}
1052 
1053 //===----------------------------------------------------------------------===//
1054 // SuccessNode
1055 //===----------------------------------------------------------------------===//
1056 
1057 SuccessNode::SuccessNode(pdl::PatternOp pattern, Value root,
1058  std::unique_ptr<MatcherNode> failureNode)
1059  : MatcherNode(TypeID::get<SuccessNode>(), /*position=*/nullptr,
1060  /*question=*/nullptr, std::move(failureNode)),
1061  pattern(pattern), root(root) {}
1062 
1063 //===----------------------------------------------------------------------===//
1064 // SwitchNode
1065 //===----------------------------------------------------------------------===//
1066 
1068  : MatcherNode(TypeID::get<SwitchNode>(), position, question) {}
static Value buildPredicateList(pdl::PatternOp pattern, PredicateBuilder &builder, std::vector< PositionalPredicate > &predList, DenseMap< Value, Position * > &valueToPosition)
Given a pattern operation, build the set of matcher predicates necessary to match this pattern.
static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
static bool dependsOn(OrderedPredicate *a, OrderedPredicate *b)
Returns true if 'b' depends on a result of 'a'.
static void getTypePredicates(Value typeValue, function_ref< Attribute()> typeAttrFn, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
static void getNonTreePredicates(pdl::PatternOp pattern, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
Collect all of the predicates that cannot be determined via walking the tree.
static bool useOperandGroup(pdl::OperationOp op, unsigned index)
Returns true if the operand at the given index needs to be queried using an operand group,...
static void getTreePredicates(std::vector< PositionalPredicate > &predList, Value val, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs, Position *pos)
Collect the tree predicates anchored at the given value.
static void getResultPredicates(pdl::ResultOp op, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
static void getOperandTreePredicates(std::vector< PositionalPredicate > &predList, Value val, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs, Position *pos)
Collect all of the predicates for the given operand position.
std::unique_ptr< MatcherNode > & getOrCreateChild(SwitchNode *node, OrderedPredicate *predicate, pdl::PatternOp pattern)
Get or insert a child matcher for the given parent switch node, given a predicate and parent pattern.
static bool comparePosDepth(Position *lhs, Position *rhs)
Compares the depths of two positions.
static void visitUpward(std::vector< PositionalPredicate > &predList, OpIndex opIndex, PredicateBuilder &builder, DenseMap< Value, Position * > &valueToPosition, Position *&pos, unsigned rootID)
Visit a node during upward traversal.
static unsigned getNumNonRangeValues(ValueRange values)
Returns the number of non-range elements within values.
static SmallVector< Value > detectRoots(pdl::PatternOp pattern)
Given a pattern, determines the set of roots present in this pattern.
static void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp)
Sorts the range begin/end with the partial order given by cmp.
static void insertExitNode(std::unique_ptr< MatcherNode > *root)
Insert an exit node at the end of the failure path of the root.
static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate)
Returns true if the given matcher refers to the same predicate as the given ordered predicate.
static void foldSwitchToBool(std::unique_ptr< MatcherNode > &node)
Fold any switch nodes nested under node to boolean nodes when possible.
static void buildCostGraph(ArrayRef< Value > roots, RootOrderingGraph &graph, ParentMaps &parentMaps)
Given a list of candidate roots, builds the cost graph for connecting them.
static void getAttributePredicates(pdl::AttributeOp op, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
static void propagatePattern(std::unique_ptr< MatcherNode > &node, OrderedPredicateList &list, std::vector< OrderedPredicate * >::iterator current, std::vector< OrderedPredicate * >::iterator end)
Build the matcher CFG by "pushing" patterns through by sorted predicate order.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
type_range getType() const
Definition: ValueRange.cpp:32
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:247
type_range getTypes() const
Definition: ValueRange.cpp:38
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:107
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
This class represents the base of a predicate matcher node.
Definition: PredicateTree.h:50
Position * getPosition() const
Returns the position on which the question predicate should be checked.
Definition: PredicateTree.h:63
Qualifier * getQuestion() const
Returns the predicate checked on this node.
Definition: PredicateTree.h:66
The optimal branching algorithm solver.
Definition: RootOrdering.h:90
unsigned solve()
Runs the Edmonds' algorithm for the current graph, returning the total cost of the minimum-weight spa...
std::vector< std::pair< Value, Value > > EdgeList
A list of edges (child, parent).
Definition: RootOrdering.h:93
EdgeList preOrderTraversal(ArrayRef< Value > nodes) const
Returns the computed edges as visited in the preorder traversal.
A position describes a value on the input IR on which a predicate may be applied, such as an operatio...
Definition: Predicate.h:144
unsigned getOperationDepth() const
Returns the depth of the first ancestor operation position.
Definition: Predicate.cpp:21
Predicates::Kind getKind() const
Returns the kind of this position.
Definition: Predicate.h:156
const KeyTy & getValue() const
Return the key value of this predicate.
Definition: Predicate.h:111
This class provides utilities for constructing predicates.
Definition: Predicate.h:610
ConstraintPosition * getConstraintPosition(ConstraintQuestion *q, unsigned index)
Definition: Predicate.h:636
Position * getTypeLiteral(Attribute attr)
Returns a type position for the given type value.
Definition: Predicate.h:688
Predicate getOperandCount(unsigned count)
Create a predicate comparing the number of operands of an operation to a known value.
Definition: Predicate.h:740
OperationPosition * getPassthroughOp(Position *p)
Returns the operation position equivalent to the given position.
Definition: Predicate.h:630
Predicate getIsNotNull()
Create a predicate comparing a value with null.
Definition: Predicate.h:734
Predicate getOperandCountAtLeast(unsigned count)
Definition: Predicate.h:744
Predicate getResultCountAtLeast(unsigned count)
Definition: Predicate.h:761
Position * getType(Position *p)
Returns a type position for the given entity.
Definition: Predicate.h:684
Position * getAttribute(OperationPosition *p, StringRef name)
Returns an attribute position for an attribute of the given operation.
Definition: Predicate.h:642
Position * getOperandGroup(OperationPosition *p, std::optional< unsigned > group, bool isVariadic)
Returns a position for a group of operands of the given operation.
Definition: Predicate.h:661
Position * getForEach(Position *p, unsigned id)
Definition: Predicate.h:651
Position * getOperand(OperationPosition *p, unsigned operand)
Returns an operand position for an operand of the given operation.
Definition: Predicate.h:656
Position * getResult(OperationPosition *p, unsigned result)
Returns a result position for a result of the given operation.
Definition: Predicate.h:670
Position * getRoot()
Returns the root operation position.
Definition: Predicate.h:620
Predicate getAttributeConstraint(Attribute attr)
Create a predicate comparing an attribute to a known value.
Definition: Predicate.h:710
Position * getResultGroup(OperationPosition *p, std::optional< unsigned > group, bool isVariadic)
Returns a position for a group of results of the given operation.
Definition: Predicate.h:675
Position * getAllResults(OperationPosition *p)
Definition: Predicate.h:679
UsersPosition * getUsers(Position *p, bool useRepresentative)
Returns the users of a position using the value at the given operand.
Definition: Predicate.h:693
Predicate getTypeConstraint(Attribute type)
Create a predicate comparing the type of an attribute or value to a known type.
Definition: Predicate.h:769
OperationPosition * getOperandDefiningOp(Position *p)
Returns the parent position defining the value held by the given operand.
Definition: Predicate.h:623
Predicate getResultCount(unsigned count)
Create a predicate comparing the number of results of an operation to a known value.
Definition: Predicate.h:757
std::pair< Qualifier *, Qualifier * > Predicate
An ordinal predicate consists of a "Question" and a set of acceptable "Answers" (later converted to o...
Definition: Predicate.h:707
Predicate getEqualTo(Position *pos)
Create a predicate checking if two values are equal.
Definition: Predicate.h:716
Position * getAllOperands(OperationPosition *p)
Definition: Predicate.h:665
Position * getAttributeLiteral(Attribute attr)
Returns an attribute position for the given attribute.
Definition: Predicate.h:647
Predicate getConstraint(StringRef name, ArrayRef< Position * > args, ArrayRef< Type > resultTypes, bool isNegated)
Create a predicate that applies a generic constraint.
Definition: Predicate.h:726
Predicate getOperationName(StringRef name)
Create a predicate comparing the name of an operation to a known value.
Definition: Predicate.h:750
An ordinal predicate consists of a "Question" and a set of acceptable "Answers" (later converted to o...
Definition: Predicate.h:422
Predicates::Kind getKind() const
Returns the kind of this qualifier.
Definition: Predicate.h:427
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
bool operator<(const Fraction &x, const Fraction &y)
Definition: Fraction.h:83
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
A position describing an attribute of an operation.
Definition: Predicate.h:175
A BoolNode denotes a question with a boolean-like result.
BoolNode(Position *position, Qualifier *question, Qualifier *answer, std::unique_ptr< MatcherNode > successNode, std::unique_ptr< MatcherNode > failureNode=nullptr)
A position describing the result of a native constraint.
Definition: Predicate.h:301
Apply a parameterized constraint to multiple position values and possibly produce results.
Definition: Predicate.h:493
An operation position describes an operation node in the IR.
Definition: Predicate.h:259
bool isRoot() const
Returns if this operation position corresponds to the root.
Definition: Predicate.h:283
bool isOperandDefiningOp() const
Returns if this operation represents an operand defining op.
Definition: Predicate.cpp:55
A PositionalPredicate is a predicate that is associated with a specific positional value.
Definition: PredicateTree.h:30
The information associated with an edge in the cost graph.
Definition: RootOrdering.h:59
Value connector
The connector value in the intersection of the two subtrees rooted at the source and target root that...
Definition: RootOrdering.h:70
std::pair< unsigned, unsigned > cost
The depth of the connector Value w.r.t.
Definition: RootOrdering.h:65
A SuccessNode denotes that a given high level pattern has successfully been matched.
SuccessNode(pdl::PatternOp pattern, Value root, std::unique_ptr< MatcherNode > failureNode)
A SwitchNode denotes a question with multiple potential results.
llvm::MapVector< Qualifier *, std::unique_ptr< MatcherNode > > ChildMapT
Returns the children of this switch node.
SwitchNode(Position *position, Qualifier *question)
A position describing the result type of an entity, i.e.
Definition: Predicate.h:364