MLIR  20.0.0git
TopologicalSortUtils.cpp
Go to the documentation of this file.
1 //===- TopologicalSortUtils.cpp - Topological sort utilities --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/IR/Block.h"
11 #include "mlir/IR/OpDefinition.h"
13 
14 #include "llvm/ADT/PostOrderIterator.h"
15 #include "llvm/ADT/SetVector.h"
16 
17 using namespace mlir;
18 
19 /// Return `true` if the given operation is ready to be scheduled.
20 static bool isOpReady(Operation *op, DenseSet<Operation *> &unscheduledOps,
21  function_ref<bool(Value, Operation *)> isOperandReady) {
22  // An operation is ready to be scheduled if all its operands are ready. An
23  // operation is ready if:
24  const auto isReady = [&](Value value) {
25  // - the user-provided callback marks it as ready,
26  if (isOperandReady && isOperandReady(value, op))
27  return true;
28  Operation *parent = value.getDefiningOp();
29  // - it is a block argument,
30  if (!parent)
31  return true;
32  // - or it is not defined by an unscheduled op (and also not nested within
33  // an unscheduled op).
34  do {
35  // Stop traversal when op under examination is reached.
36  if (parent == op)
37  return true;
38  if (unscheduledOps.contains(parent))
39  return false;
40  } while ((parent = parent->getParentOp()));
41  // No unscheduled op found.
42  return true;
43  };
44 
45  // An operation is recursively ready to be scheduled of it and its nested
46  // operations are ready.
47  WalkResult readyToSchedule = op->walk([&](Operation *nestedOp) {
48  return llvm::all_of(nestedOp->getOperands(),
49  [&](Value operand) { return isReady(operand); })
52  });
53  return !readyToSchedule.wasInterrupted();
54 }
55 
58  function_ref<bool(Value, Operation *)> isOperandReady) {
59  if (ops.empty())
60  return true;
61 
62  // The set of operations that have not yet been scheduled.
63  DenseSet<Operation *> unscheduledOps;
64  // Mark all operations as unscheduled.
65  for (Operation &op : ops)
66  unscheduledOps.insert(&op);
67 
68  Block::iterator nextScheduledOp = ops.begin();
69  Block::iterator end = ops.end();
70 
71  bool allOpsScheduled = true;
72  while (!unscheduledOps.empty()) {
73  bool scheduledAtLeastOnce = false;
74 
75  // Loop over the ops that are not sorted yet, try to find the ones "ready",
76  // i.e. the ones for which there aren't any operand produced by an op in the
77  // set, and "schedule" it (move it before the `nextScheduledOp`).
78  for (Operation &op :
79  llvm::make_early_inc_range(llvm::make_range(nextScheduledOp, end))) {
80  if (!isOpReady(&op, unscheduledOps, isOperandReady))
81  continue;
82 
83  // Schedule the operation by moving it to the start.
84  unscheduledOps.erase(&op);
85  op.moveBefore(block, nextScheduledOp);
86  scheduledAtLeastOnce = true;
87  // Move the iterator forward if we schedule the operation at the front.
88  if (&op == &*nextScheduledOp)
89  ++nextScheduledOp;
90  }
91  // If no operations were scheduled, give up and advance the iterator.
92  if (!scheduledAtLeastOnce) {
93  allOpsScheduled = false;
94  unscheduledOps.erase(&*nextScheduledOp);
95  ++nextScheduledOp;
96  }
97  }
98 
99  return allOpsScheduled;
100 }
101 
103  Block *block, function_ref<bool(Value, Operation *)> isOperandReady) {
104  if (block->empty())
105  return true;
106  if (block->back().hasTrait<OpTrait::IsTerminator>())
107  return sortTopologically(block, block->without_terminator(),
108  isOperandReady);
109  return sortTopologically(block, *block, isOperandReady);
110 }
111 
114  function_ref<bool(Value, Operation *)> isOperandReady) {
115  if (ops.empty())
116  return true;
117 
118  // The set of operations that have not yet been scheduled.
119  DenseSet<Operation *> unscheduledOps;
120 
121  // Mark all operations as unscheduled.
122  for (Operation *op : ops)
123  unscheduledOps.insert(op);
124 
125  unsigned nextScheduledOp = 0;
126 
127  bool allOpsScheduled = true;
128  while (!unscheduledOps.empty()) {
129  bool scheduledAtLeastOnce = false;
130 
131  // Loop over the ops that are not sorted yet, try to find the ones "ready",
132  // i.e. the ones for which there aren't any operand produced by an op in the
133  // set, and "schedule" it (swap it with the op at `nextScheduledOp`).
134  for (unsigned i = nextScheduledOp; i < ops.size(); ++i) {
135  if (!isOpReady(ops[i], unscheduledOps, isOperandReady))
136  continue;
137 
138  // Schedule the operation by moving it to the start.
139  unscheduledOps.erase(ops[i]);
140  std::swap(ops[i], ops[nextScheduledOp]);
141  scheduledAtLeastOnce = true;
142  ++nextScheduledOp;
143  }
144 
145  // If no operations were scheduled, just schedule the first op and continue.
146  if (!scheduledAtLeastOnce) {
147  allOpsScheduled = false;
148  unscheduledOps.erase(ops[nextScheduledOp++]);
149  }
150  }
151 
152  return allOpsScheduled;
153 }
154 
156  // For each block that has not been visited yet (i.e. that has no
157  // predecessors), add it to the list as well as its successors.
158  SetVector<Block *> blocks;
159  for (Block &b : region) {
160  if (blocks.count(&b) == 0) {
161  llvm::ReversePostOrderTraversal<Block *> traversal(&b);
162  blocks.insert(traversal.begin(), traversal.end());
163  }
164  }
165  assert(blocks.size() == region.getBlocks().size() &&
166  "some blocks are not sorted");
167 
168  return blocks;
169 }
170 
171 namespace {
172 class TopoSortHelper {
173 public:
174  explicit TopoSortHelper(const SetVector<Operation *> &toSort)
175  : toSort(toSort) {}
176 
177  /// Executes the topological sort of the operations this instance was
178  /// constructed with. This function will destroy the internal state of the
179  /// instance.
180  SetVector<Operation *> sort() {
181  if (toSort.size() <= 1) {
182  // Note: Creates a copy on purpose.
183  return toSort;
184  }
185 
186  // First, find the root region to start the traversal through the IR. This
187  // additionally enriches the internal caches with all relevant ancestor
188  // regions and blocks.
189  Region *rootRegion = findCommonAncestorRegion();
190  assert(rootRegion && "expected all ops to have a common ancestor");
191 
192  // Sort all elements in `toSort` by traversing the IR in the appropriate
193  // order.
194  SetVector<Operation *> result = topoSortRegion(*rootRegion);
195  assert(result.size() == toSort.size() &&
196  "expected all operations to be present in the result");
197  return result;
198  }
199 
200 private:
201  /// Computes the closest common ancestor region of all operations in `toSort`.
202  Region *findCommonAncestorRegion() {
203  // Map to count the number of times a region was encountered.
204  DenseMap<Region *, size_t> regionCounts;
205  size_t expectedCount = toSort.size();
206 
207  // Walk the region tree for each operation towards the root and add to the
208  // region count.
209  Region *res = nullptr;
210  for (Operation *op : toSort) {
211  Region *current = op->getParentRegion();
212  // Store the block as an ancestor block.
213  ancestorBlocks.insert(op->getBlock());
214  while (current) {
215  // Insert or update the count and compare it.
216  if (++regionCounts[current] == expectedCount) {
217  res = current;
218  break;
219  }
220  ancestorBlocks.insert(current->getParentOp()->getBlock());
221  current = current->getParentRegion();
222  }
223  }
224  auto firstRange = llvm::make_first_range(regionCounts);
225  ancestorRegions.insert(firstRange.begin(), firstRange.end());
226  return res;
227  }
228 
229  /// Performs the dominance respecting IR walk to collect the topological order
230  /// of the operation to sort.
231  SetVector<Operation *> topoSortRegion(Region &rootRegion) {
233 
234  SetVector<Operation *> result;
235  // Stack that stores the different IR constructs to traverse.
236  SmallVector<StackT> stack;
237  stack.push_back(&rootRegion);
238 
239  // Traverse the IR in a dominance respecting pre-order walk.
240  while (!stack.empty()) {
241  StackT current = stack.pop_back_val();
242  if (auto *region = dyn_cast<Region *>(current)) {
243  // A region's blocks need to be traversed in dominance order.
244  SetVector<Block *> sortedBlocks = getBlocksSortedByDominance(*region);
245  for (Block *block : llvm::reverse(sortedBlocks)) {
246  // Only add blocks to the stack that are ancestors of the operations
247  // to sort.
248  if (ancestorBlocks.contains(block))
249  stack.push_back(block);
250  }
251  continue;
252  }
253 
254  if (auto *block = dyn_cast<Block *>(current)) {
255  // Add all of the blocks operations to the stack.
256  for (Operation &op : llvm::reverse(*block))
257  stack.push_back(&op);
258  continue;
259  }
260 
261  auto *op = cast<Operation *>(current);
262  if (toSort.contains(op))
263  result.insert(op);
264 
265  // Add all the subregions that are ancestors of the operations to sort.
266  for (Region &subRegion : op->getRegions())
267  if (ancestorRegions.contains(&subRegion))
268  stack.push_back(&subRegion);
269  }
270  return result;
271  }
272 
273  /// Operations to sort.
274  const SetVector<Operation *> &toSort;
275  /// Set containing all the ancestor regions of the operations to sort.
276  DenseSet<Region *> ancestorRegions;
277  /// Set containing all the ancestor blocks of the operations to sort.
278  DenseSet<Block *> ancestorBlocks;
279 };
280 } // namespace
281 
284  return TopoSortHelper(toSort).sort();
285 }
static bool isOpReady(Operation *op, DenseSet< Operation * > &unscheduledOps, function_ref< bool(Value, Operation *)> isOperandReady)
Return true if the given operation is ready to be scheduled.
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
bool empty()
Definition: Block.h:148
Operation & back()
Definition: Block.h:152
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:209
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:764
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Operation.cpp:555
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition: Region.cpp:45
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
BlockListType & getBlocks()
Definition: Region.h:45
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:33
static WalkResult advance()
Definition: Visitors.h:51
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:55
static WalkResult interrupt()
Definition: Visitors.h:50
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region &region)
Gets a list of blocks that is sorted according to dominance.
bool computeTopologicalSorting(MutableArrayRef< Operation * > ops, function_ref< bool(Value, Operation *)> isOperandReady=nullptr)
Compute a topological ordering of the given ops.
bool sortTopologically(Block *block, iterator_range< Block::iterator > ops, function_ref< bool(Value, Operation *)> isOperandReady=nullptr)
Given a block, sort a range operations in said block in topological order.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.