14 #include "llvm/ADT/PostOrderIterator.h"
15 #include "llvm/ADT/SetVector.h"
24 const auto isReady = [&](
Value value) {
26 if (isOperandReady && isOperandReady(value, op))
28 Operation *parent = value.getDefiningOp();
38 if (unscheduledOps.contains(parent))
49 [&](
Value operand) { return isReady(operand); })
66 unscheduledOps.insert(&op);
71 bool allOpsScheduled =
true;
72 while (!unscheduledOps.empty()) {
73 bool scheduledAtLeastOnce =
false;
79 llvm::make_early_inc_range(llvm::make_range(nextScheduledOp, end))) {
80 if (!
isOpReady(&op, unscheduledOps, isOperandReady))
84 unscheduledOps.erase(&op);
86 scheduledAtLeastOnce =
true;
88 if (&op == &*nextScheduledOp)
92 if (!scheduledAtLeastOnce) {
93 allOpsScheduled =
false;
94 unscheduledOps.erase(&*nextScheduledOp);
99 return allOpsScheduled;
123 unscheduledOps.insert(op);
125 unsigned nextScheduledOp = 0;
127 bool allOpsScheduled =
true;
128 while (!unscheduledOps.empty()) {
129 bool scheduledAtLeastOnce =
false;
134 for (
unsigned i = nextScheduledOp; i < ops.size(); ++i) {
135 if (!
isOpReady(ops[i], unscheduledOps, isOperandReady))
139 unscheduledOps.erase(ops[i]);
140 std::swap(ops[i], ops[nextScheduledOp]);
141 scheduledAtLeastOnce =
true;
146 if (!scheduledAtLeastOnce) {
147 allOpsScheduled =
false;
148 unscheduledOps.erase(ops[nextScheduledOp++]);
152 return allOpsScheduled;
159 for (
Block &b : region) {
160 if (blocks.count(&b) == 0) {
161 llvm::ReversePostOrderTraversal<Block *> traversal(&b);
162 blocks.insert(traversal.begin(), traversal.end());
165 assert(blocks.size() == region.
getBlocks().size() &&
166 "some blocks are not sorted");
172 class TopoSortHelper {
181 if (toSort.size() <= 1) {
189 Region *rootRegion = findCommonAncestorRegion();
190 assert(rootRegion &&
"expected all ops to have a common ancestor");
195 assert(result.size() == toSort.size() &&
196 "expected all operations to be present in the result");
202 Region *findCommonAncestorRegion() {
205 size_t expectedCount = toSort.size();
213 ancestorBlocks.insert(op->
getBlock());
216 if (++regionCounts[current] == expectedCount) {
224 auto firstRange = llvm::make_first_range(regionCounts);
225 ancestorRegions.insert(firstRange.begin(), firstRange.end());
237 stack.push_back(&rootRegion);
240 while (!stack.empty()) {
241 StackT current = stack.pop_back_val();
242 if (
auto *region = dyn_cast<Region *>(current)) {
245 for (
Block *block : llvm::reverse(sortedBlocks)) {
248 if (ancestorBlocks.contains(block))
249 stack.push_back(block);
254 if (
auto *block = dyn_cast<Block *>(current)) {
256 for (
Operation &op : llvm::reverse(*block))
257 stack.push_back(&op);
261 auto *op = cast<Operation *>(current);
262 if (toSort.contains(op))
267 if (ancestorRegions.contains(&subRegion))
268 stack.push_back(&subRegion);
284 return TopoSortHelper(toSort).sort();
static bool isOpReady(Operation *op, DenseSet< Operation * > &unscheduledOps, function_ref< bool(Value, Operation *)> isOperandReady)
Return true if the given operation is ready to be scheduled.
Block represents an ordered list of Operations.
OpListType::iterator iterator
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Region * getParentRegion()
Returns the region to which the instruction belongs.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
Include the generated interface declarations.
SetVector< Block * > getBlocksSortedByDominance(Region ®ion)
Gets a list of blocks that is sorted according to dominance.
bool computeTopologicalSorting(MutableArrayRef< Operation * > ops, function_ref< bool(Value, Operation *)> isOperandReady=nullptr)
Compute a topological ordering of the given ops.
bool sortTopologically(Block *block, iterator_range< Block::iterator > ops, function_ref< bool(Value, Operation *)> isOperandReady=nullptr)
Given a block, sort a range operations in said block in topological order.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.