MLIR 22.0.0git
TopologicalSortUtils.cpp
Go to the documentation of this file.
1//===- TopologicalSortUtils.cpp - Topological sort utilities --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "mlir/IR/Block.h"
13
14#include "llvm/ADT/PostOrderIterator.h"
15#include "llvm/ADT/SetVector.h"
16
17using namespace mlir;
18
19/// Return `true` if the given operation is ready to be scheduled.
20static bool isOpReady(Operation *op, DenseSet<Operation *> &unscheduledOps,
21 function_ref<bool(Value, Operation *)> isOperandReady) {
22 // An operation is ready to be scheduled if all its operands are ready. An
23 // operation is ready if:
24 const auto isReady = [&](Value value) {
25 // - the user-provided callback marks it as ready,
26 if (isOperandReady && isOperandReady(value, op))
27 return true;
28 Operation *parent = value.getDefiningOp();
29 // - it is a block argument,
30 if (!parent)
31 return true;
32 // - or it is not defined by an unscheduled op (and also not nested within
33 // an unscheduled op).
34 do {
35 // Stop traversal when op under examination is reached.
36 if (parent == op)
37 return true;
38 if (unscheduledOps.contains(parent))
39 return false;
40 } while ((parent = parent->getParentOp()));
41 // No unscheduled op found.
42 return true;
43 };
44
45 // An operation is recursively ready to be scheduled of it and its nested
46 // operations are ready.
47 WalkResult readyToSchedule = op->walk([&](Operation *nestedOp) {
48 return llvm::all_of(nestedOp->getOperands(),
49 [&](Value operand) { return isReady(operand); })
52 });
53 return !readyToSchedule.wasInterrupted();
54}
55
58 function_ref<bool(Value, Operation *)> isOperandReady) {
59 if (ops.empty())
60 return true;
61
62 // The set of operations that have not yet been scheduled.
63 DenseSet<Operation *> unscheduledOps;
64 // Mark all operations as unscheduled.
65 for (Operation &op : ops)
66 unscheduledOps.insert(&op);
67
68 Block::iterator nextScheduledOp = ops.begin();
69 Block::iterator end = ops.end();
70
71 bool allOpsScheduled = true;
72 while (!unscheduledOps.empty()) {
73 bool scheduledAtLeastOnce = false;
74
75 // Loop over the ops that are not sorted yet, try to find the ones "ready",
76 // i.e. the ones for which there aren't any operand produced by an op in the
77 // set, and "schedule" it (move it before the `nextScheduledOp`).
78 for (Operation &op :
79 llvm::make_early_inc_range(llvm::make_range(nextScheduledOp, end))) {
80 if (!isOpReady(&op, unscheduledOps, isOperandReady))
81 continue;
82
83 // Schedule the operation by moving it to the start.
84 unscheduledOps.erase(&op);
85 op.moveBefore(block, nextScheduledOp);
86 scheduledAtLeastOnce = true;
87 // Move the iterator forward if we schedule the operation at the front.
88 if (&op == &*nextScheduledOp)
89 ++nextScheduledOp;
90 }
91 // If no operations were scheduled, give up and advance the iterator.
92 if (!scheduledAtLeastOnce) {
93 allOpsScheduled = false;
94 unscheduledOps.erase(&*nextScheduledOp);
95 ++nextScheduledOp;
96 }
97 }
98
99 return allOpsScheduled;
100}
101
103 Block *block, function_ref<bool(Value, Operation *)> isOperandReady) {
104 return sortTopologically(block, block->without_terminator(), isOperandReady);
105}
106
109 function_ref<bool(Value, Operation *)> isOperandReady) {
110 if (ops.empty())
111 return true;
112
113 // The set of operations that have not yet been scheduled.
114 // Mark all operations as unscheduled.
115 DenseSet<Operation *> unscheduledOps(llvm::from_range, ops);
116
117 unsigned nextScheduledOp = 0;
118
119 bool allOpsScheduled = true;
120 while (!unscheduledOps.empty()) {
121 bool scheduledAtLeastOnce = false;
122
123 // Loop over the ops that are not sorted yet, try to find the ones "ready",
124 // i.e. the ones for which there aren't any operand produced by an op in the
125 // set, and "schedule" it (swap it with the op at `nextScheduledOp`).
126 for (unsigned i = nextScheduledOp; i < ops.size(); ++i) {
127 if (!isOpReady(ops[i], unscheduledOps, isOperandReady))
128 continue;
129
130 // Schedule the operation by moving it to the start.
131 unscheduledOps.erase(ops[i]);
132 std::swap(ops[i], ops[nextScheduledOp]);
133 scheduledAtLeastOnce = true;
134 ++nextScheduledOp;
135 }
136
137 // If no operations were scheduled, just schedule the first op and continue.
138 if (!scheduledAtLeastOnce) {
139 allOpsScheduled = false;
140 unscheduledOps.erase(ops[nextScheduledOp++]);
141 }
142 }
143
144 return allOpsScheduled;
145}
146
148 // For each block that has not been visited yet (i.e. that has no
149 // predecessors), add it to the list as well as its successors.
150 SetVector<Block *> blocks;
151 for (Block &b : region) {
152 if (blocks.count(&b) == 0) {
153 llvm::ReversePostOrderTraversal<Block *> traversal(&b);
154 blocks.insert_range(traversal);
155 }
156 }
157 assert(blocks.size() == region.getBlocks().size() &&
158 "some blocks are not sorted");
159
160 return blocks;
161}
162
163namespace {
164class TopoSortHelper {
165public:
166 explicit TopoSortHelper(const SetVector<Operation *> &toSort)
167 : toSort(toSort) {}
168
169 /// Executes the topological sort of the operations this instance was
170 /// constructed with. This function will destroy the internal state of the
171 /// instance.
173 if (toSort.size() <= 1) {
174 // Note: Creates a copy on purpose.
175 return toSort;
176 }
177
178 // First, find the root region to start the traversal through the IR. This
179 // additionally enriches the internal caches with all relevant ancestor
180 // regions and blocks.
181 Region *rootRegion = findCommonAncestorRegion();
182 assert(rootRegion && "expected all ops to have a common ancestor");
183
184 // Sort all elements in `toSort` by traversing the IR in the appropriate
185 // order.
186 SetVector<Operation *> result = topoSortRegion(*rootRegion);
187 assert(result.size() == toSort.size() &&
188 "expected all operations to be present in the result");
189 return result;
190 }
191
192private:
193 /// Computes the closest common ancestor region of all operations in `toSort`.
194 Region *findCommonAncestorRegion() {
195 // Map to count the number of times a region was encountered.
196 DenseMap<Region *, size_t> regionCounts;
197 size_t expectedCount = toSort.size();
198
199 // Walk the region tree for each operation towards the root and add to the
200 // region count.
201 Region *res = nullptr;
202 for (Operation *op : toSort) {
203 Region *current = op->getParentRegion();
204 // Store the block as an ancestor block.
205 ancestorBlocks.insert(op->getBlock());
206 while (current) {
207 // Insert or update the count and compare it.
208 if (++regionCounts[current] == expectedCount) {
209 res = current;
210 break;
211 }
212 ancestorBlocks.insert(current->getParentOp()->getBlock());
213 current = current->getParentRegion();
214 }
215 }
216 auto firstRange = llvm::make_first_range(regionCounts);
217 ancestorRegions.insert_range(firstRange);
218 return res;
219 }
220
221 /// Performs the dominance respecting IR walk to collect the topological order
222 /// of the operation to sort.
223 SetVector<Operation *> topoSortRegion(Region &rootRegion) {
224 using StackT = PointerUnion<Region *, Block *, Operation *>;
225
227 // Stack that stores the different IR constructs to traverse.
228 SmallVector<StackT> stack;
229 stack.push_back(&rootRegion);
230
231 // Traverse the IR in a dominance respecting pre-order walk.
232 while (!stack.empty()) {
233 StackT current = stack.pop_back_val();
234 if (auto *region = dyn_cast<Region *>(current)) {
235 // A region's blocks need to be traversed in dominance order.
236 SetVector<Block *> sortedBlocks = getBlocksSortedByDominance(*region);
237 for (Block *block : llvm::reverse(sortedBlocks)) {
238 // Only add blocks to the stack that are ancestors of the operations
239 // to sort.
240 if (ancestorBlocks.contains(block))
241 stack.push_back(block);
242 }
243 continue;
244 }
245
246 if (auto *block = dyn_cast<Block *>(current)) {
247 // Add all of the blocks operations to the stack.
248 for (Operation &op : llvm::reverse(*block))
249 stack.push_back(&op);
250 continue;
251 }
252
253 auto *op = cast<Operation *>(current);
254 if (toSort.contains(op))
255 result.insert(op);
256
257 // Add all the subregions that are ancestors of the operations to sort.
258 for (Region &subRegion : op->getRegions())
259 if (ancestorRegions.contains(&subRegion))
260 stack.push_back(&subRegion);
261 }
262 return result;
263 }
264
265 /// Operations to sort.
266 const SetVector<Operation *> &toSort;
267 /// Set containing all the ancestor regions of the operations to sort.
268 DenseSet<Region *> ancestorRegions;
269 /// Set containing all the ancestor blocks of the operations to sort.
270 DenseSet<Block *> ancestorBlocks;
271};
272} // namespace
273
276 return TopoSortHelper(toSort).sort();
277}
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static bool isOpReady(Operation *op, DenseSet< Operation * > &unscheduledOps, function_ref< bool(Value, Operation *)> isOperandReady)
Return true if the given operation is ready to be scheduled.
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType::iterator iterator
Definition Block.h:140
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:212
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:230
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition Region.cpp:45
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
BlockListType & getBlocks()
Definition Region.h:45
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition WalkResult.h:51
static WalkResult interrupt()
Definition WalkResult.h:46
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region &region)
Gets a list of blocks that is sorted according to dominance.
bool computeTopologicalSorting(MutableArrayRef< Operation * > ops, function_ref< bool(Value, Operation *)> isOperandReady=nullptr)
Compute a topological ordering of the given ops.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
bool sortTopologically(Block *block, iterator_range< Block::iterator > ops, function_ref< bool(Value, Operation *)> isOperandReady=nullptr)
Given a block, sort a range operations in said block in topological order.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.