MLIR  22.0.0git
TopologicalSortUtils.cpp
Go to the documentation of this file.
1 //===- TopologicalSortUtils.cpp - Topological sort utilities --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/IR/Block.h"
11 #include "mlir/IR/OpDefinition.h"
13 
14 #include "llvm/ADT/PostOrderIterator.h"
15 #include "llvm/ADT/SetVector.h"
16 
17 using namespace mlir;
18 
19 /// Return `true` if the given operation is ready to be scheduled.
20 static bool isOpReady(Operation *op, DenseSet<Operation *> &unscheduledOps,
21  function_ref<bool(Value, Operation *)> isOperandReady) {
22  // An operation is ready to be scheduled if all its operands are ready. An
23  // operation is ready if:
24  const auto isReady = [&](Value value) {
25  // - the user-provided callback marks it as ready,
26  if (isOperandReady && isOperandReady(value, op))
27  return true;
28  Operation *parent = value.getDefiningOp();
29  // - it is a block argument,
30  if (!parent)
31  return true;
32  // - or it is not defined by an unscheduled op (and also not nested within
33  // an unscheduled op).
34  do {
35  // Stop traversal when op under examination is reached.
36  if (parent == op)
37  return true;
38  if (unscheduledOps.contains(parent))
39  return false;
40  } while ((parent = parent->getParentOp()));
41  // No unscheduled op found.
42  return true;
43  };
44 
45  // An operation is recursively ready to be scheduled of it and its nested
46  // operations are ready.
47  WalkResult readyToSchedule = op->walk([&](Operation *nestedOp) {
48  return llvm::all_of(nestedOp->getOperands(),
49  [&](Value operand) { return isReady(operand); })
52  });
53  return !readyToSchedule.wasInterrupted();
54 }
55 
58  function_ref<bool(Value, Operation *)> isOperandReady) {
59  if (ops.empty())
60  return true;
61 
62  // The set of operations that have not yet been scheduled.
63  DenseSet<Operation *> unscheduledOps;
64  // Mark all operations as unscheduled.
65  for (Operation &op : ops)
66  unscheduledOps.insert(&op);
67 
68  Block::iterator nextScheduledOp = ops.begin();
69  Block::iterator end = ops.end();
70 
71  bool allOpsScheduled = true;
72  while (!unscheduledOps.empty()) {
73  bool scheduledAtLeastOnce = false;
74 
75  // Loop over the ops that are not sorted yet, try to find the ones "ready",
76  // i.e. the ones for which there aren't any operand produced by an op in the
77  // set, and "schedule" it (move it before the `nextScheduledOp`).
78  for (Operation &op :
79  llvm::make_early_inc_range(llvm::make_range(nextScheduledOp, end))) {
80  if (!isOpReady(&op, unscheduledOps, isOperandReady))
81  continue;
82 
83  // Schedule the operation by moving it to the start.
84  unscheduledOps.erase(&op);
85  op.moveBefore(block, nextScheduledOp);
86  scheduledAtLeastOnce = true;
87  // Move the iterator forward if we schedule the operation at the front.
88  if (&op == &*nextScheduledOp)
89  ++nextScheduledOp;
90  }
91  // If no operations were scheduled, give up and advance the iterator.
92  if (!scheduledAtLeastOnce) {
93  allOpsScheduled = false;
94  unscheduledOps.erase(&*nextScheduledOp);
95  ++nextScheduledOp;
96  }
97  }
98 
99  return allOpsScheduled;
100 }
101 
103  Block *block, function_ref<bool(Value, Operation *)> isOperandReady) {
104  return sortTopologically(block, block->without_terminator(), isOperandReady);
105 }
106 
109  function_ref<bool(Value, Operation *)> isOperandReady) {
110  if (ops.empty())
111  return true;
112 
113  // The set of operations that have not yet been scheduled.
114  // Mark all operations as unscheduled.
115  DenseSet<Operation *> unscheduledOps(llvm::from_range, ops);
116 
117  unsigned nextScheduledOp = 0;
118 
119  bool allOpsScheduled = true;
120  while (!unscheduledOps.empty()) {
121  bool scheduledAtLeastOnce = false;
122 
123  // Loop over the ops that are not sorted yet, try to find the ones "ready",
124  // i.e. the ones for which there aren't any operand produced by an op in the
125  // set, and "schedule" it (swap it with the op at `nextScheduledOp`).
126  for (unsigned i = nextScheduledOp; i < ops.size(); ++i) {
127  if (!isOpReady(ops[i], unscheduledOps, isOperandReady))
128  continue;
129 
130  // Schedule the operation by moving it to the start.
131  unscheduledOps.erase(ops[i]);
132  std::swap(ops[i], ops[nextScheduledOp]);
133  scheduledAtLeastOnce = true;
134  ++nextScheduledOp;
135  }
136 
137  // If no operations were scheduled, just schedule the first op and continue.
138  if (!scheduledAtLeastOnce) {
139  allOpsScheduled = false;
140  unscheduledOps.erase(ops[nextScheduledOp++]);
141  }
142  }
143 
144  return allOpsScheduled;
145 }
146 
148  // For each block that has not been visited yet (i.e. that has no
149  // predecessors), add it to the list as well as its successors.
150  SetVector<Block *> blocks;
151  for (Block &b : region) {
152  if (blocks.count(&b) == 0) {
153  llvm::ReversePostOrderTraversal<Block *> traversal(&b);
154  blocks.insert_range(traversal);
155  }
156  }
157  assert(blocks.size() == region.getBlocks().size() &&
158  "some blocks are not sorted");
159 
160  return blocks;
161 }
162 
163 namespace {
164 class TopoSortHelper {
165 public:
166  explicit TopoSortHelper(const SetVector<Operation *> &toSort)
167  : toSort(toSort) {}
168 
169  /// Executes the topological sort of the operations this instance was
170  /// constructed with. This function will destroy the internal state of the
171  /// instance.
172  SetVector<Operation *> sort() {
173  if (toSort.size() <= 1) {
174  // Note: Creates a copy on purpose.
175  return toSort;
176  }
177 
178  // First, find the root region to start the traversal through the IR. This
179  // additionally enriches the internal caches with all relevant ancestor
180  // regions and blocks.
181  Region *rootRegion = findCommonAncestorRegion();
182  assert(rootRegion && "expected all ops to have a common ancestor");
183 
184  // Sort all elements in `toSort` by traversing the IR in the appropriate
185  // order.
186  SetVector<Operation *> result = topoSortRegion(*rootRegion);
187  assert(result.size() == toSort.size() &&
188  "expected all operations to be present in the result");
189  return result;
190  }
191 
192 private:
193  /// Computes the closest common ancestor region of all operations in `toSort`.
194  Region *findCommonAncestorRegion() {
195  // Map to count the number of times a region was encountered.
196  DenseMap<Region *, size_t> regionCounts;
197  size_t expectedCount = toSort.size();
198 
199  // Walk the region tree for each operation towards the root and add to the
200  // region count.
201  Region *res = nullptr;
202  for (Operation *op : toSort) {
203  Region *current = op->getParentRegion();
204  // Store the block as an ancestor block.
205  ancestorBlocks.insert(op->getBlock());
206  while (current) {
207  // Insert or update the count and compare it.
208  if (++regionCounts[current] == expectedCount) {
209  res = current;
210  break;
211  }
212  ancestorBlocks.insert(current->getParentOp()->getBlock());
213  current = current->getParentRegion();
214  }
215  }
216  auto firstRange = llvm::make_first_range(regionCounts);
217  ancestorRegions.insert_range(firstRange);
218  return res;
219  }
220 
221  /// Performs the dominance respecting IR walk to collect the topological order
222  /// of the operation to sort.
223  SetVector<Operation *> topoSortRegion(Region &rootRegion) {
225 
226  SetVector<Operation *> result;
227  // Stack that stores the different IR constructs to traverse.
228  SmallVector<StackT> stack;
229  stack.push_back(&rootRegion);
230 
231  // Traverse the IR in a dominance respecting pre-order walk.
232  while (!stack.empty()) {
233  StackT current = stack.pop_back_val();
234  if (auto *region = dyn_cast<Region *>(current)) {
235  // A region's blocks need to be traversed in dominance order.
236  SetVector<Block *> sortedBlocks = getBlocksSortedByDominance(*region);
237  for (Block *block : llvm::reverse(sortedBlocks)) {
238  // Only add blocks to the stack that are ancestors of the operations
239  // to sort.
240  if (ancestorBlocks.contains(block))
241  stack.push_back(block);
242  }
243  continue;
244  }
245 
246  if (auto *block = dyn_cast<Block *>(current)) {
247  // Add all of the blocks operations to the stack.
248  for (Operation &op : llvm::reverse(*block))
249  stack.push_back(&op);
250  continue;
251  }
252 
253  auto *op = cast<Operation *>(current);
254  if (toSort.contains(op))
255  result.insert(op);
256 
257  // Add all the subregions that are ancestors of the operations to sort.
258  for (Region &subRegion : op->getRegions())
259  if (ancestorRegions.contains(&subRegion))
260  stack.push_back(&subRegion);
261  }
262  return result;
263  }
264 
265  /// Operations to sort.
266  const SetVector<Operation *> &toSort;
267  /// Set containing all the ancestor regions of the operations to sort.
268  DenseSet<Region *> ancestorRegions;
269  /// Set containing all the ancestor blocks of the operations to sort.
270  DenseSet<Block *> ancestorBlocks;
271 };
272 } // namespace
273 
276  return TopoSortHelper(toSort).sort();
277 }
static bool isOpReady(Operation *op, DenseSet< Operation * > &unscheduledOps, function_ref< bool(Value, Operation *)> isOperandReady)
Return true if the given operation is ready to be scheduled.
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:212
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Operation.cpp:554
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition: Region.cpp:45
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
BlockListType & getBlocks()
Definition: Region.h:45
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: WalkResult.h:29
static WalkResult advance()
Definition: WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: WalkResult.h:51
static WalkResult interrupt()
Definition: WalkResult.h:46
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region &region)
Gets a list of blocks that is sorted according to dominance.
bool computeTopologicalSorting(MutableArrayRef< Operation * > ops, function_ref< bool(Value, Operation *)> isOperandReady=nullptr)
Compute a topological ordering of the given ops.
bool sortTopologically(Block *block, iterator_range< Block::iterator > ops, function_ref< bool(Value, Operation *)> isOperandReady=nullptr)
Given a block, sort a range operations in said block in topological order.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.