MLIR  20.0.0git
PredicateTree.cpp
Go to the documentation of this file.
1 //===- PredicateTree.cpp - Predicate tree merging -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PredicateTree.h"
10 #include "RootOrdering.h"
11 
15 #include "mlir/IR/BuiltinOps.h"
17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/SmallPtrSet.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 #include "llvm/Support/Debug.h"
21 #include <queue>
22 
23 #define DEBUG_TYPE "pdl-predicate-tree"
24 
25 using namespace mlir;
26 using namespace mlir::pdl_to_pdl_interp;
27 
28 //===----------------------------------------------------------------------===//
29 // Predicate List Building
30 //===----------------------------------------------------------------------===//
31 
32 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
33  Value val, PredicateBuilder &builder,
35  Position *pos);
36 
37 /// Compares the depths of two positions.
38 static bool comparePosDepth(Position *lhs, Position *rhs) {
39  return lhs->getOperationDepth() < rhs->getOperationDepth();
40 }
41 
42 /// Returns the number of non-range elements within `values`.
43 static unsigned getNumNonRangeValues(ValueRange values) {
44  return llvm::count_if(values.getTypes(),
45  [](Type type) { return !isa<pdl::RangeType>(type); });
46 }
47 
48 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
49  Value val, PredicateBuilder &builder,
51  AttributePosition *pos) {
52  assert(isa<pdl::AttributeType>(val.getType()) && "expected attribute type");
53  predList.emplace_back(pos, builder.getIsNotNull());
54 
55  if (auto attr = dyn_cast<pdl::AttributeOp>(val.getDefiningOp())) {
56  // If the attribute has a type or value, add a constraint.
57  if (Value type = attr.getValueType())
58  getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
59  else if (Attribute value = attr.getValueAttr())
60  predList.emplace_back(pos, builder.getAttributeConstraint(value));
61  }
62 }
63 
64 /// Collect all of the predicates for the given operand position.
65 static void getOperandTreePredicates(std::vector<PositionalPredicate> &predList,
66  Value val, PredicateBuilder &builder,
68  Position *pos) {
69  Type valueType = val.getType();
70  bool isVariadic = isa<pdl::RangeType>(valueType);
71 
72  // If this is a typed operand, add a type constraint.
74  .Case<pdl::OperandOp, pdl::OperandsOp>([&](auto op) {
75  // Prevent traversal into a null value if the operand has a proper
76  // index.
77  if (std::is_same<pdl::OperandOp, decltype(op)>::value ||
78  cast<OperandGroupPosition>(pos)->getOperandGroupNumber())
79  predList.emplace_back(pos, builder.getIsNotNull());
80 
81  if (Value type = op.getValueType())
82  getTreePredicates(predList, type, builder, inputs,
83  builder.getType(pos));
84  })
85  .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto op) {
86  std::optional<unsigned> index = op.getIndex();
87 
88  // Prevent traversal into a null value if the result has a proper index.
89  if (index)
90  predList.emplace_back(pos, builder.getIsNotNull());
91 
92  // Get the parent operation of this operand.
93  OperationPosition *parentPos = builder.getOperandDefiningOp(pos);
94  predList.emplace_back(parentPos, builder.getIsNotNull());
95 
96  // Ensure that the operands match the corresponding results of the
97  // parent operation.
98  Position *resultPos = nullptr;
99  if (std::is_same<pdl::ResultOp, decltype(op)>::value)
100  resultPos = builder.getResult(parentPos, *index);
101  else
102  resultPos = builder.getResultGroup(parentPos, index, isVariadic);
103  predList.emplace_back(resultPos, builder.getEqualTo(pos));
104 
105  // Collect the predicates of the parent operation.
106  getTreePredicates(predList, op.getParent(), builder, inputs,
107  (Position *)parentPos);
108  });
109 }
110 
111 static void
112 getTreePredicates(std::vector<PositionalPredicate> &predList, Value val,
113  PredicateBuilder &builder,
115  std::optional<unsigned> ignoreOperand = std::nullopt) {
116  assert(isa<pdl::OperationType>(val.getType()) && "expected operation");
117  pdl::OperationOp op = cast<pdl::OperationOp>(val.getDefiningOp());
118  OperationPosition *opPos = cast<OperationPosition>(pos);
119 
120  // Ensure getDefiningOp returns a non-null operation.
121  if (!opPos->isRoot())
122  predList.emplace_back(pos, builder.getIsNotNull());
123 
124  // Check that this is the correct root operation.
125  if (std::optional<StringRef> opName = op.getOpName())
126  predList.emplace_back(pos, builder.getOperationName(*opName));
127 
128  // Check that the operation has the proper number of operands. If there are
129  // any variable length operands, we check a minimum instead of an exact count.
130  OperandRange operands = op.getOperandValues();
131  unsigned minOperands = getNumNonRangeValues(operands);
132  if (minOperands != operands.size()) {
133  if (minOperands)
134  predList.emplace_back(pos, builder.getOperandCountAtLeast(minOperands));
135  } else {
136  predList.emplace_back(pos, builder.getOperandCount(minOperands));
137  }
138 
139  // Check that the operation has the proper number of results. If there are
140  // any variable length results, we check a minimum instead of an exact count.
141  OperandRange types = op.getTypeValues();
142  unsigned minResults = getNumNonRangeValues(types);
143  if (minResults == types.size())
144  predList.emplace_back(pos, builder.getResultCount(types.size()));
145  else if (minResults)
146  predList.emplace_back(pos, builder.getResultCountAtLeast(minResults));
147 
148  // Recurse into any attributes, operands, or results.
149  for (auto [attrName, attr] :
150  llvm::zip(op.getAttributeValueNames(), op.getAttributeValues())) {
152  predList, attr, builder, inputs,
153  builder.getAttribute(opPos, cast<StringAttr>(attrName).getValue()));
154  }
155 
156  // Process the operands and results of the operation. For all values up to
157  // the first variable length value, we use the concrete operand/result
158  // number. After that, we use the "group" given that we can't know the
159  // concrete indices until runtime. If there is only one variadic operand
160  // group, we treat it as all of the operands/results of the operation.
161  /// Operands.
162  if (operands.size() == 1 && isa<pdl::RangeType>(operands[0].getType())) {
163  // Ignore the operands if we are performing an upward traversal (in that
164  // case, they have already been visited).
165  if (opPos->isRoot() || opPos->isOperandDefiningOp())
166  getTreePredicates(predList, operands.front(), builder, inputs,
167  builder.getAllOperands(opPos));
168  } else {
169  bool foundVariableLength = false;
170  for (const auto &operandIt : llvm::enumerate(operands)) {
171  bool isVariadic = isa<pdl::RangeType>(operandIt.value().getType());
172  foundVariableLength |= isVariadic;
173 
174  // Ignore the specified operand, usually because this position was
175  // visited in an upward traversal via an iterative choice.
176  if (ignoreOperand && *ignoreOperand == operandIt.index())
177  continue;
178 
179  Position *pos =
180  foundVariableLength
181  ? builder.getOperandGroup(opPos, operandIt.index(), isVariadic)
182  : builder.getOperand(opPos, operandIt.index());
183  getTreePredicates(predList, operandIt.value(), builder, inputs, pos);
184  }
185  }
186  /// Results.
187  if (types.size() == 1 && isa<pdl::RangeType>(types[0].getType())) {
188  getTreePredicates(predList, types.front(), builder, inputs,
189  builder.getType(builder.getAllResults(opPos)));
190  return;
191  }
192 
193  bool foundVariableLength = false;
194  for (auto [idx, typeValue] : llvm::enumerate(types)) {
195  bool isVariadic = isa<pdl::RangeType>(typeValue.getType());
196  foundVariableLength |= isVariadic;
197 
198  auto *resultPos = foundVariableLength
199  ? builder.getResultGroup(pos, idx, isVariadic)
200  : builder.getResult(pos, idx);
201  predList.emplace_back(resultPos, builder.getIsNotNull());
202  getTreePredicates(predList, typeValue, builder, inputs,
203  builder.getType(resultPos));
204  }
205 }
206 
207 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
208  Value val, PredicateBuilder &builder,
210  TypePosition *pos) {
211  // Check for a constraint on a constant type.
212  if (pdl::TypeOp typeOp = val.getDefiningOp<pdl::TypeOp>()) {
213  if (Attribute type = typeOp.getConstantTypeAttr())
214  predList.emplace_back(pos, builder.getTypeConstraint(type));
215  } else if (pdl::TypesOp typeOp = val.getDefiningOp<pdl::TypesOp>()) {
216  if (Attribute typeAttr = typeOp.getConstantTypesAttr())
217  predList.emplace_back(pos, builder.getTypeConstraint(typeAttr));
218  }
219 }
220 
221 /// Collect the tree predicates anchored at the given value.
222 static void getTreePredicates(std::vector<PositionalPredicate> &predList,
223  Value val, PredicateBuilder &builder,
225  Position *pos) {
226  // Make sure this input value is accessible to the rewrite.
227  auto it = inputs.try_emplace(val, pos);
228  if (!it.second) {
229  // If this is an input value that has been visited in the tree, add a
230  // constraint to ensure that both instances refer to the same value.
231  if (isa<pdl::AttributeOp, pdl::OperandOp, pdl::OperandsOp, pdl::OperationOp,
232  pdl::TypeOp>(val.getDefiningOp())) {
233  auto minMaxPositions =
234  std::minmax(pos, it.first->second, comparePosDepth);
235  predList.emplace_back(minMaxPositions.second,
236  builder.getEqualTo(minMaxPositions.first));
237  }
238  return;
239  }
240 
242  .Case<AttributePosition, OperationPosition, TypePosition>([&](auto *pos) {
243  getTreePredicates(predList, val, builder, inputs, pos);
244  })
245  .Case<OperandPosition, OperandGroupPosition>([&](auto *pos) {
246  getOperandTreePredicates(predList, val, builder, inputs, pos);
247  })
248  .Default([](auto *) { llvm_unreachable("unexpected position kind"); });
249 }
250 
251 static void getAttributePredicates(pdl::AttributeOp op,
252  std::vector<PositionalPredicate> &predList,
253  PredicateBuilder &builder,
254  DenseMap<Value, Position *> &inputs) {
255  Position *&attrPos = inputs[op];
256  if (attrPos)
257  return;
258  Attribute value = op.getValueAttr();
259  assert(value && "expected non-tree `pdl.attribute` to contain a value");
260  attrPos = builder.getAttributeLiteral(value);
261 }
262 
263 static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
264  std::vector<PositionalPredicate> &predList,
265  PredicateBuilder &builder,
266  DenseMap<Value, Position *> &inputs) {
267  OperandRange arguments = op.getArgs();
268 
269  std::vector<Position *> allPositions;
270  allPositions.reserve(arguments.size());
271  for (Value arg : arguments)
272  allPositions.push_back(inputs.lookup(arg));
273 
274  // Push the constraint to the furthest position.
275  Position *pos = *llvm::max_element(allPositions, comparePosDepth);
276  ResultRange results = op.getResults();
278  op.getName(), allPositions, SmallVector<Type>(results.getTypes()),
279  op.getIsNegated());
280 
281  // For each result register a position so it can be used later
282  for (auto [i, result] : llvm::enumerate(results)) {
283  ConstraintQuestion *q = cast<ConstraintQuestion>(pred.first);
284  ConstraintPosition *pos = builder.getConstraintPosition(q, i);
285  auto [it, inserted] = inputs.try_emplace(result, pos);
286  // If this is an input value that has been visited in the tree, add a
287  // constraint to ensure that both instances refer to the same value.
288  if (!inserted) {
289  Position *first = pos;
290  Position *second = it->second;
291  if (comparePosDepth(second, first))
292  std::tie(second, first) = std::make_pair(first, second);
293 
294  predList.emplace_back(second, builder.getEqualTo(first));
295  }
296  }
297  predList.emplace_back(pos, pred);
298 }
299 
300 static void getResultPredicates(pdl::ResultOp op,
301  std::vector<PositionalPredicate> &predList,
302  PredicateBuilder &builder,
303  DenseMap<Value, Position *> &inputs) {
304  Position *&resultPos = inputs[op];
305  if (resultPos)
306  return;
307 
308  // Ensure that the result isn't null.
309  auto *parentPos = cast<OperationPosition>(inputs.lookup(op.getParent()));
310  resultPos = builder.getResult(parentPos, op.getIndex());
311  predList.emplace_back(resultPos, builder.getIsNotNull());
312 }
313 
314 static void getResultPredicates(pdl::ResultsOp op,
315  std::vector<PositionalPredicate> &predList,
316  PredicateBuilder &builder,
317  DenseMap<Value, Position *> &inputs) {
318  Position *&resultPos = inputs[op];
319  if (resultPos)
320  return;
321 
322  // Ensure that the result isn't null if the result has an index.
323  auto *parentPos = cast<OperationPosition>(inputs.lookup(op.getParent()));
324  bool isVariadic = isa<pdl::RangeType>(op.getType());
325  std::optional<unsigned> index = op.getIndex();
326  resultPos = builder.getResultGroup(parentPos, index, isVariadic);
327  if (index)
328  predList.emplace_back(resultPos, builder.getIsNotNull());
329 }
330 
331 static void getTypePredicates(Value typeValue,
332  function_ref<Attribute()> typeAttrFn,
333  PredicateBuilder &builder,
334  DenseMap<Value, Position *> &inputs) {
335  Position *&typePos = inputs[typeValue];
336  if (typePos)
337  return;
338  Attribute typeAttr = typeAttrFn();
339  assert(typeAttr &&
340  "expected non-tree `pdl.type`/`pdl.types` to contain a value");
341  typePos = builder.getTypeLiteral(typeAttr);
342 }
343 
344 /// Collect all of the predicates that cannot be determined via walking the
345 /// tree.
346 static void getNonTreePredicates(pdl::PatternOp pattern,
347  std::vector<PositionalPredicate> &predList,
348  PredicateBuilder &builder,
349  DenseMap<Value, Position *> &inputs) {
350  for (Operation &op : pattern.getBodyRegion().getOps()) {
352  .Case([&](pdl::AttributeOp attrOp) {
353  getAttributePredicates(attrOp, predList, builder, inputs);
354  })
355  .Case<pdl::ApplyNativeConstraintOp>([&](auto constraintOp) {
356  getConstraintPredicates(constraintOp, predList, builder, inputs);
357  })
358  .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
359  getResultPredicates(resultOp, predList, builder, inputs);
360  })
361  .Case([&](pdl::TypeOp typeOp) {
363  typeOp, [&] { return typeOp.getConstantTypeAttr(); }, builder,
364  inputs);
365  })
366  .Case([&](pdl::TypesOp typeOp) {
368  typeOp, [&] { return typeOp.getConstantTypesAttr(); }, builder,
369  inputs);
370  });
371  }
372 }
373 
374 namespace {
375 
376 /// An op accepting a value at an optional index.
377 struct OpIndex {
378  Value parent;
379  std::optional<unsigned> index;
380 };
381 
382 /// The parent and operand index of each operation for each root, stored
383 /// as a nested map [root][operation].
384 using ParentMaps = DenseMap<Value, DenseMap<Value, OpIndex>>;
385 
386 } // namespace
387 
388 /// Given a pattern, determines the set of roots present in this pattern.
389 /// These are the operations whose results are not consumed by other operations.
390 static SmallVector<Value> detectRoots(pdl::PatternOp pattern) {
391  // First, collect all the operations that are used as operands
392  // to other operations. These are not roots by default.
393  DenseSet<Value> used;
394  for (auto operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>()) {
395  for (Value operand : operationOp.getOperandValues())
396  TypeSwitch<Operation *>(operand.getDefiningOp())
397  .Case<pdl::ResultOp, pdl::ResultsOp>(
398  [&used](auto resultOp) { used.insert(resultOp.getParent()); });
399  }
400 
401  // Remove the specified root from the use set, so that we can
402  // always select it as a root, even if it is used by other operations.
403  if (Value root = pattern.getRewriter().getRoot())
404  used.erase(root);
405 
406  // Finally, collect all the unused operations.
407  SmallVector<Value> roots;
408  for (Value operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>())
409  if (!used.contains(operationOp))
410  roots.push_back(operationOp);
411 
412  return roots;
413 }
414 
415 /// Given a list of candidate roots, builds the cost graph for connecting them.
416 /// The graph is formed by traversing the DAG of operations starting from each
417 /// root and marking the depth of each connector value (operand). Then we join
418 /// the candidate roots based on the common connector values, taking the one
419 /// with the minimum depth. Along the way, we compute, for each candidate root,
420 /// a mapping from each operation (in the DAG underneath this root) to its
421 /// parent operation and the corresponding operand index.
423  ParentMaps &parentMaps) {
424 
425  // The entry of a queue. The entry consists of the following items:
426  // * the value in the DAG underneath the root;
427  // * the parent of the value;
428  // * the operand index of the value in its parent;
429  // * the depth of the visited value.
430  struct Entry {
431  Entry(Value value, Value parent, std::optional<unsigned> index,
432  unsigned depth)
433  : value(value), parent(parent), index(index), depth(depth) {}
434 
435  Value value;
436  Value parent;
437  std::optional<unsigned> index;
438  unsigned depth;
439  };
440 
441  // A root of a value and its depth (distance from root to the value).
442  struct RootDepth {
443  Value root;
444  unsigned depth = 0;
445  };
446 
447  // Map from candidate connector values to their roots and depths. Using a
448  // small vector with 1 entry because most values belong to a single root.
449  llvm::MapVector<Value, SmallVector<RootDepth, 1>> connectorsRootsDepths;
450 
451  // Perform a breadth-first traversal of the op DAG rooted at each root.
452  for (Value root : roots) {
453  // The queue of visited values. A value may be present multiple times in
454  // the queue, for multiple parents. We only accept the first occurrence,
455  // which is guaranteed to have the lowest depth.
456  std::queue<Entry> toVisit;
457  toVisit.emplace(root, Value(), 0, 0);
458 
459  // The map from value to its parent for the current root.
460  DenseMap<Value, OpIndex> &parentMap = parentMaps[root];
461 
462  while (!toVisit.empty()) {
463  Entry entry = toVisit.front();
464  toVisit.pop();
465  // Skip if already visited.
466  if (!parentMap.insert({entry.value, {entry.parent, entry.index}}).second)
467  continue;
468 
469  // Mark the root and depth of the value.
470  connectorsRootsDepths[entry.value].push_back({root, entry.depth});
471 
472  // Traverse the operands of an operation and result ops.
473  // We intentionally do not traverse attributes and types, because those
474  // are expensive to join on.
475  TypeSwitch<Operation *>(entry.value.getDefiningOp())
476  .Case<pdl::OperationOp>([&](auto operationOp) {
477  OperandRange operands = operationOp.getOperandValues();
478  // Special case when we pass all the operands in one range.
479  // For those, the index is empty.
480  if (operands.size() == 1 &&
481  isa<pdl::RangeType>(operands[0].getType())) {
482  toVisit.emplace(operands[0], entry.value, std::nullopt,
483  entry.depth + 1);
484  return;
485  }
486 
487  // Default case: visit all the operands.
488  for (const auto &p :
489  llvm::enumerate(operationOp.getOperandValues()))
490  toVisit.emplace(p.value(), entry.value, p.index(),
491  entry.depth + 1);
492  })
493  .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
494  toVisit.emplace(resultOp.getParent(), entry.value,
495  resultOp.getIndex(), entry.depth);
496  });
497  }
498  }
499 
500  // Now build the cost graph.
501  // This is simply a minimum over all depths for the target root.
502  unsigned nextID = 0;
503  for (const auto &connectorRootsDepths : connectorsRootsDepths) {
504  Value value = connectorRootsDepths.first;
505  ArrayRef<RootDepth> rootsDepths = connectorRootsDepths.second;
506  // If there is only one root for this value, this will not trigger
507  // any edges in the cost graph (a perf optimization).
508  if (rootsDepths.size() == 1)
509  continue;
510 
511  for (const RootDepth &p : rootsDepths) {
512  for (const RootDepth &q : rootsDepths) {
513  if (&p == &q)
514  continue;
515  // Insert or retrieve the property of edge from p to q.
516  RootOrderingEntry &entry = graph[q.root][p.root];
517  if (!entry.connector /* new edge */ || entry.cost.first > q.depth) {
518  if (!entry.connector)
519  entry.cost.second = nextID++;
520  entry.cost.first = q.depth;
521  entry.connector = value;
522  }
523  }
524  }
525  }
526 
527  assert((llvm::hasSingleElement(roots) || graph.size() == roots.size()) &&
528  "the pattern contains a candidate root disconnected from the others");
529 }
530 
531 /// Returns true if the operand at the given index needs to be queried using an
532 /// operand group, i.e., if it is variadic itself or follows a variadic operand.
533 static bool useOperandGroup(pdl::OperationOp op, unsigned index) {
534  OperandRange operands = op.getOperandValues();
535  assert(index < operands.size() && "operand index out of range");
536  for (unsigned i = 0; i <= index; ++i)
537  if (isa<pdl::RangeType>(operands[i].getType()))
538  return true;
539  return false;
540 }
541 
542 /// Visit a node during upward traversal.
543 static void visitUpward(std::vector<PositionalPredicate> &predList,
544  OpIndex opIndex, PredicateBuilder &builder,
545  DenseMap<Value, Position *> &valueToPosition,
546  Position *&pos, unsigned rootID) {
547  Value value = opIndex.parent;
549  .Case<pdl::OperationOp>([&](auto operationOp) {
550  LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
551 
552  // Get users and iterate over them.
553  Position *usersPos = builder.getUsers(pos, /*useRepresentative=*/true);
554  Position *foreachPos = builder.getForEach(usersPos, rootID);
555  OperationPosition *opPos = builder.getPassthroughOp(foreachPos);
556 
557  // Compare the operand(s) of the user against the input value(s).
558  Position *operandPos;
559  if (!opIndex.index) {
560  // We are querying all the operands of the operation.
561  operandPos = builder.getAllOperands(opPos);
562  } else if (useOperandGroup(operationOp, *opIndex.index)) {
563  // We are querying an operand group.
564  Type type = operationOp.getOperandValues()[*opIndex.index].getType();
565  bool variadic = isa<pdl::RangeType>(type);
566  operandPos = builder.getOperandGroup(opPos, opIndex.index, variadic);
567  } else {
568  // We are querying an individual operand.
569  operandPos = builder.getOperand(opPos, *opIndex.index);
570  }
571  predList.emplace_back(operandPos, builder.getEqualTo(pos));
572 
573  // Guard against duplicate upward visits. These are not possible,
574  // because if this value was already visited, it would have been
575  // cheaper to start the traversal at this value rather than at the
576  // `connector`, violating the optimality of our spanning tree.
577  bool inserted = valueToPosition.try_emplace(value, opPos).second;
578  (void)inserted;
579  assert(inserted && "duplicate upward visit");
580 
581  // Obtain the tree predicates at the current value.
582  getTreePredicates(predList, value, builder, valueToPosition, opPos,
583  opIndex.index);
584 
585  // Update the position
586  pos = opPos;
587  })
588  .Case<pdl::ResultOp>([&](auto resultOp) {
589  // Traverse up an individual result.
590  auto *opPos = dyn_cast<OperationPosition>(pos);
591  assert(opPos && "operations and results must be interleaved");
592  pos = builder.getResult(opPos, *opIndex.index);
593 
594  // Insert the result position in case we have not visited it yet.
595  valueToPosition.try_emplace(value, pos);
596  })
597  .Case<pdl::ResultsOp>([&](auto resultOp) {
598  // Traverse up a group of results.
599  auto *opPos = dyn_cast<OperationPosition>(pos);
600  assert(opPos && "operations and results must be interleaved");
601  bool isVariadic = isa<pdl::RangeType>(value.getType());
602  if (opIndex.index)
603  pos = builder.getResultGroup(opPos, opIndex.index, isVariadic);
604  else
605  pos = builder.getAllResults(opPos);
606 
607  // Insert the result position in case we have not visited it yet.
608  valueToPosition.try_emplace(value, pos);
609  });
610 }
611 
612 /// Given a pattern operation, build the set of matcher predicates necessary to
613 /// match this pattern.
614 static Value buildPredicateList(pdl::PatternOp pattern,
615  PredicateBuilder &builder,
616  std::vector<PositionalPredicate> &predList,
617  DenseMap<Value, Position *> &valueToPosition) {
618  SmallVector<Value> roots = detectRoots(pattern);
619 
620  // Build the root ordering graph and compute the parent maps.
621  RootOrderingGraph graph;
622  ParentMaps parentMaps;
623  buildCostGraph(roots, graph, parentMaps);
624  LLVM_DEBUG({
625  llvm::dbgs() << "Graph:\n";
626  for (auto &target : graph) {
627  llvm::dbgs() << " * " << target.first.getLoc() << " " << target.first
628  << "\n";
629  for (auto &source : target.second) {
630  RootOrderingEntry &entry = source.second;
631  llvm::dbgs() << " <- " << source.first << ": " << entry.cost.first
632  << ":" << entry.cost.second << " via "
633  << entry.connector.getLoc() << "\n";
634  }
635  }
636  });
637 
638  // Solve the optimal branching problem for each candidate root, or use the
639  // provided one.
640  Value bestRoot = pattern.getRewriter().getRoot();
641  OptimalBranching::EdgeList bestEdges;
642  if (!bestRoot) {
643  unsigned bestCost = 0;
644  LLVM_DEBUG(llvm::dbgs() << "Candidate roots:\n");
645  for (Value root : roots) {
646  OptimalBranching solver(graph, root);
647  unsigned cost = solver.solve();
648  LLVM_DEBUG(llvm::dbgs() << " * " << root << ": " << cost << "\n");
649  if (!bestRoot || bestCost > cost) {
650  bestCost = cost;
651  bestRoot = root;
652  bestEdges = solver.preOrderTraversal(roots);
653  }
654  }
655  } else {
656  OptimalBranching solver(graph, bestRoot);
657  solver.solve();
658  bestEdges = solver.preOrderTraversal(roots);
659  }
660 
661  // Print the best solution.
662  LLVM_DEBUG({
663  llvm::dbgs() << "Best tree:\n";
664  for (const std::pair<Value, Value> &edge : bestEdges) {
665  llvm::dbgs() << " * " << edge.first;
666  if (edge.second)
667  llvm::dbgs() << " <- " << edge.second;
668  llvm::dbgs() << "\n";
669  }
670  });
671 
672  LLVM_DEBUG(llvm::dbgs() << "Calling key getTreePredicates:\n");
673  LLVM_DEBUG(llvm::dbgs() << " * Value: " << bestRoot << "\n");
674 
675  // The best root is the starting point for the traversal. Get the tree
676  // predicates for the DAG rooted at bestRoot.
677  getTreePredicates(predList, bestRoot, builder, valueToPosition,
678  builder.getRoot());
679 
680  // Traverse the selected optimal branching. For all edges in order, traverse
681  // up starting from the connector, until the candidate root is reached, and
682  // call getTreePredicates at every node along the way.
683  for (const auto &it : llvm::enumerate(bestEdges)) {
684  Value target = it.value().first;
685  Value source = it.value().second;
686 
687  // Check if we already visited the target root. This happens in two cases:
688  // 1) the initial root (bestRoot);
689  // 2) a root that is dominated by (contained in the subtree rooted at) an
690  // already visited root.
691  if (valueToPosition.count(target))
692  continue;
693 
694  // Determine the connector.
695  Value connector = graph[target][source].connector;
696  assert(connector && "invalid edge");
697  LLVM_DEBUG(llvm::dbgs() << " * Connector: " << connector.getLoc() << "\n");
698  DenseMap<Value, OpIndex> parentMap = parentMaps.lookup(target);
699  Position *pos = valueToPosition.lookup(connector);
700  assert(pos && "connector has not been traversed yet");
701 
702  // Traverse from the connector upwards towards the target root.
703  for (Value value = connector; value != target;) {
704  OpIndex opIndex = parentMap.lookup(value);
705  assert(opIndex.parent && "missing parent");
706  visitUpward(predList, opIndex, builder, valueToPosition, pos, it.index());
707  value = opIndex.parent;
708  }
709  }
710 
711  getNonTreePredicates(pattern, predList, builder, valueToPosition);
712 
713  return bestRoot;
714 }
715 
716 //===----------------------------------------------------------------------===//
717 // Pattern Predicate Tree Merging
718 //===----------------------------------------------------------------------===//
719 
720 namespace {
721 
722 /// This class represents a specific predicate applied to a position, and
723 /// provides hashing and ordering operators. This class allows for computing a
724 /// frequence sum and ordering predicates based on a cost model.
725 struct OrderedPredicate {
726  OrderedPredicate(const std::pair<Position *, Qualifier *> &ip)
727  : position(ip.first), question(ip.second) {}
728  OrderedPredicate(const PositionalPredicate &ip)
729  : position(ip.position), question(ip.question) {}
730 
731  /// The position this predicate is applied to.
732  Position *position;
733 
734  /// The question that is applied by this predicate onto the position.
735  Qualifier *question;
736 
737  /// The first and second order benefit sums.
738  /// The primary sum is the number of occurrences of this predicate among all
739  /// of the patterns.
740  unsigned primary = 0;
741  /// The secondary sum is a squared summation of the primary sum of all of the
742  /// predicates within each pattern that contains this predicate. This allows
743  /// for favoring predicates that are more commonly shared within a pattern, as
744  /// opposed to those shared across patterns.
745  unsigned secondary = 0;
746 
747  /// The tie breaking ID, used to preserve a deterministic (insertion) order
748  /// among all the predicates with the same priority, depth, and position /
749  /// predicate dependency.
750  unsigned id = 0;
751 
752  /// A map between a pattern operation and the answer to the predicate question
753  /// within that pattern.
754  DenseMap<Operation *, Qualifier *> patternToAnswer;
755 
756  /// Returns true if this predicate is ordered before `rhs`, based on the cost
757  /// model.
758  bool operator<(const OrderedPredicate &rhs) const {
759  // Sort by:
760  // * higher first and secondary order sums
761  // * lower depth
762  // * lower position dependency
763  // * lower predicate dependency
764  // * lower tie breaking ID
765  auto *rhsPos = rhs.position;
766  return std::make_tuple(primary, secondary, rhsPos->getOperationDepth(),
767  rhsPos->getKind(), rhs.question->getKind(), rhs.id) >
768  std::make_tuple(rhs.primary, rhs.secondary,
769  position->getOperationDepth(), position->getKind(),
770  question->getKind(), id);
771  }
772 };
773 
774 /// A DenseMapInfo for OrderedPredicate based solely on the position and
775 /// question.
776 struct OrderedPredicateDenseInfo {
778 
779  static OrderedPredicate getEmptyKey() { return Base::getEmptyKey(); }
780  static OrderedPredicate getTombstoneKey() { return Base::getTombstoneKey(); }
781  static bool isEqual(const OrderedPredicate &lhs,
782  const OrderedPredicate &rhs) {
783  return lhs.position == rhs.position && lhs.question == rhs.question;
784  }
785  static unsigned getHashValue(const OrderedPredicate &p) {
786  return llvm::hash_combine(p.position, p.question);
787  }
788 };
789 
790 /// This class wraps a set of ordered predicates that are used within a specific
791 /// pattern operation.
792 struct OrderedPredicateList {
793  OrderedPredicateList(pdl::PatternOp pattern, Value root)
794  : pattern(pattern), root(root) {}
795 
796  pdl::PatternOp pattern;
797  Value root;
798  DenseSet<OrderedPredicate *> predicates;
799 };
800 } // namespace
801 
802 /// Returns true if the given matcher refers to the same predicate as the given
803 /// ordered predicate. This means that the position and questions of the two
804 /// match.
805 static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate) {
806  return node->getPosition() == predicate->position &&
807  node->getQuestion() == predicate->question;
808 }
809 
810 /// Get or insert a child matcher for the given parent switch node, given a
811 /// predicate and parent pattern.
812 std::unique_ptr<MatcherNode> &getOrCreateChild(SwitchNode *node,
813  OrderedPredicate *predicate,
814  pdl::PatternOp pattern) {
815  assert(isSamePredicate(node, predicate) &&
816  "expected matcher to equal the given predicate");
817 
818  auto it = predicate->patternToAnswer.find(pattern);
819  assert(it != predicate->patternToAnswer.end() &&
820  "expected pattern to exist in predicate");
821  return node->getChildren().insert({it->second, nullptr}).first->second;
822 }
823 
824 /// Build the matcher CFG by "pushing" patterns through by sorted predicate
825 /// order. A pattern will traverse as far as possible using common predicates
826 /// and then either diverge from the CFG or reach the end of a branch and start
827 /// creating new nodes.
828 static void propagatePattern(std::unique_ptr<MatcherNode> &node,
829  OrderedPredicateList &list,
830  std::vector<OrderedPredicate *>::iterator current,
831  std::vector<OrderedPredicate *>::iterator end) {
832  if (current == end) {
833  // We've hit the end of a pattern, so create a successful result node.
834  node =
835  std::make_unique<SuccessNode>(list.pattern, list.root, std::move(node));
836 
837  // If the pattern doesn't contain this predicate, ignore it.
838  } else if (!list.predicates.contains(*current)) {
839  propagatePattern(node, list, std::next(current), end);
840 
841  // If the current matcher node is invalid, create a new one for this
842  // position and continue propagation.
843  } else if (!node) {
844  // Create a new node at this position and continue
845  node = std::make_unique<SwitchNode>((*current)->position,
846  (*current)->question);
848  getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
849  list, std::next(current), end);
850 
851  // If the matcher has already been created, and it is for this predicate we
852  // continue propagation to the child.
853  } else if (isSamePredicate(node.get(), *current)) {
855  getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
856  list, std::next(current), end);
857 
858  // If the matcher doesn't match the current predicate, insert a branch as
859  // the common set of matchers has diverged.
860  } else {
861  propagatePattern(node->getFailureNode(), list, current, end);
862  }
863 }
864 
865 /// Fold any switch nodes nested under `node` to boolean nodes when possible.
866 /// `node` is updated in-place if it is a switch.
867 static void foldSwitchToBool(std::unique_ptr<MatcherNode> &node) {
868  if (!node)
869  return;
870 
871  if (SwitchNode *switchNode = dyn_cast<SwitchNode>(&*node)) {
872  SwitchNode::ChildMapT &children = switchNode->getChildren();
873  for (auto &it : children)
874  foldSwitchToBool(it.second);
875 
876  // If the node only contains one child, collapse it into a boolean predicate
877  // node.
878  if (children.size() == 1) {
879  auto *childIt = children.begin();
880  node = std::make_unique<BoolNode>(
881  node->getPosition(), node->getQuestion(), childIt->first,
882  std::move(childIt->second), std::move(node->getFailureNode()));
883  }
884  } else if (BoolNode *boolNode = dyn_cast<BoolNode>(&*node)) {
885  foldSwitchToBool(boolNode->getSuccessNode());
886  }
887 
888  foldSwitchToBool(node->getFailureNode());
889 }
890 
891 /// Insert an exit node at the end of the failure path of the `root`.
892 static void insertExitNode(std::unique_ptr<MatcherNode> *root) {
893  while (*root)
894  root = &(*root)->getFailureNode();
895  *root = std::make_unique<ExitNode>();
896 }
897 
898 /// Sorts the range begin/end with the partial order given by cmp.
899 template <typename Iterator, typename Compare>
900 static void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp) {
901  while (begin != end) {
902  // Cannot compute sortBeforeOthers in the predicate of stable_partition
903  // because stable_partition will not keep the [begin, end) range intact
904  // while it runs.
906  for (auto i = begin; i != end; ++i) {
907  if (std::none_of(begin, end, [&](auto const &b) { return cmp(b, *i); }))
908  sortBeforeOthers.insert(*i);
909  }
910 
911  auto const next = std::stable_partition(begin, end, [&](auto const &a) {
912  return sortBeforeOthers.contains(a);
913  });
914  assert(next != begin && "not a partial ordering");
915  begin = next;
916  }
917 }
918 
919 /// Returns true if 'b' depends on a result of 'a'.
920 static bool dependsOn(OrderedPredicate *a, OrderedPredicate *b) {
921  auto *cqa = dyn_cast<ConstraintQuestion>(a->question);
922  if (!cqa)
923  return false;
924 
925  auto positionDependsOnA = [&](Position *p) {
926  auto *cp = dyn_cast<ConstraintPosition>(p);
927  return cp && cp->getQuestion() == cqa;
928  };
929 
930  if (auto *cqb = dyn_cast<ConstraintQuestion>(b->question)) {
931  // Does any argument of b use a?
932  return llvm::any_of(cqb->getArgs(), positionDependsOnA);
933  }
934  if (auto *equalTo = dyn_cast<EqualToQuestion>(b->question)) {
935  return positionDependsOnA(b->position) ||
936  positionDependsOnA(equalTo->getValue());
937  }
938  return positionDependsOnA(b->position);
939 }
940 
941 /// Given a module containing PDL pattern operations, generate a matcher tree
942 /// using the patterns within the given module and return the root matcher node.
943 std::unique_ptr<MatcherNode>
944 MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
945  DenseMap<Value, Position *> &valueToPosition) {
946  // The set of predicates contained within the pattern operations of the
947  // module.
948  struct PatternPredicates {
949  PatternPredicates(pdl::PatternOp pattern, Value root,
950  std::vector<PositionalPredicate> predicates)
951  : pattern(pattern), root(root), predicates(std::move(predicates)) {}
952 
953  /// A pattern.
954  pdl::PatternOp pattern;
955 
956  /// A root of the pattern chosen among the candidate roots in pdl.rewrite.
957  Value root;
958 
959  /// The extracted predicates for this pattern and root.
960  std::vector<PositionalPredicate> predicates;
961  };
962 
963  SmallVector<PatternPredicates, 16> patternsAndPredicates;
964  for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
965  std::vector<PositionalPredicate> predicateList;
966  Value root =
967  buildPredicateList(pattern, builder, predicateList, valueToPosition);
968  patternsAndPredicates.emplace_back(pattern, root, std::move(predicateList));
969  }
970 
971  // Associate a pattern result with each unique predicate.
973  for (auto &patternAndPredList : patternsAndPredicates) {
974  for (auto &predicate : patternAndPredList.predicates) {
975  auto it = uniqued.insert(predicate);
976  it.first->patternToAnswer.try_emplace(patternAndPredList.pattern,
977  predicate.answer);
978  // Mark the insertion order (0-based indexing).
979  if (it.second)
980  it.first->id = uniqued.size() - 1;
981  }
982  }
983 
984  // Associate each pattern to a set of its ordered predicates for later lookup.
985  std::vector<OrderedPredicateList> lists;
986  lists.reserve(patternsAndPredicates.size());
987  for (auto &patternAndPredList : patternsAndPredicates) {
988  OrderedPredicateList list(patternAndPredList.pattern,
989  patternAndPredList.root);
990  for (auto &predicate : patternAndPredList.predicates) {
991  OrderedPredicate *orderedPredicate = &*uniqued.find(predicate);
992  list.predicates.insert(orderedPredicate);
993 
994  // Increment the primary sum for each reference to a particular predicate.
995  ++orderedPredicate->primary;
996  }
997  lists.push_back(std::move(list));
998  }
999 
1000  // For a particular pattern, get the total primary sum and add it to the
1001  // secondary sum of each predicate. Square the primary sums to emphasize
1002  // shared predicates within rather than across patterns.
1003  for (auto &list : lists) {
1004  unsigned total = 0;
1005  for (auto *predicate : list.predicates)
1006  total += predicate->primary * predicate->primary;
1007  for (auto *predicate : list.predicates)
1008  predicate->secondary += total;
1009  }
1010 
1011  // Sort the set of predicates now that the cost primary and secondary sums
1012  // have been computed.
1013  std::vector<OrderedPredicate *> ordered;
1014  ordered.reserve(uniqued.size());
1015  for (auto &ip : uniqued)
1016  ordered.push_back(&ip);
1017  llvm::sort(ordered, [](OrderedPredicate *lhs, OrderedPredicate *rhs) {
1018  return *lhs < *rhs;
1019  });
1020 
1021  // Mostly keep the now established order, but also ensure that
1022  // ConstraintQuestions come after the results they use.
1023  stableTopologicalSort(ordered.begin(), ordered.end(), dependsOn);
1024 
1025  // Build the matchers for each of the pattern predicate lists.
1026  std::unique_ptr<MatcherNode> root;
1027  for (OrderedPredicateList &list : lists)
1028  propagatePattern(root, list, ordered.begin(), ordered.end());
1029 
1030  // Collapse the graph and insert the exit node.
1031  foldSwitchToBool(root);
1032  insertExitNode(&root);
1033  return root;
1034 }
1035 
1036 //===----------------------------------------------------------------------===//
1037 // MatcherNode
1038 //===----------------------------------------------------------------------===//
1039 
1040 MatcherNode::MatcherNode(TypeID matcherTypeID, Position *p, Qualifier *q,
1041  std::unique_ptr<MatcherNode> failureNode)
1042  : position(p), question(q), failureNode(std::move(failureNode)),
1043  matcherTypeID(matcherTypeID) {}
1044 
1045 //===----------------------------------------------------------------------===//
1046 // BoolNode
1047 //===----------------------------------------------------------------------===//
1048 
1049 BoolNode::BoolNode(Position *position, Qualifier *question, Qualifier *answer,
1050  std::unique_ptr<MatcherNode> successNode,
1051  std::unique_ptr<MatcherNode> failureNode)
1052  : MatcherNode(TypeID::get<BoolNode>(), position, question,
1053  std::move(failureNode)),
1054  answer(answer), successNode(std::move(successNode)) {}
1055 
1056 //===----------------------------------------------------------------------===//
1057 // SuccessNode
1058 //===----------------------------------------------------------------------===//
1059 
1060 SuccessNode::SuccessNode(pdl::PatternOp pattern, Value root,
1061  std::unique_ptr<MatcherNode> failureNode)
1062  : MatcherNode(TypeID::get<SuccessNode>(), /*position=*/nullptr,
1063  /*question=*/nullptr, std::move(failureNode)),
1064  pattern(pattern), root(root) {}
1065 
1066 //===----------------------------------------------------------------------===//
1067 // SwitchNode
1068 //===----------------------------------------------------------------------===//
1069 
1071  : MatcherNode(TypeID::get<SwitchNode>(), position, question) {}
static Value buildPredicateList(pdl::PatternOp pattern, PredicateBuilder &builder, std::vector< PositionalPredicate > &predList, DenseMap< Value, Position * > &valueToPosition)
Given a pattern operation, build the set of matcher predicates necessary to match this pattern.
static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
static bool dependsOn(OrderedPredicate *a, OrderedPredicate *b)
Returns true if 'b' depends on a result of 'a'.
static void getTypePredicates(Value typeValue, function_ref< Attribute()> typeAttrFn, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
static void getNonTreePredicates(pdl::PatternOp pattern, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
Collect all of the predicates that cannot be determined via walking the tree.
static bool useOperandGroup(pdl::OperationOp op, unsigned index)
Returns true if the operand at the given index needs to be queried using an operand group,...
static void getTreePredicates(std::vector< PositionalPredicate > &predList, Value val, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs, Position *pos)
Collect the tree predicates anchored at the given value.
static void getResultPredicates(pdl::ResultOp op, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
static void getOperandTreePredicates(std::vector< PositionalPredicate > &predList, Value val, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs, Position *pos)
Collect all of the predicates for the given operand position.
std::unique_ptr< MatcherNode > & getOrCreateChild(SwitchNode *node, OrderedPredicate *predicate, pdl::PatternOp pattern)
Get or insert a child matcher for the given parent switch node, given a predicate and parent pattern.
static bool comparePosDepth(Position *lhs, Position *rhs)
Compares the depths of two positions.
static void visitUpward(std::vector< PositionalPredicate > &predList, OpIndex opIndex, PredicateBuilder &builder, DenseMap< Value, Position * > &valueToPosition, Position *&pos, unsigned rootID)
Visit a node during upward traversal.
static unsigned getNumNonRangeValues(ValueRange values)
Returns the number of non-range elements within values.
static SmallVector< Value > detectRoots(pdl::PatternOp pattern)
Given a pattern, determines the set of roots present in this pattern.
static void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp)
Sorts the range begin/end with the partial order given by cmp.
static void insertExitNode(std::unique_ptr< MatcherNode > *root)
Insert an exit node at the end of the failure path of the root.
static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate)
Returns true if the given matcher refers to the same predicate as the given ordered predicate.
static void foldSwitchToBool(std::unique_ptr< MatcherNode > &node)
Fold any switch nodes nested under node to boolean nodes when possible.
static void buildCostGraph(ArrayRef< Value > roots, RootOrderingGraph &graph, ParentMaps &parentMaps)
Given a list of candidate roots, builds the cost graph for connecting them.
static void getAttributePredicates(pdl::AttributeOp op, std::vector< PositionalPredicate > &predList, PredicateBuilder &builder, DenseMap< Value, Position * > &inputs)
static void propagatePattern(std::unique_ptr< MatcherNode > &node, OrderedPredicateList &list, std::vector< OrderedPredicate * >::iterator current, std::vector< OrderedPredicate * >::iterator end)
Build the matcher CFG by "pushing" patterns through by sorted predicate order.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
type_range getType() const
Definition: ValueRange.cpp:30
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_range getResults()
Definition: Operation.h:410
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:242
type_range getTypes() const
Definition: ValueRange.cpp:35
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
This class represents the base of a predicate matcher node.
Definition: PredicateTree.h:50
Position * getPosition() const
Returns the position on which the question predicate should be checked.
Definition: PredicateTree.h:63
Qualifier * getQuestion() const
Returns the predicate checked on this node.
Definition: PredicateTree.h:66
The optimal branching algorithm solver.
Definition: RootOrdering.h:90
unsigned solve()
Runs the Edmonds' algorithm for the current graph, returning the total cost of the minimum-weight spa...
std::vector< std::pair< Value, Value > > EdgeList
A list of edges (child, parent).
Definition: RootOrdering.h:93
EdgeList preOrderTraversal(ArrayRef< Value > nodes) const
Returns the computed edges as visited in the preorder traversal.
A position describes a value on the input IR on which a predicate may be applied, such as an operatio...
Definition: Predicate.h:144
unsigned getOperationDepth() const
Returns the depth of the first ancestor operation position.
Definition: Predicate.cpp:21
Predicates::Kind getKind() const
Returns the kind of this position.
Definition: Predicate.h:156
const KeyTy & getValue() const
Return the key value of this predicate.
Definition: Predicate.h:111
This class provides utilities for constructing predicates.
Definition: Predicate.h:596
ConstraintPosition * getConstraintPosition(ConstraintQuestion *q, unsigned index)
Definition: Predicate.h:622
Position * getTypeLiteral(Attribute attr)
Returns a type position for the given type value.
Definition: Predicate.h:674
Predicate getOperandCount(unsigned count)
Create a predicate comparing the number of operands of an operation to a known value.
Definition: Predicate.h:726
OperationPosition * getPassthroughOp(Position *p)
Returns the operation position equivalent to the given position.
Definition: Predicate.h:616
Predicate getIsNotNull()
Create a predicate comparing a value with null.
Definition: Predicate.h:720
Predicate getOperandCountAtLeast(unsigned count)
Definition: Predicate.h:730
Predicate getResultCountAtLeast(unsigned count)
Definition: Predicate.h:747
Position * getType(Position *p)
Returns a type position for the given entity.
Definition: Predicate.h:670
Position * getAttribute(OperationPosition *p, StringRef name)
Returns an attribute position for an attribute of the given operation.
Definition: Predicate.h:628
Position * getOperandGroup(OperationPosition *p, std::optional< unsigned > group, bool isVariadic)
Returns a position for a group of operands of the given operation.
Definition: Predicate.h:647
Position * getForEach(Position *p, unsigned id)
Definition: Predicate.h:637
Position * getOperand(OperationPosition *p, unsigned operand)
Returns an operand position for an operand of the given operation.
Definition: Predicate.h:642
Position * getResult(OperationPosition *p, unsigned result)
Returns a result position for a result of the given operation.
Definition: Predicate.h:656
Position * getRoot()
Returns the root operation position.
Definition: Predicate.h:606
Predicate getAttributeConstraint(Attribute attr)
Create a predicate comparing an attribute to a known value.
Definition: Predicate.h:696
Position * getResultGroup(OperationPosition *p, std::optional< unsigned > group, bool isVariadic)
Returns a position for a group of results of the given operation.
Definition: Predicate.h:661
Position * getAllResults(OperationPosition *p)
Definition: Predicate.h:665
UsersPosition * getUsers(Position *p, bool useRepresentative)
Returns the users of a position using the value at the given operand.
Definition: Predicate.h:679
Predicate getTypeConstraint(Attribute type)
Create a predicate comparing the type of an attribute or value to a known type.
Definition: Predicate.h:755
OperationPosition * getOperandDefiningOp(Position *p)
Returns the parent position defining the value held by the given operand.
Definition: Predicate.h:609
Predicate getResultCount(unsigned count)
Create a predicate comparing the number of results of an operation to a known value.
Definition: Predicate.h:743
std::pair< Qualifier *, Qualifier * > Predicate
An ordinal predicate consists of a "Question" and a set of acceptable "Answers" (later converted to o...
Definition: Predicate.h:693
Predicate getEqualTo(Position *pos)
Create a predicate checking if two values are equal.
Definition: Predicate.h:702
Position * getAllOperands(OperationPosition *p)
Definition: Predicate.h:651
Position * getAttributeLiteral(Attribute attr)
Returns an attribute position for the given attribute.
Definition: Predicate.h:633
Predicate getConstraint(StringRef name, ArrayRef< Position * > args, ArrayRef< Type > resultTypes, bool isNegated)
Create a predicate that applies a generic constraint.
Definition: Predicate.h:712
Predicate getOperationName(StringRef name)
Create a predicate comparing the name of an operation to a known value.
Definition: Predicate.h:736
An ordinal predicate consists of a "Question" and a set of acceptable "Answers" (later converted to o...
Definition: Predicate.h:410
Predicates::Kind getKind() const
Returns the kind of this qualifier.
Definition: Predicate.h:415
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
bool operator<(const Fraction &x, const Fraction &y)
Definition: Fraction.h:83
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
A position describing an attribute of an operation.
Definition: Predicate.h:174
A BoolNode denotes a question with a boolean-like result.
BoolNode(Position *position, Qualifier *question, Qualifier *answer, std::unique_ptr< MatcherNode > successNode, std::unique_ptr< MatcherNode > failureNode=nullptr)
A position describing the result of a native constraint.
Definition: Predicate.h:294
Apply a parameterized constraint to multiple position values and possibly produce results.
Definition: Predicate.h:479
An operation position describes an operation node in the IR.
Definition: Predicate.h:253
bool isRoot() const
Returns if this operation position corresponds to the root.
Definition: Predicate.h:277
bool isOperandDefiningOp() const
Returns if this operation represents an operand defining op.
Definition: Predicate.cpp:51
A PositionalPredicate is a predicate that is associated with a specific positional value.
Definition: PredicateTree.h:30
The information associated with an edge in the cost graph.
Definition: RootOrdering.h:59
Value connector
The connector value in the intersection of the two subtrees rooted at the source and target root that...
Definition: RootOrdering.h:70
std::pair< unsigned, unsigned > cost
The depth of the connector Value w.r.t.
Definition: RootOrdering.h:65
A SuccessNode denotes that a given high level pattern has successfully been matched.
SuccessNode(pdl::PatternOp pattern, Value root, std::unique_ptr< MatcherNode > failureNode)
A SwitchNode denotes a question with multiple potential results.
llvm::MapVector< Qualifier *, std::unique_ptr< MatcherNode > > ChildMapT
Returns the children of this switch node.
SwitchNode(Position *position, Qualifier *question)
A position describing the result type of an entity, i.e.
Definition: Predicate.h:354