MLIR  21.0.0git
TopologicalSortUtils.cpp
Go to the documentation of this file.
1 //===- TopologicalSortUtils.cpp - Topological sort utilities --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/IR/Block.h"
11 #include "mlir/IR/OpDefinition.h"
13 
14 #include "llvm/ADT/PostOrderIterator.h"
15 #include "llvm/ADT/SetVector.h"
16 
17 using namespace mlir;
18 
19 /// Return `true` if the given operation is ready to be scheduled.
20 static bool isOpReady(Operation *op, DenseSet<Operation *> &unscheduledOps,
21  function_ref<bool(Value, Operation *)> isOperandReady) {
22  // An operation is ready to be scheduled if all its operands are ready. An
23  // operation is ready if:
24  const auto isReady = [&](Value value) {
25  // - the user-provided callback marks it as ready,
26  if (isOperandReady && isOperandReady(value, op))
27  return true;
28  Operation *parent = value.getDefiningOp();
29  // - it is a block argument,
30  if (!parent)
31  return true;
32  // - or it is not defined by an unscheduled op (and also not nested within
33  // an unscheduled op).
34  do {
35  // Stop traversal when op under examination is reached.
36  if (parent == op)
37  return true;
38  if (unscheduledOps.contains(parent))
39  return false;
40  } while ((parent = parent->getParentOp()));
41  // No unscheduled op found.
42  return true;
43  };
44 
45  // An operation is recursively ready to be scheduled of it and its nested
46  // operations are ready.
47  WalkResult readyToSchedule = op->walk([&](Operation *nestedOp) {
48  return llvm::all_of(nestedOp->getOperands(),
49  [&](Value operand) { return isReady(operand); })
52  });
53  return !readyToSchedule.wasInterrupted();
54 }
55 
58  function_ref<bool(Value, Operation *)> isOperandReady) {
59  if (ops.empty())
60  return true;
61 
62  // The set of operations that have not yet been scheduled.
63  DenseSet<Operation *> unscheduledOps;
64  // Mark all operations as unscheduled.
65  for (Operation &op : ops)
66  unscheduledOps.insert(&op);
67 
68  Block::iterator nextScheduledOp = ops.begin();
69  Block::iterator end = ops.end();
70 
71  bool allOpsScheduled = true;
72  while (!unscheduledOps.empty()) {
73  bool scheduledAtLeastOnce = false;
74 
75  // Loop over the ops that are not sorted yet, try to find the ones "ready",
76  // i.e. the ones for which there aren't any operand produced by an op in the
77  // set, and "schedule" it (move it before the `nextScheduledOp`).
78  for (Operation &op :
79  llvm::make_early_inc_range(llvm::make_range(nextScheduledOp, end))) {
80  if (!isOpReady(&op, unscheduledOps, isOperandReady))
81  continue;
82 
83  // Schedule the operation by moving it to the start.
84  unscheduledOps.erase(&op);
85  op.moveBefore(block, nextScheduledOp);
86  scheduledAtLeastOnce = true;
87  // Move the iterator forward if we schedule the operation at the front.
88  if (&op == &*nextScheduledOp)
89  ++nextScheduledOp;
90  }
91  // If no operations were scheduled, give up and advance the iterator.
92  if (!scheduledAtLeastOnce) {
93  allOpsScheduled = false;
94  unscheduledOps.erase(&*nextScheduledOp);
95  ++nextScheduledOp;
96  }
97  }
98 
99  return allOpsScheduled;
100 }
101 
103  Block *block, function_ref<bool(Value, Operation *)> isOperandReady) {
104  if (block->empty())
105  return true;
106  if (block->back().hasTrait<OpTrait::IsTerminator>())
107  return sortTopologically(block, block->without_terminator(),
108  isOperandReady);
109  return sortTopologically(block, *block, isOperandReady);
110 }
111 
114  function_ref<bool(Value, Operation *)> isOperandReady) {
115  if (ops.empty())
116  return true;
117 
118  // The set of operations that have not yet been scheduled.
119  // Mark all operations as unscheduled.
120  DenseSet<Operation *> unscheduledOps(llvm::from_range, ops);
121 
122  unsigned nextScheduledOp = 0;
123 
124  bool allOpsScheduled = true;
125  while (!unscheduledOps.empty()) {
126  bool scheduledAtLeastOnce = false;
127 
128  // Loop over the ops that are not sorted yet, try to find the ones "ready",
129  // i.e. the ones for which there aren't any operand produced by an op in the
130  // set, and "schedule" it (swap it with the op at `nextScheduledOp`).
131  for (unsigned i = nextScheduledOp; i < ops.size(); ++i) {
132  if (!isOpReady(ops[i], unscheduledOps, isOperandReady))
133  continue;
134 
135  // Schedule the operation by moving it to the start.
136  unscheduledOps.erase(ops[i]);
137  std::swap(ops[i], ops[nextScheduledOp]);
138  scheduledAtLeastOnce = true;
139  ++nextScheduledOp;
140  }
141 
142  // If no operations were scheduled, just schedule the first op and continue.
143  if (!scheduledAtLeastOnce) {
144  allOpsScheduled = false;
145  unscheduledOps.erase(ops[nextScheduledOp++]);
146  }
147  }
148 
149  return allOpsScheduled;
150 }
151 
153  // For each block that has not been visited yet (i.e. that has no
154  // predecessors), add it to the list as well as its successors.
155  SetVector<Block *> blocks;
156  for (Block &b : region) {
157  if (blocks.count(&b) == 0) {
158  llvm::ReversePostOrderTraversal<Block *> traversal(&b);
159  blocks.insert_range(traversal);
160  }
161  }
162  assert(blocks.size() == region.getBlocks().size() &&
163  "some blocks are not sorted");
164 
165  return blocks;
166 }
167 
168 namespace {
169 class TopoSortHelper {
170 public:
171  explicit TopoSortHelper(const SetVector<Operation *> &toSort)
172  : toSort(toSort) {}
173 
174  /// Executes the topological sort of the operations this instance was
175  /// constructed with. This function will destroy the internal state of the
176  /// instance.
177  SetVector<Operation *> sort() {
178  if (toSort.size() <= 1) {
179  // Note: Creates a copy on purpose.
180  return toSort;
181  }
182 
183  // First, find the root region to start the traversal through the IR. This
184  // additionally enriches the internal caches with all relevant ancestor
185  // regions and blocks.
186  Region *rootRegion = findCommonAncestorRegion();
187  assert(rootRegion && "expected all ops to have a common ancestor");
188 
189  // Sort all elements in `toSort` by traversing the IR in the appropriate
190  // order.
191  SetVector<Operation *> result = topoSortRegion(*rootRegion);
192  assert(result.size() == toSort.size() &&
193  "expected all operations to be present in the result");
194  return result;
195  }
196 
197 private:
198  /// Computes the closest common ancestor region of all operations in `toSort`.
199  Region *findCommonAncestorRegion() {
200  // Map to count the number of times a region was encountered.
201  DenseMap<Region *, size_t> regionCounts;
202  size_t expectedCount = toSort.size();
203 
204  // Walk the region tree for each operation towards the root and add to the
205  // region count.
206  Region *res = nullptr;
207  for (Operation *op : toSort) {
208  Region *current = op->getParentRegion();
209  // Store the block as an ancestor block.
210  ancestorBlocks.insert(op->getBlock());
211  while (current) {
212  // Insert or update the count and compare it.
213  if (++regionCounts[current] == expectedCount) {
214  res = current;
215  break;
216  }
217  ancestorBlocks.insert(current->getParentOp()->getBlock());
218  current = current->getParentRegion();
219  }
220  }
221  auto firstRange = llvm::make_first_range(regionCounts);
222  ancestorRegions.insert_range(firstRange);
223  return res;
224  }
225 
226  /// Performs the dominance respecting IR walk to collect the topological order
227  /// of the operation to sort.
228  SetVector<Operation *> topoSortRegion(Region &rootRegion) {
230 
231  SetVector<Operation *> result;
232  // Stack that stores the different IR constructs to traverse.
233  SmallVector<StackT> stack;
234  stack.push_back(&rootRegion);
235 
236  // Traverse the IR in a dominance respecting pre-order walk.
237  while (!stack.empty()) {
238  StackT current = stack.pop_back_val();
239  if (auto *region = dyn_cast<Region *>(current)) {
240  // A region's blocks need to be traversed in dominance order.
241  SetVector<Block *> sortedBlocks = getBlocksSortedByDominance(*region);
242  for (Block *block : llvm::reverse(sortedBlocks)) {
243  // Only add blocks to the stack that are ancestors of the operations
244  // to sort.
245  if (ancestorBlocks.contains(block))
246  stack.push_back(block);
247  }
248  continue;
249  }
250 
251  if (auto *block = dyn_cast<Block *>(current)) {
252  // Add all of the blocks operations to the stack.
253  for (Operation &op : llvm::reverse(*block))
254  stack.push_back(&op);
255  continue;
256  }
257 
258  auto *op = cast<Operation *>(current);
259  if (toSort.contains(op))
260  result.insert(op);
261 
262  // Add all the subregions that are ancestors of the operations to sort.
263  for (Region &subRegion : op->getRegions())
264  if (ancestorRegions.contains(&subRegion))
265  stack.push_back(&subRegion);
266  }
267  return result;
268  }
269 
270  /// Operations to sort.
271  const SetVector<Operation *> &toSort;
272  /// Set containing all the ancestor regions of the operations to sort.
273  DenseSet<Region *> ancestorRegions;
274  /// Set containing all the ancestor blocks of the operations to sort.
275  DenseSet<Block *> ancestorBlocks;
276 };
277 } // namespace
278 
281  return TopoSortHelper(toSort).sort();
282 }
static bool isOpReady(Operation *op, DenseSet< Operation * > &unscheduledOps, function_ref< bool(Value, Operation *)> isOperandReady)
Return true if the given operation is ready to be scheduled.
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
bool empty()
Definition: Block.h:148
Operation & back()
Definition: Block.h:152
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:209
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:768
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:798
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Operation.cpp:555
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition: Region.cpp:45
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
BlockListType & getBlocks()
Definition: Region.h:45
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:33
static WalkResult advance()
Definition: Visitors.h:51
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:55
static WalkResult interrupt()
Definition: Visitors.h:50
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region &region)
Gets a list of blocks that is sorted according to dominance.
bool computeTopologicalSorting(MutableArrayRef< Operation * > ops, function_ref< bool(Value, Operation *)> isOperandReady=nullptr)
Compute a topological ordering of the given ops.
bool sortTopologically(Block *block, iterator_range< Block::iterator > ops, function_ref< bool(Value, Operation *)> isOperandReady=nullptr)
Given a block, sort a range operations in said block in topological order.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.