MLIR 23.0.0git
RegionUtils.cpp
Go to the documentation of this file.
1//===- RegionUtils.cpp - Region-related transformation utilities ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
13#include "mlir/IR/Block.h"
14#include "mlir/IR/Dominance.h"
15#include "mlir/IR/IRMapping.h"
16#include "mlir/IR/Operation.h"
18#include "mlir/IR/Value.h"
22#include "llvm/ADT/DenseSet.h"
23#include "llvm/ADT/DepthFirstIterator.h"
24#include "llvm/ADT/PostOrderIterator.h"
25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/SmallVectorExtras.h"
27#include "llvm/Support/DebugLog.h"
28
29#include <deque>
30#include <iterator>
31
32using namespace mlir;
33
34#define DEBUG_TYPE "region-utils"
35
37 Region &region) {
38 for (auto &use : llvm::make_early_inc_range(orig.getUses())) {
39 if (region.isAncestor(use.getOwner()->getParentRegion()))
40 use.set(replacement);
41 }
42}
43
45 Region &region, Region &limit, function_ref<void(OpOperand *)> callback) {
46 assert(limit.isAncestor(&region) &&
47 "expected isolation limit to be an ancestor of the given region");
48
49 // Collect proper ancestors of `limit` upfront to avoid traversing the region
50 // tree for every value.
51 SmallPtrSet<Region *, 4> properAncestors;
52 for (auto *reg = limit.getParentRegion(); reg != nullptr;
53 reg = reg->getParentRegion()) {
54 properAncestors.insert(reg);
55 }
56
57 region.walk([callback, &properAncestors](Operation *op) {
58 for (OpOperand &operand : op->getOpOperands())
59 // Callback on values defined in a proper ancestor of region.
60 if (properAncestors.count(operand.get().getParentRegion()))
61 callback(&operand);
62 });
63}
64
66 MutableArrayRef<Region> regions, function_ref<void(OpOperand *)> callback) {
67 for (Region &region : regions)
68 visitUsedValuesDefinedAbove(region, region, callback);
69}
70
72 SetVector<Value> &values) {
73 visitUsedValuesDefinedAbove(region, limit, [&](OpOperand *operand) {
74 values.insert(operand->get());
75 });
76}
77
79 SetVector<Value> &values) {
80 for (Region &region : regions)
81 getUsedValuesDefinedAbove(region, region, values);
82}
83
84//===----------------------------------------------------------------------===//
85// Make block isolated from above.
86//===----------------------------------------------------------------------===//
87
89 RewriterBase &rewriter, Region &region,
90 llvm::function_ref<bool(Operation *)> cloneOperationIntoRegion) {
91
92 // Get initial list of values used within region but defined above.
93 llvm::SetVector<Value> initialCapturedValues;
94 mlir::getUsedValuesDefinedAbove(region, initialCapturedValues);
95
96 std::deque<Value> worklist(initialCapturedValues.begin(),
97 initialCapturedValues.end());
100
101 llvm::SetVector<Value> finalCapturedValues;
102 SmallVector<Operation *> clonedOperations;
103 while (!worklist.empty()) {
104 Value currValue = worklist.front();
105 worklist.pop_front();
106 if (visited.count(currValue))
107 continue;
108 visited.insert(currValue);
109
110 Operation *definingOp = currValue.getDefiningOp();
111 if (!definingOp || visitedOps.count(definingOp)) {
112 finalCapturedValues.insert(currValue);
113 continue;
114 }
115 visitedOps.insert(definingOp);
116
117 if (!cloneOperationIntoRegion(definingOp)) {
118 // Defining operation isnt cloned, so add the current value to final
119 // captured values list.
120 finalCapturedValues.insert(currValue);
121 continue;
122 }
123
124 // Add all operands of the operation to the worklist and mark the op as to
125 // be cloned.
126 for (Value operand : definingOp->getOperands()) {
127 if (visited.count(operand))
128 continue;
129 worklist.push_back(operand);
130 }
131 clonedOperations.push_back(definingOp);
132 }
133
134 // The operations to be cloned need to be ordered in topological order
135 // so that they can be cloned into the region without violating use-def
136 // chains.
137 mlir::computeTopologicalSorting(clonedOperations);
138
139 OpBuilder::InsertionGuard g(rewriter);
140 // Collect types of existing block
141 Block *entryBlock = &region.front();
142 SmallVector<Type> newArgTypes =
143 llvm::to_vector(entryBlock->getArgumentTypes());
144 SmallVector<Location> newArgLocs = llvm::map_to_vector(
145 entryBlock->getArguments(), [](BlockArgument b) { return b.getLoc(); });
146
147 // Append the types of the captured values.
148 for (auto value : finalCapturedValues) {
149 newArgTypes.push_back(value.getType());
150 newArgLocs.push_back(value.getLoc());
151 }
152
153 // Create a new entry block.
154 Block *newEntryBlock =
155 rewriter.createBlock(&region, region.begin(), newArgTypes, newArgLocs);
156 auto newEntryBlockArgs = newEntryBlock->getArguments();
157
158 // Create a mapping between the captured values and the new arguments added.
159 IRMapping map;
160 auto replaceIfFn = [&](OpOperand &use) {
161 return region.isAncestor(use.getOwner()->getParentRegion());
162 };
163
164 for (auto [arg, capturedVal] :
165 llvm::zip(newEntryBlockArgs.take_back(finalCapturedValues.size()),
166 finalCapturedValues)) {
167 map.map(capturedVal, arg);
168 rewriter.replaceUsesWithIf(capturedVal, arg, replaceIfFn);
169 }
170 rewriter.setInsertionPointToStart(newEntryBlock);
171 for (auto *clonedOp : clonedOperations) {
172 Operation *newOp = rewriter.clone(*clonedOp, map);
173 rewriter.replaceOpUsesWithIf(clonedOp, newOp->getResults(), replaceIfFn);
174 }
175 rewriter.mergeBlocks(
176 entryBlock, newEntryBlock,
177 newEntryBlock->getArguments().take_front(entryBlock->getNumArguments()));
178 return llvm::to_vector(finalCapturedValues);
179}
180
181//===----------------------------------------------------------------------===//
182// Unreachable Block Elimination
183//===----------------------------------------------------------------------===//
184
185/// Erase the unreachable blocks within the provided regions. Returns success
186/// if any blocks were erased, failure otherwise.
187// TODO: We could likely merge this with the DCE algorithm below.
190 bool recurse) {
191 LDBG() << "Starting eraseUnreachableBlocks with " << regions.size()
192 << " regions";
193
194 // Set of blocks found to be reachable within a given region.
195 llvm::df_iterator_default_set<Block *, 16> reachable;
196 // If any blocks were found to be dead.
197 int erasedDeadBlocks = 0;
198
200 worklist.reserve(regions.size());
201 for (Region &region : regions)
202 worklist.push_back(&region);
203
204 LDBG(2) << "Initial worklist size: " << worklist.size();
205
206 while (!worklist.empty()) {
207 Region *region = worklist.pop_back_val();
208 if (region->empty()) {
209 LDBG(2) << "Skipping empty region";
210 continue;
211 }
212
213 LDBG(2) << "Processing region with " << region->getBlocks().size()
214 << " blocks";
215 if (region->getParentOp())
216 LDBG(2) << " -> for operation: "
217 << OpWithFlags(region->getParentOp(),
218 OpPrintingFlags().skipRegions());
219
220 // If this is a single block region, just collect the nested regions.
221 if (region->hasOneBlock()) {
222 if (recurse)
223 for (Operation &op : region->front())
224 for (Region &region : op.getRegions())
225 worklist.push_back(&region);
226 continue;
227 }
228
229 // Mark all reachable blocks.
230 reachable.clear();
231 for (Block *block : depth_first_ext(&region->front(), reachable))
232 (void)block /* Mark all reachable blocks */;
233
234 LDBG(2) << "Found " << reachable.size() << " reachable blocks out of "
235 << region->getBlocks().size() << " total blocks";
236
237 // Collect all of the dead blocks and push the live regions onto the
238 // worklist.
239 for (Block &block : llvm::make_early_inc_range(*region)) {
240 if (!reachable.count(&block)) {
241 LDBG() << "Erasing unreachable block: " << &block;
242 block.dropAllDefinedValueUses();
243 rewriter.eraseBlock(&block);
244 ++erasedDeadBlocks;
245 continue;
246 }
247
248 // Walk any regions within this block.
249 if (recurse)
250 for (Operation &op : block)
251 for (Region &region : op.getRegions())
252 worklist.push_back(&region);
253 }
254 }
255
256 LDBG() << "Finished eraseUnreachableBlocks, erased " << erasedDeadBlocks
257 << " dead blocks";
258
259 return success(erasedDeadBlocks > 0);
260}
261
262//===----------------------------------------------------------------------===//
263// Dead Code Elimination
264//===----------------------------------------------------------------------===//
265
266namespace {
267/// Data structure used to track which values have already been proved live.
268///
269/// Because Operation's can have multiple results, this data structure tracks
270/// liveness for both Value's and Operation's to avoid having to look through
271/// all Operation results when analyzing a use.
272///
273/// This data structure essentially tracks the dataflow lattice.
274/// The set of values/ops proved live increases monotonically to a fixed-point.
275class LiveMap {
276public:
277 /// Value methods.
278 bool wasProvenLive(Value value) {
279 // TODO: For results that are removable, e.g. for region based control flow,
280 // we could allow for these values to be tracked independently.
281 if (OpResult result = dyn_cast<OpResult>(value))
282 return wasProvenLive(result.getOwner());
283 return wasProvenLive(cast<BlockArgument>(value));
284 }
285 bool wasProvenLive(BlockArgument arg) { return liveValues.count(arg); }
286 void setProvedLive(Value value) {
287 // TODO: For results that are removable, e.g. for region based control flow,
288 // we could allow for these values to be tracked independently.
289 if (OpResult result = dyn_cast<OpResult>(value))
290 return setProvedLive(result.getOwner());
291 setProvedLive(cast<BlockArgument>(value));
292 }
293 void setProvedLive(BlockArgument arg) {
294 changed |= liveValues.insert(arg).second;
295 }
296
297 /// Operation methods.
298 bool wasProvenLive(Operation *op) { return liveOps.count(op); }
299 void setProvedLive(Operation *op) { changed |= liveOps.insert(op).second; }
300
301 /// Methods for tracking if we have reached a fixed-point.
302 void resetChanged() { changed = false; }
303 bool hasChanged() { return changed; }
304
305private:
306 bool changed = false;
307 DenseSet<Value> liveValues;
308 DenseSet<Operation *> liveOps;
309};
310} // namespace
311
312static bool isUseSpeciallyKnownDead(OpOperand &use, LiveMap &liveMap) {
313 Operation *owner = use.getOwner();
314 unsigned operandIndex = use.getOperandNumber();
315 // This pass generally treats all uses of an op as live if the op itself is
316 // considered live. However, for successor operands to terminators we need a
317 // finer-grained notion where we deduce liveness for operands individually.
318 // The reason for this is easiest to think about in terms of a classical phi
319 // node based SSA IR, where each successor operand is really an operand to a
320 // *separate* phi node, rather than all operands to the branch itself as with
321 // the block argument representation that MLIR uses.
322 //
323 // And similarly, because each successor operand is really an operand to a phi
324 // node, rather than to the terminator op itself, a terminator op can't e.g.
325 // "print" the value of a successor operand.
326 if (owner->hasTrait<OpTrait::IsTerminator>()) {
327 if (BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(owner))
328 if (auto arg = branchInterface.getSuccessorBlockArgument(operandIndex))
329 return !liveMap.wasProvenLive(*arg);
330 return false;
331 }
332 return false;
333}
334
335static void processValue(Value value, LiveMap &liveMap) {
336 bool provedLive = llvm::any_of(value.getUses(), [&](OpOperand &use) {
337 if (isUseSpeciallyKnownDead(use, liveMap))
338 return false;
339 return liveMap.wasProvenLive(use.getOwner());
340 });
341 if (provedLive)
342 liveMap.setProvedLive(value);
343}
344
345static void propagateLiveness(Region &region, LiveMap &liveMap);
346
347static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) {
348 // Terminators are always live.
349 liveMap.setProvedLive(op);
350
351 // Check to see if we can reason about the successor operands and mutate them.
352 BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(op);
353 if (!branchInterface) {
354 for (Block *successor : op->getSuccessors())
355 for (BlockArgument arg : successor->getArguments())
356 liveMap.setProvedLive(arg);
357 return;
358 }
359
360 // If we can't reason about the operand to a successor, conservatively mark
361 // it as live.
362 for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) {
363 SuccessorOperands successorOperands =
364 branchInterface.getSuccessorOperands(i);
365 for (unsigned opI = 0, opE = successorOperands.getProducedOperandCount();
366 opI != opE; ++opI)
367 liveMap.setProvedLive(op->getSuccessor(i)->getArgument(opI));
368 }
369}
370
371static void propagateLiveness(Operation *op, LiveMap &liveMap) {
372 // Recurse on any regions the op has.
373 for (Region &region : op->getRegions())
374 propagateLiveness(region, liveMap);
375
376 // Process terminator operations.
378 return propagateTerminatorLiveness(op, liveMap);
379
380 // Don't reprocess live operations.
381 if (liveMap.wasProvenLive(op))
382 return;
383
384 // Process the op itself.
385 if (!wouldOpBeTriviallyDead(op))
386 return liveMap.setProvedLive(op);
387
388 // If the op isn't intrinsically alive, check it's results.
389 for (Value value : op->getResults())
390 processValue(value, liveMap);
391}
392
393static void propagateLiveness(Region &region, LiveMap &liveMap) {
394 if (region.empty())
395 return;
396
397 for (Block *block : llvm::post_order(&region.front())) {
398 // We process block arguments after the ops in the block, to promote
399 // faster convergence to a fixed point (we try to visit uses before defs).
400 for (Operation &op : llvm::reverse(block->getOperations()))
401 propagateLiveness(&op, liveMap);
402
403 // We currently do not remove entry block arguments, so there is no need to
404 // track their liveness.
405 // TODO: We could track these and enable removing dead operands/arguments
406 // from region control flow operations.
407 if (block->isEntryBlock())
408 continue;
409
410 for (Value value : block->getArguments()) {
411 if (!liveMap.wasProvenLive(value))
412 processValue(value, liveMap);
413 }
414 }
415}
416
418 LiveMap &liveMap) {
419 BranchOpInterface branchOp = dyn_cast<BranchOpInterface>(terminator);
420 if (!branchOp)
421 return;
422
423 for (unsigned succI = 0, succE = terminator->getNumSuccessors();
424 succI < succE; succI++) {
425 // Iterating successors in reverse is not strictly needed, since we
426 // aren't erasing any successors. But it is slightly more efficient
427 // since it will promote later operands of the terminator being erased
428 // first, reducing the quadratic-ness.
429 unsigned succ = succE - succI - 1;
430 SuccessorOperands succOperands = branchOp.getSuccessorOperands(succ);
431 Block *successor = terminator->getSuccessor(succ);
432
433 for (unsigned argI = 0, argE = succOperands.size(); argI < argE; ++argI) {
434 // Iterating args in reverse is needed for correctness, to avoid
435 // shifting later args when earlier args are erased.
436 unsigned arg = argE - argI - 1;
437 if (!liveMap.wasProvenLive(successor->getArgument(arg)))
438 succOperands.erase(arg);
439 }
440 }
441}
442
443static LogicalResult deleteDeadness(RewriterBase &rewriter,
445 LiveMap &liveMap) {
446 bool erasedAnything = false;
447 for (Region &region : regions) {
448 if (region.empty())
449 continue;
450 bool hasSingleBlock = region.hasOneBlock();
451
452 // Delete every operation that is not live. Graph regions may have cycles
453 // in the use-def graph, so we must explicitly dropAllUses() from each
454 // operation as we erase it. Visiting the operations in post-order
455 // guarantees that in SSA CFG regions value uses are removed before defs,
456 // which makes dropAllUses() a no-op.
457 for (Block *block : llvm::post_order(&region.front())) {
458 if (!hasSingleBlock)
459 eraseTerminatorSuccessorOperands(block->getTerminator(), liveMap);
460 for (Operation &childOp :
461 llvm::make_early_inc_range(llvm::reverse(block->getOperations()))) {
462 if (!liveMap.wasProvenLive(&childOp)) {
463 erasedAnything = true;
464 childOp.dropAllUses();
465 rewriter.eraseOp(&childOp);
466 } else {
467 erasedAnything |= succeeded(
468 deleteDeadness(rewriter, childOp.getRegions(), liveMap));
469 }
470 }
471 }
472 // Delete block arguments.
473 // The entry block has an unknown contract with their enclosing block, so
474 // skip it.
475 for (Block &block : llvm::drop_begin(region.getBlocks(), 1)) {
476 block.eraseArguments(
477 [&](BlockArgument arg) { return !liveMap.wasProvenLive(arg); });
478 }
479 }
480 return success(erasedAnything);
481}
482
483// This function performs a simple dead code elimination algorithm over the
484// given regions.
485//
486// The overall goal is to prove that Values are dead, which allows deleting ops
487// and block arguments.
488//
489// This uses an optimistic algorithm that assumes everything is dead until
490// proved otherwise, allowing it to delete recursively dead cycles.
491//
492// This is a simple fixed-point dataflow analysis algorithm on a lattice
493// {Dead,Alive}. Because liveness flows backward, we generally try to
494// iterate everything backward to speed up convergence to the fixed-point. This
495// allows for being able to delete recursively dead cycles of the use-def graph,
496// including block arguments.
497//
498// This function returns success if any operations or arguments were deleted,
499// failure otherwise.
500LogicalResult mlir::runRegionDCE(RewriterBase &rewriter,
501 MutableArrayRef<Region> regions) {
502 LiveMap liveMap;
503 do {
504 liveMap.resetChanged();
505
506 for (Region &region : regions)
507 propagateLiveness(region, liveMap);
508 } while (liveMap.hasChanged());
509
510 return deleteDeadness(rewriter, regions, liveMap);
511}
512
514 bool includeNestedRegions) {
515 LDBG() << "Starting eliminateTriviallyDeadOps with "
516 << region.getBlocks().size()
517 << " blocks, includeNestedRegions=" << includeNestedRegions;
518 if (Operation *parentOp = region.getParentOp())
519 LDBG(2) << " -> parent operation: "
520 << OpWithFlags(parentOp, OpPrintingFlags().skipRegions());
521
522 bool changed = false;
523 unsigned erasedOps = 0;
524 unsigned seededOps = 0;
525 unsigned enqueuedDefs = 0;
526
527 // Step 1: walk each op in reverse program order. If the op is already
528 // trivially dead, erase it outright — there's no point recursing into
529 // regions that will be destroyed with it. Otherwise, if
530 // `includeNestedRegions` is set, recurse into its nested regions so values
531 // defined in `region` may lose their last user and show up as dead in
532 // step 2's seed. Reverse iteration lets dead chains propagate within this
533 // single pass.
534 for (Block &block : llvm::reverse(region)) {
535 LDBG(2) << "Scanning block " << &block << " with "
536 << block.getOperations().size() << " operations";
537 for (Operation &op :
538 llvm::make_early_inc_range(llvm::reverse(block.getOperations()))) {
539 LDBG(3) << "Visiting operation: "
540 << OpWithFlags(&op, OpPrintingFlags().skipRegions());
541 if (isOpTriviallyDead(&op)) {
542 LDBG() << "Erasing trivially dead operation: "
543 << OpWithFlags(&op, OpPrintingFlags().skipRegions());
544 rewriter.eraseOp(&op);
545 changed = true;
546 ++erasedOps;
547 continue;
548 }
549 if (includeNestedRegions) {
550 unsigned regionIdx = 0;
551 for (Region &nested : op.getRegions()) {
552 LDBG(2) << "Recursing into nested region #" << regionIdx
553 << " of operation " << op.getName();
554 bool nestedChanged =
555 eliminateTriviallyDeadOps(rewriter, nested, includeNestedRegions);
556 LDBG(2) << "Finished nested region #" << regionIdx << " of operation "
557 << op.getName() << ", changed=" << nestedChanged;
558 changed |= nestedChanged;
559 ++regionIdx;
560 }
561 }
562 }
563 }
564
565 // Step 2: worklist over ops in this region only.
566 //
567 // Worklist invariant: an op is pushed only once we have verified it is
568 // trivially dead. No speculative enqueues: every op on the worklist will
569 // be erased when popped. Two things enforce this:
570 // - the initial seed below calls isOpTriviallyDead before enqueueing,
571 // - the propagation inside the loop drops the erasing op's use of
572 // `defOp` *before* re-checking isOpTriviallyDead(defOp), so the check
573 // sees the post-erase use count and only enqueues when actually dead.
574 // Deadness is monotonic within this pass (we never add users, only remove
575 // them), so an op that was dead at enqueue time is still dead at pop time.
577
578 LDBG(2) << "Stage 2: Seeding trivially dead operation worklist";
579 for (Operation &op : region.getOps()) {
580 if (isOpTriviallyDead(&op)) {
581 LDBG(2) << "Seeded worklist with operation: "
582 << OpWithFlags(&op, OpPrintingFlags().skipRegions());
583 worklist.push_back(&op);
584 changed = true;
585 ++seededOps;
586 }
587 }
588 LDBG(2) << "Initial worklist size: " << worklist.size();
589
590 while (!worklist.empty()) {
591 Operation *op = worklist.pop_back_val();
592 LDBG(2) << "Popped operation from worklist: "
593 << OpWithFlags(op, OpPrintingFlags().skipRegions());
594 /// Erase each operand to drop its use count before checking its defining
595 /// op: by the time we call isOpTriviallyDead on defOp, the
596 /// about-to-be-erased `op` is no longer counted as a user. Only
597 /// actually-dead ops enter the worklist.
598 ///
599 /// Walk nested operations as well because erasing `op` also implicitly
600 /// erases every operation nested under it and therefore drops their operand
601 /// uses.
602 op->walk([&](Operation *erasedOp) {
603 LDBG(3) << "Processing operands of operation erased: "
604 << OpWithFlags(erasedOp, OpPrintingFlags().skipRegions());
605 for (OpOperand &opOperand : erasedOp->getOpOperands()) {
606 Operation *defOp = opOperand.get().getDefiningOp();
607 if (!defOp) {
608 LDBG(4) << "Skipping operand #" << opOperand.getOperandNumber()
609 << ": value has no defining operation";
610 continue;
611 }
612 if (defOp->getParentRegion() != &region) {
613 LDBG(4) << "Skipping operand #" << opOperand.getOperandNumber()
614 << ": defining operation is outside the current region";
615 continue;
616 }
617 LDBG(4) << "Dropping operand #" << opOperand.getOperandNumber()
618 << " from defining operation: "
619 << OpWithFlags(defOp, OpPrintingFlags().skipRegions());
620 opOperand.drop();
621 if (isOpTriviallyDead(defOp)) {
622 LDBG(2) << "Enqueued newly trivially dead defining operation: "
623 << OpWithFlags(defOp, OpPrintingFlags().skipRegions());
624 worklist.push_back(defOp);
625 ++enqueuedDefs;
626 } else {
627 LDBG(4) << "Defining operation is still not trivially dead: "
628 << OpWithFlags(defOp, OpPrintingFlags().skipRegions());
629 }
630 }
631 });
632 LDBG() << "Erasing trivially dead worklist operation: "
633 << OpWithFlags(op, OpPrintingFlags().skipRegions());
634 rewriter.eraseOp(op);
635 ++erasedOps;
636 }
637 LDBG() << "Finished eliminateTriviallyDeadOps, erased " << erasedOps
638 << " operations, seeded " << seededOps << " operations, enqueued "
639 << enqueuedDefs << " defining operations, changed=" << changed;
640 return changed;
641}
642
643//===----------------------------------------------------------------------===//
644// Block Merging
645//===----------------------------------------------------------------------===//
646
647//===----------------------------------------------------------------------===//
648// BlockEquivalenceData
649//===----------------------------------------------------------------------===//
650
651namespace {
652/// This class contains the information for comparing the equivalencies of two
653/// blocks. Blocks are considered equivalent if they contain the same operations
654/// in the same order. The only allowed divergence is for operands that come
655/// from sources outside of the parent block, i.e. the uses of values produced
656/// within the block must be equivalent.
657/// e.g.,
658/// Equivalent:
659/// ^bb1(%arg0: i32)
660/// return %arg0, %foo : i32, i32
661/// ^bb2(%arg1: i32)
662/// return %arg1, %bar : i32, i32
663/// Not Equivalent:
664/// ^bb1(%arg0: i32)
665/// return %foo, %arg0 : i32, i32
666/// ^bb2(%arg1: i32)
667/// return %arg1, %bar : i32, i32
668struct BlockEquivalenceData {
669 BlockEquivalenceData(Block *block);
670
671 /// Return the order index for the given value that is within the block of
672 /// this data.
673 unsigned getOrderOf(Value value) const;
674
675 /// The block this data refers to.
676 Block *block;
677 /// A hash value for this block.
678 llvm::hash_code hash;
679 /// A map of result producing operations to their relative orders within this
680 /// block. The order of an operation is the number of defined values that are
681 /// produced within the block before this operation.
683};
684} // namespace
685
686BlockEquivalenceData::BlockEquivalenceData(Block *block)
687 : block(block), hash(0) {
688 unsigned orderIt = block->getNumArguments();
689 for (Operation &op : *block) {
690 if (unsigned numResults = op.getNumResults()) {
691 opOrderIndex.try_emplace(&op, orderIt);
692 orderIt += numResults;
693 }
698 hash = llvm::hash_combine(hash, opHash);
699 }
700}
701
702unsigned BlockEquivalenceData::getOrderOf(Value value) const {
703 assert(value.getParentBlock() == block && "expected value of this block");
704
705 // Arguments use the argument number as the order index.
706 if (BlockArgument arg = dyn_cast<BlockArgument>(value))
707 return arg.getArgNumber();
708
709 // Otherwise, the result order is offset from the parent op's order.
710 OpResult result = cast<OpResult>(value);
711 auto opOrderIt = opOrderIndex.find(result.getDefiningOp());
712 assert(opOrderIt != opOrderIndex.end() && "expected op to have an order");
713 return opOrderIt->second + result.getResultNumber();
714}
715
716//===----------------------------------------------------------------------===//
717// BlockMergeCluster
718//===----------------------------------------------------------------------===//
719
720namespace {
721/// This class represents a cluster of blocks to be merged together.
722class BlockMergeCluster {
723public:
724 BlockMergeCluster(BlockEquivalenceData &&leaderData)
725 : leaderData(std::move(leaderData)) {}
726
727 /// Attempt to add the given block to this cluster. Returns success if the
728 /// block was merged, failure otherwise.
729 LogicalResult addToCluster(BlockEquivalenceData &blockData);
730
731 /// Try to merge all of the blocks within this cluster into the leader block.
732 LogicalResult merge(RewriterBase &rewriter);
733
734private:
735 /// The equivalence data for the leader of the cluster.
736 BlockEquivalenceData leaderData;
737
738 /// The set of blocks that can be merged into the leader.
739 llvm::SmallSetVector<Block *, 1> blocksToMerge;
740
741 /// A set of operand+index pairs that correspond to operands that need to be
742 /// replaced by arguments when the cluster gets merged.
743 std::set<std::pair<int, int>> operandsToMerge;
744};
745} // namespace
746
747LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) {
748 if (leaderData.hash != blockData.hash)
749 return failure();
750 Block *leaderBlock = leaderData.block, *mergeBlock = blockData.block;
751 if (leaderBlock->getArgumentTypes() != mergeBlock->getArgumentTypes())
752 return failure();
753
754 // A set of operands that mismatch between the leader and the new block.
755 SmallVector<std::pair<int, int>, 8> mismatchedOperands;
756 auto lhsIt = leaderBlock->begin(), lhsE = leaderBlock->end();
757 auto rhsIt = blockData.block->begin(), rhsE = blockData.block->end();
758 for (int opI = 0; lhsIt != lhsE && rhsIt != rhsE; ++lhsIt, ++rhsIt, ++opI) {
759 // Check that the operations are equivalent.
762 /*markEquivalent=*/nullptr,
763 OperationEquivalence::Flags::IgnoreLocations))
764 return failure();
765
766 // Compare the operands of the two operations. If the operand is within
767 // the block, it must refer to the same operation.
768 auto lhsOperands = lhsIt->getOperands(), rhsOperands = rhsIt->getOperands();
769 for (int operand : llvm::seq<int>(0, lhsIt->getNumOperands())) {
770 Value lhsOperand = lhsOperands[operand];
771 Value rhsOperand = rhsOperands[operand];
772 if (lhsOperand == rhsOperand)
773 continue;
774 // Check that the types of the operands match.
775 if (lhsOperand.getType() != rhsOperand.getType())
776 return failure();
777
778 // Check that these uses are both external, or both internal.
779 bool lhsIsInBlock = lhsOperand.getParentBlock() == leaderBlock;
780 bool rhsIsInBlock = rhsOperand.getParentBlock() == mergeBlock;
781 if (lhsIsInBlock != rhsIsInBlock)
782 return failure();
783 // Let the operands differ if they are defined in a different block. These
784 // will become new arguments if the blocks get merged.
785 if (!lhsIsInBlock) {
786
787 // Check whether the operands aren't the result of an immediate
788 // predecessors terminator. In that case we are not able to use it as a
789 // successor operand when branching to the merged block as it does not
790 // dominate its producing operation.
791 auto isValidSuccessorArg = [](Block *block, Value operand) {
792 if (operand.getDefiningOp() !=
793 operand.getParentBlock()->getTerminator())
794 return true;
795 return !llvm::is_contained(block->getPredecessors(),
796 operand.getParentBlock());
797 };
798
799 if (!isValidSuccessorArg(leaderBlock, lhsOperand) ||
800 !isValidSuccessorArg(mergeBlock, rhsOperand))
801 return failure();
802
803 mismatchedOperands.emplace_back(opI, operand);
804 continue;
805 }
806
807 // Otherwise, these operands must have the same logical order within the
808 // parent block.
809 if (leaderData.getOrderOf(lhsOperand) != blockData.getOrderOf(rhsOperand))
810 return failure();
811 }
812
813 // If the lhs or rhs has external uses, the blocks cannot be merged as the
814 // merged version of this operation will not be either the lhs or rhs
815 // alone (thus semantically incorrect), but some mix dependending on which
816 // block preceeded this.
817 // TODO allow merging of operations when one block does not dominate the
818 // other
819 if (rhsIt->isUsedOutsideOfBlock(mergeBlock) ||
820 lhsIt->isUsedOutsideOfBlock(leaderBlock)) {
821 return failure();
822 }
823 }
824 // Make sure that the block sizes are equivalent.
825 if (lhsIt != lhsE || rhsIt != rhsE)
826 return failure();
827
828 // If we get here, the blocks are equivalent and can be merged.
829 operandsToMerge.insert(mismatchedOperands.begin(), mismatchedOperands.end());
830 blocksToMerge.insert(blockData.block);
831 return success();
832}
833
834/// Returns true if the predecessor terminators of the given block can not have
835/// their operands updated.
836static bool ableToUpdatePredOperands(Block *block) {
837 for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
838 if (!isa<BranchOpInterface>((*it)->getTerminator()))
839 return false;
840 }
841 return true;
842}
843
844/// Prunes the redundant list of new arguments. E.g., if we are passing an
845/// argument list like [x, y, z, x] this would return [x, y, z] and it would
846/// update the `block` (to whom the argument are passed to) accordingly. The new
847/// arguments are passed as arguments at the back of the block, hence we need to
848/// know how many `numOldArguments` were before, in order to correctly replace
849/// the new arguments in the block
851 const SmallVector<SmallVector<Value, 8>, 2> &newArguments,
852 RewriterBase &rewriter, unsigned numOldArguments, Block *block) {
853
854 SmallVector<SmallVector<Value, 8>, 2> newArgumentsPruned(
855 newArguments.size(), SmallVector<Value, 8>());
856
857 if (newArguments.empty())
858 return newArguments;
859
860 // `newArguments` is a 2D array of size `numLists` x `numArgs`
861 unsigned numLists = newArguments.size();
862 unsigned numArgs = newArguments[0].size();
863
864 // Map that for each arg index contains the index that we can use in place of
865 // the original index. E.g., if we have newArgs = [x, y, z, x], we will have
866 // idxToReplacement[3] = 0
867 llvm::DenseMap<unsigned, unsigned> idxToReplacement;
868
869 // This is a useful data structure to track the first appearance of a Value
870 // on a given list of arguments
871 DenseMap<Value, unsigned> firstValueToIdx;
872 for (unsigned j = 0; j < numArgs; ++j) {
873 Value newArg = newArguments[0][j];
874 firstValueToIdx.try_emplace(newArg, j);
875 }
876
877 // Go through the first list of arguments (list 0).
878 for (unsigned j = 0; j < numArgs; ++j) {
879 // Look back to see if there are possible redundancies in list 0. Please
880 // note that we are using a map to annotate when an argument was seen first
881 // to avoid a O(N^2) algorithm. This has the drawback that if we have two
882 // lists like:
883 // list0: [%a, %a, %a]
884 // list1: [%c, %b, %b]
885 // We cannot simplify it, because firstValueToIdx[%a] = 0, but we cannot
886 // point list1[1](==%b) or list1[2](==%b) to list1[0](==%c). However, since
887 // the number of arguments can be potentially unbounded we cannot afford a
888 // O(N^2) algorithm (to search to all the possible pairs) and we need to
889 // accept the trade-off.
890 unsigned k = firstValueToIdx[newArguments[0][j]];
891 if (k == j)
892 continue;
893
894 bool shouldReplaceJ = true;
895 unsigned replacement = k;
896 // If a possible redundancy is found, then scan the other lists: we
897 // can prune the arguments if and only if they are redundant in every
898 // list.
899 for (unsigned i = 1; i < numLists; ++i)
900 shouldReplaceJ =
901 shouldReplaceJ && (newArguments[i][k] == newArguments[i][j]);
902 // Save the replacement.
903 if (shouldReplaceJ)
904 idxToReplacement[j] = replacement;
905 }
906
907 // Populate the pruned argument list.
908 for (unsigned i = 0; i < numLists; ++i)
909 for (unsigned j = 0; j < numArgs; ++j)
910 if (!idxToReplacement.contains(j))
911 newArgumentsPruned[i].push_back(newArguments[i][j]);
912
913 // Replace the block's redundant arguments.
914 SmallVector<unsigned> toErase;
915 for (auto [idx, arg] : llvm::enumerate(block->getArguments())) {
916 if (idxToReplacement.contains(idx)) {
917 Value oldArg = block->getArgument(numOldArguments + idx);
918 Value newArg =
919 block->getArgument(numOldArguments + idxToReplacement[idx]);
920 rewriter.replaceAllUsesWith(oldArg, newArg);
921 toErase.push_back(numOldArguments + idx);
922 }
923 }
924
925 // Erase the block's redundant arguments.
926 for (unsigned idxToErase : llvm::reverse(toErase))
927 block->eraseArgument(idxToErase);
928 return newArgumentsPruned;
929}
930
931LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
932 // Don't consider clusters that don't have blocks to merge.
933 if (blocksToMerge.empty())
934 return failure();
935
936 Block *leaderBlock = leaderData.block;
937 if (!operandsToMerge.empty()) {
938 // If the cluster has operands to merge, verify that the predecessor
939 // terminators of each of the blocks can have their successor operands
940 // updated.
941 // TODO: We could try and sub-partition this cluster if only some blocks
942 // cause the mismatch.
943 if (!ableToUpdatePredOperands(leaderBlock) ||
944 !llvm::all_of(blocksToMerge, ableToUpdatePredOperands))
945 return failure();
946
947 // Collect the iterators for each of the blocks to merge. We will walk all
948 // of the iterators at once to avoid operand index invalidation.
949 SmallVector<Block::iterator, 2> blockIterators;
950 blockIterators.reserve(blocksToMerge.size() + 1);
951 blockIterators.push_back(leaderBlock->begin());
952 for (Block *mergeBlock : blocksToMerge)
953 blockIterators.push_back(mergeBlock->begin());
954
955 // Update each of the predecessor terminators with the new arguments.
956 SmallVector<SmallVector<Value, 8>, 2> newArguments(
957 1 + blocksToMerge.size(),
958 SmallVector<Value, 8>(operandsToMerge.size()));
959 unsigned curOpIndex = 0;
960 unsigned numOldArguments = leaderBlock->getNumArguments();
961 for (const auto &it : llvm::enumerate(operandsToMerge)) {
962 unsigned nextOpOffset = it.value().first - curOpIndex;
963 curOpIndex = it.value().first;
964
965 // Process the operand for each of the block iterators.
966 for (unsigned i = 0, e = blockIterators.size(); i != e; ++i) {
967 Block::iterator &blockIter = blockIterators[i];
968 std::advance(blockIter, nextOpOffset);
969 auto &operand = blockIter->getOpOperand(it.value().second);
970 newArguments[i][it.index()] = operand.get();
971
972 // Update the operand and insert an argument if this is the leader.
973 if (i == 0) {
974 Value operandVal = operand.get();
975 operand.set(leaderBlock->addArgument(operandVal.getType(),
976 operandVal.getLoc()));
977 }
978 }
979 }
980
981 // Prune redundant arguments and update the leader block argument list
982 newArguments = pruneRedundantArguments(newArguments, rewriter,
983 numOldArguments, leaderBlock);
984
985 // Update the predecessors for each of the blocks.
986 auto updatePredecessors = [&](Block *block, unsigned clusterIndex) {
987 for (auto predIt = block->pred_begin(), predE = block->pred_end();
988 predIt != predE; ++predIt) {
989 auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
990 unsigned succIndex = predIt.getSuccessorIndex();
991 branch.getSuccessorOperands(succIndex).append(
992 newArguments[clusterIndex]);
993 }
994 };
995 updatePredecessors(leaderBlock, /*clusterIndex=*/0);
996 for (unsigned i = 0, e = blocksToMerge.size(); i != e; ++i)
997 updatePredecessors(blocksToMerge[i], /*clusterIndex=*/i + 1);
998 }
999
1000 // Replace all uses of the merged blocks with the leader and erase them.
1001 for (Block *block : blocksToMerge) {
1002 block->replaceAllUsesWith(leaderBlock);
1003 rewriter.eraseBlock(block);
1004 }
1005 return success();
1006}
1007
1008/// Identify identical blocks within the given region and merge them, inserting
1009/// new block arguments as necessary. Returns success if any blocks were merged,
1010/// failure otherwise.
1011static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
1012 Region &region) {
1013 if (region.empty() || region.hasOneBlock())
1014 return failure();
1015
1016 // Identify sets of blocks, other than the entry block, that branch to the
1017 // same successors. We will use these groups to create clusters of equivalent
1018 // blocks.
1020 for (Block &block : llvm::drop_begin(region, 1))
1021 matchingSuccessors[block.getSuccessors()].push_back(&block);
1022
1023 bool mergedAnyBlocks = false;
1024 for (ArrayRef<Block *> blocks : llvm::make_second_range(matchingSuccessors)) {
1025 if (blocks.size() == 1)
1026 continue;
1027
1029 for (Block *block : blocks) {
1030 BlockEquivalenceData data(block);
1031
1032 // Don't allow merging if this block has any regions.
1033 // TODO: Add support for regions if necessary.
1034 bool hasNonEmptyRegion = llvm::any_of(*block, [](Operation &op) {
1035 return llvm::any_of(op.getRegions(),
1036 [](Region &region) { return !region.empty(); });
1037 });
1038 if (hasNonEmptyRegion)
1039 continue;
1040
1041 // Don't allow merging if this block's arguments are used outside of the
1042 // original block.
1043 bool argHasExternalUsers = llvm::any_of(
1044 block->getArguments(), [block](mlir::BlockArgument &arg) {
1045 return arg.isUsedOutsideOfBlock(block);
1046 });
1047 if (argHasExternalUsers)
1048 continue;
1049
1050 // Try to add this block to an existing cluster.
1051 bool addedToCluster = false;
1052 for (auto &cluster : clusters)
1053 if ((addedToCluster = succeeded(cluster.addToCluster(data))))
1054 break;
1055 if (!addedToCluster)
1056 clusters.emplace_back(std::move(data));
1057 }
1058 for (auto &cluster : clusters)
1059 mergedAnyBlocks |= succeeded(cluster.merge(rewriter));
1060 }
1061
1062 return success(mergedAnyBlocks);
1063}
1064
1065/// Identify identical blocks within the given regions and merge them, inserting
1066/// new block arguments as necessary.
1067static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
1068 MutableArrayRef<Region> regions) {
1069 llvm::SmallSetVector<Region *, 1> worklist;
1070 for (auto &region : regions)
1071 worklist.insert(&region);
1072 bool anyChanged = false;
1073 while (!worklist.empty()) {
1074 Region *region = worklist.pop_back_val();
1075 if (succeeded(mergeIdenticalBlocks(rewriter, *region))) {
1076 worklist.insert(region);
1077 anyChanged = true;
1078 }
1079
1080 // Add any nested regions to the worklist.
1081 for (Block &block : *region)
1082 for (auto &op : block)
1083 for (auto &nestedRegion : op.getRegions())
1084 worklist.insert(&nestedRegion);
1085 }
1086
1087 return success(anyChanged);
1088}
1089
1090/// If a block's argument is always the same across different invocations, then
1091/// drop the argument and use the value directly inside the block
1092static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
1093 Block &block) {
1094 SmallVector<size_t> argsToErase;
1095
1096 // Go through the arguments of the block.
1097 for (auto [argIdx, blockOperand] : llvm::enumerate(block.getArguments())) {
1098 bool sameArg = true;
1099 Value commonValue;
1100
1101 // Go through the block predecessor and flag if they pass to the block
1102 // different values for the same argument.
1103 for (Block::pred_iterator predIt = block.pred_begin(),
1104 predE = block.pred_end();
1105 predIt != predE; ++predIt) {
1106 auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator());
1107 if (!branch) {
1108 sameArg = false;
1109 break;
1110 }
1111 unsigned succIndex = predIt.getSuccessorIndex();
1112 SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
1113
1114 // Produced operands are generated by the terminator operation itself
1115 // (e.g., results of an async call) and cannot be forwarded or dropped.
1116 if (succOperands.isOperandProduced(argIdx)) {
1117 sameArg = false;
1118 break;
1119 }
1120
1121 // Get the forwarded operand value using operator[] which correctly
1122 // adjusts for the produced operand offset.
1123 Value operandValue = succOperands[argIdx];
1124 if (!commonValue) {
1125 commonValue = operandValue;
1126 continue;
1127 }
1128 if (operandValue != commonValue) {
1129 sameArg = false;
1130 break;
1131 }
1132 }
1133
1134 // If they are passing the same value, drop the argument.
1135 if (commonValue && sameArg) {
1136 argsToErase.push_back(argIdx);
1137
1138 // Remove the argument from the block.
1139 rewriter.replaceAllUsesWith(blockOperand, commonValue);
1140 }
1141 }
1142
1143 // Remove the arguments.
1144 for (size_t argIdx : llvm::reverse(argsToErase)) {
1145 block.eraseArgument(argIdx);
1146
1147 // Remove the argument from the branch ops.
1148 for (auto predIt = block.pred_begin(), predE = block.pred_end();
1149 predIt != predE; ++predIt) {
1150 auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
1151 unsigned succIndex = predIt.getSuccessorIndex();
1152 SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
1153 succOperands.erase(argIdx);
1154 }
1155 }
1156 return success(!argsToErase.empty());
1157}
1158
1159/// This optimization drops redundant argument to blocks. I.e., if a given
1160/// argument to a block receives the same value from each of the block
1161/// predecessors, we can remove the argument from the block and use directly the
1162/// original value. This is a simple example:
1163///
1164/// %cond = llvm.call @rand() : () -> i1
1165/// %val0 = llvm.mlir.constant(1 : i64) : i64
1166/// %val1 = llvm.mlir.constant(2 : i64) : i64
1167/// %val2 = llvm.mlir.constant(3 : i64) : i64
1168/// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
1169/// : i64)
1170///
1171/// ^bb1(%arg0 : i64, %arg1 : i64):
1172/// llvm.call @foo(%arg0, %arg1)
1173///
1174/// The previous IR can be rewritten as:
1175/// %cond = llvm.call @rand() : () -> i1
1176/// %val0 = llvm.mlir.constant(1 : i64) : i64
1177/// %val1 = llvm.mlir.constant(2 : i64) : i64
1178/// %val2 = llvm.mlir.constant(3 : i64) : i64
1179/// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
1180///
1181/// ^bb1(%arg0 : i64):
1182/// llvm.call @foo(%val0, %arg0)
1183///
1184static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
1185 MutableArrayRef<Region> regions) {
1186 llvm::SmallSetVector<Region *, 1> worklist;
1187 for (Region &region : regions)
1188 worklist.insert(&region);
1189 bool anyChanged = false;
1190 while (!worklist.empty()) {
1191 Region *region = worklist.pop_back_val();
1192
1193 // Add any nested regions to the worklist.
1194 for (Block &block : *region) {
1195 anyChanged =
1196 succeeded(dropRedundantArguments(rewriter, block)) || anyChanged;
1197
1198 for (Operation &op : block)
1199 for (Region &nestedRegion : op.getRegions())
1200 worklist.insert(&nestedRegion);
1201 }
1202 }
1203 return success(anyChanged);
1204}
1205
1206//===----------------------------------------------------------------------===//
1207// Region Simplification
1208//===----------------------------------------------------------------------===//
1209
1210/// Run a set of structural simplifications over the given regions. This
1211/// includes transformations like unreachable block elimination, dead argument
1212/// elimination, as well as some other DCE. This function returns success if any
1213/// of the regions were simplified, failure otherwise.
1214LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
1216 bool mergeBlocks) {
1217 bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
1218 bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
1219 bool mergedIdenticalBlocks = false;
1220 bool droppedRedundantArguments = false;
1221 if (mergeBlocks) {
1222 mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions));
1223 droppedRedundantArguments =
1224 succeeded(dropRedundantArguments(rewriter, regions));
1225 }
1226 return success(eliminatedBlocks || eliminatedOpsOrArgs ||
1227 mergedIdenticalBlocks || droppedRedundantArguments);
1228}
1229
1230//===---------------------------------------------------------------------===//
1231// Move operation dependencies
1232//===---------------------------------------------------------------------===//
1233
1234/// Check if moving operations in the slice before `insertionPoint` would break
1235/// dominance due to block argument operands. Returns true if all block args
1236/// dominate the insertion point (no issue), false otherwise. If `failingOp` is
1237/// provided, it will be set to the first problematic op.
1238///
1239/// For operands defined by ops: either the defining op is in the slice (so
1240/// dominance preserved), or it already dominates insertionPoint (otherwise it
1241/// would be in the slice). So we only need to check block argument operands,
1242/// both as direct operands and as values captured inside regions.
1244 const llvm::SetVector<Operation *> &slice, Operation *insertionPoint,
1245 DominanceInfo &dominance, Operation **failingOp = nullptr) {
1246 Block *insertionBlock = insertionPoint->getBlock();
1247
1248 // Returns true if the block arg dominates, false otherwise. Sets failingOp
1249 // on failure.
1250 auto argDominates = [&](BlockArgument arg, Operation *op) {
1251 Block *argBlock = arg.getOwner();
1252 bool dominates = argBlock == insertionBlock ||
1253 dominance.dominates(argBlock, insertionBlock);
1254 if (!dominates && failingOp)
1255 *failingOp = op;
1256 return dominates;
1257 };
1258
1259 for (Operation *op : slice) {
1260 // Check direct operands.
1261 for (Value operand : op->getOperands()) {
1262 auto arg = dyn_cast<BlockArgument>(operand);
1263 if (!arg)
1264 continue;
1265 if (!argDominates(arg, op))
1266 return false;
1267 }
1268
1269 // Check block arguments captured inside regions. Process one region at a
1270 // time to enable early exit without collecting values from all regions.
1271 for (Region &region : op->getRegions()) {
1272 SetVector<Value> capturedValues;
1273 getUsedValuesDefinedAbove(region, region, capturedValues);
1274 for (Value val : capturedValues) {
1275 auto arg = dyn_cast<BlockArgument>(val);
1276 if (!arg)
1277 continue;
1278 if (!argDominates(arg, op))
1279 return false;
1280 }
1281 }
1282 }
1283 return true;
1284}
1285
1286/// Check if any region between an operation and an ancestor block is
1287/// isolated from above. If so, moving the operation out would break
1288/// the isolation semantics.
1289static bool hasIsolatedRegionBetween(Operation *op, Block *ancestorBlock) {
1290 Region *ancestorRegion = ancestorBlock->getParent();
1291
1292 // Walk up from the op's region to find if there's an isolated region
1293 // between the op and the ancestor.
1294 Region *region = op->getParentRegion();
1295 while (region && region != ancestorRegion) {
1296 Operation *parentOp = region->getParentOp();
1297 if (!parentOp)
1298 break;
1299
1300 if (parentOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
1301 return true;
1302
1303 region = parentOp->getParentRegion();
1304 }
1305 return false;
1306}
1307
1309 Operation *op,
1310 Operation *insertionPoint,
1311 DominanceInfo &dominance) {
1312 Block *insertionBlock = insertionPoint->getBlock();
1313
1314 // If `insertionPoint` does not dominate `op`, do nothing.
1315 if (!dominance.properlyDominates(insertionPoint, op)) {
1316 return rewriter.notifyMatchFailure(op,
1317 "insertion point does not dominate op");
1318 }
1319
1320 // Verify we're not crossing an isolated region.
1321 if (hasIsolatedRegionBetween(op, insertionBlock)) {
1322 return rewriter.notifyMatchFailure(
1323 op, "cannot move operation across isolated-from-above region");
1324 }
1325
1326 // Find the backward slice of operation for each `Value` the operation
1327 // depends on. Prune the slice to only include operations not already
1328 // dominated by the `insertionPoint`.
1330 options.inclusive = false;
1331 options.omitUsesFromAbove = false;
1332 // Block arguments cannot be moved; dominance check handles this case.
1333 options.omitBlockArguments = true;
1334 bool dependsOnSideEffectingOp = false;
1335 options.filter = [&](Operation *sliceBoundaryOp) {
1336 // Skip the root op - we're moving its dependencies, not the op itself.
1337 // The root op is filtered out by options.inclusive = false anyway.
1338 if (sliceBoundaryOp == op)
1339 return true;
1340 bool dominated =
1341 dominance.properlyDominates(sliceBoundaryOp, insertionPoint);
1342 // Op is already before insertion point, no need to include in slice.
1343 if (dominated)
1344 return false;
1345 // Op needs to move but is side-effecting - stop traversal early.
1346 if (!isPure(sliceBoundaryOp)) {
1347 dependsOnSideEffectingOp = true;
1348 return false;
1349 }
1350 return true;
1351 };
1353 LogicalResult result = getBackwardSlice(op, &slice, options);
1354 assert(result.succeeded() && "expected a backward slice");
1355 (void)result;
1356
1357 // Check if any operation in the slice is side-effecting.
1358 if (dependsOnSideEffectingOp) {
1359 return rewriter.notifyMatchFailure(
1360 op, "cannot move operation with side-effecting dependencies");
1361 }
1362
1363 // If the slice contains `insertionPoint` cannot move the dependencies.
1364 if (slice.contains(insertionPoint)) {
1365 return rewriter.notifyMatchFailure(
1366 op,
1367 "cannot move dependencies before operation in backward slice of op");
1368 }
1369
1370 // Verify no operation in the slice uses a block argument that wouldn't
1371 // dominate at the new location.
1372 Operation *badOp = nullptr;
1373 if (!blockArgsDominateInsertionPoint(slice, insertionPoint, dominance,
1374 &badOp)) {
1375 return rewriter.notifyMatchFailure(
1376 badOp, "moving op would break dominance for block argument operand");
1377 }
1378
1379 // We should move the slice in topological order, but `getBackwardSlice`
1380 // already does that. So no need to sort again.
1381 for (Operation *op : slice) {
1382 rewriter.moveOpBefore(op, insertionPoint);
1383 }
1384 return success();
1385}
1386
1388 Operation *op,
1389 Operation *insertionPoint) {
1390 DominanceInfo dominance(op);
1391 return moveOperationDependencies(rewriter, op, insertionPoint, dominance);
1392}
1393
1395 ValueRange values,
1396 Operation *insertionPoint,
1397 DominanceInfo &dominance) {
1398 // Remove the values that already dominate the insertion point.
1399 SmallVector<Value> prunedValues;
1400 for (auto value : values) {
1401 if (dominance.properlyDominates(value, insertionPoint))
1402 continue;
1403 // Block arguments are not supported.
1404 if (isa<BlockArgument>(value)) {
1405 return rewriter.notifyMatchFailure(
1406 insertionPoint,
1407 "unsupported case of moving block argument before insertion point");
1408 }
1409
1410 Block *insertionBlock = insertionPoint->getBlock();
1411 Operation *definingOp = value.getDefiningOp();
1412 Block *definingBlock = definingOp->getBlock();
1413
1414 // Verify we're not crossing an isolated region.
1415 if (hasIsolatedRegionBetween(definingOp, insertionBlock)) {
1416 return rewriter.notifyMatchFailure(
1417 insertionPoint,
1418 "cannot move value definition across isolated-from-above region");
1419 }
1420
1421 // Verify the insertion point's block dominates the defining block,
1422 // otherwise we're trying to move "backwards" in the CFG which doesn't
1423 // make sense.
1424 if (!dominance.dominates(insertionBlock, definingBlock)) {
1425 return rewriter.notifyMatchFailure(
1426 insertionPoint,
1427 "insertion point block does not dominate the value's defining "
1428 "block");
1429 }
1430 prunedValues.push_back(value);
1431 }
1432
1433 // Find the backward slice of operation for each `Value` the operation
1434 // depends on. Prune the slice to only include operations not already
1435 // dominated by the `insertionPoint`
1437 options.inclusive = true;
1438 options.omitUsesFromAbove = false;
1439 // Block arguments cannot be moved, so we stop the slice computation there.
1440 // If an op uses a block argument that wouldn't dominate at the new location,
1441 // the dominance check will catch it.
1442 options.omitBlockArguments = true;
1443 bool dependsOnSideEffectingOp = false;
1444 options.filter = [&](Operation *sliceBoundaryOp) {
1445 bool dominated =
1446 dominance.properlyDominates(sliceBoundaryOp, insertionPoint);
1447 // Op is already before insertion point, no need to include in slice.
1448 if (dominated)
1449 return false;
1450 // Op needs to move but is side-effecting - stop traversal early.
1451 if (!isPure(sliceBoundaryOp)) {
1452 dependsOnSideEffectingOp = true;
1453 return false;
1454 }
1455 return true;
1456 };
1458 for (auto value : prunedValues) {
1459 LogicalResult result = getBackwardSlice(value, &slice, options);
1460 assert(result.succeeded() && "expected a backward slice");
1461 (void)result;
1462 }
1463
1464 // Check if any operation in the slice is side-effecting.
1465 if (dependsOnSideEffectingOp) {
1466 return rewriter.notifyMatchFailure(
1467 insertionPoint, "cannot move value definitions with side-effecting "
1468 "operations in the slice");
1469 }
1470
1471 // If the slice contains `insertionPoint` cannot move the dependencies.
1472 if (slice.contains(insertionPoint)) {
1473 return rewriter.notifyMatchFailure(
1474 insertionPoint,
1475 "cannot move dependencies before operation in backward slice of op");
1476 }
1477
1478 // Sort operations topologically. This is needed because we call
1479 // getBackwardSlice multiple times (once per value), and the combined slice
1480 // may not be in topological order when independent subgraphs interleave.
1481 mlir::topologicalSort(slice);
1482
1483 // Verify no operation in the slice uses a block argument that wouldn't
1484 // dominate at the new location.
1485 Operation *badOp = nullptr;
1486 if (!blockArgsDominateInsertionPoint(slice, insertionPoint, dominance,
1487 &badOp)) {
1488 return rewriter.notifyMatchFailure(
1489 badOp, "moving op would break dominance for block argument operand");
1490 }
1491
1492 for (Operation *op : slice)
1493 rewriter.moveOpBefore(op, insertionPoint);
1494 return success();
1495}
1496
1498 ValueRange values,
1499 Operation *insertionPoint) {
1500 DominanceInfo dominance(insertionPoint);
1501 return moveValueDefinitions(rewriter, values, insertionPoint, dominance);
1502}
return success()
static size_t hash(const T &value)
Local helper to compute std::hash for a value.
Definition IRCore.cpp:56
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static llvm::ManagedStatic< PassManagerOptions > options
static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter, Region &region)
Identify identical blocks within the given region and merge them, inserting new block arguments as ne...
static void propagateLiveness(Region &region, LiveMap &liveMap)
static SmallVector< SmallVector< Value, 8 >, 2 > pruneRedundantArguments(const SmallVector< SmallVector< Value, 8 >, 2 > &newArguments, RewriterBase &rewriter, unsigned numOldArguments, Block *block)
Prunes the redundant list of new arguments.
static void processValue(Value value, LiveMap &liveMap)
static bool ableToUpdatePredOperands(Block *block)
Returns true if the predecessor terminators of the given block can not have their operands updated.
static bool blockArgsDominateInsertionPoint(const llvm::SetVector< Operation * > &slice, Operation *insertionPoint, DominanceInfo &dominance, Operation **failingOp=nullptr)
Check if moving operations in the slice before insertionPoint would break dominance due to block argu...
static void eraseTerminatorSuccessorOperands(Operation *terminator, LiveMap &liveMap)
static LogicalResult dropRedundantArguments(RewriterBase &rewriter, Block &block)
If a block's argument is always the same across different invocations, then drop the argument and use...
static bool hasIsolatedRegionBetween(Operation *op, Block *ancestorBlock)
Check if any region between an operation and an ancestor block is isolated from above.
static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap)
static bool isUseSpeciallyKnownDead(OpOperand &use, LiveMap &liveMap)
static LogicalResult deleteDeadness(RewriterBase &rewriter, MutableArrayRef< Region > regions, LiveMap &liveMap)
This class represents an argument of a Block.
Definition Value.h:306
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:318
Block * getOwner() const
Returns the block that owns this argument.
Definition Value.h:315
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType::iterator iterator
Definition Block.h:150
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:154
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
iterator_range< pred_iterator > getPredecessors()
Definition Block.h:250
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
pred_iterator pred_begin()
Definition Block.h:246
SuccessorRange getSuccessors()
Definition Block.h:280
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
BlockArgListType getArguments()
Definition Block.h:97
PredecessorIterator pred_iterator
Definition Block.h:245
iterator end()
Definition Block.h:154
iterator begin()
Definition Block.h:153
void eraseArgument(unsigned index)
Erase the argument at 'index' and remove it from the argument list.
Definition Block.cpp:198
pred_iterator pred_end()
Definition Block.h:249
A class for computing basic dominance information.
Definition Dominance.h:140
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition Dominance.h:158
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
IRValueT get() const
Return the current value being used by this operand.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
This class represents an operand of an operation.
Definition Value.h:254
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
Definition Value.h:454
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition Operation.h:1143
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:775
unsigned getNumSuccessors()
Definition Operation.h:732
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:231
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:409
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:703
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:823
Block * getSuccessor(unsigned index)
Definition Operation.h:734
SuccessorRange getSuccessors()
Definition Operation.h:729
result_range getResults()
Definition Operation.h:441
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:248
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition Region.cpp:45
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition Region.h:233
iterator_range< OpIterator > getOps()
Definition Region.h:172
bool empty()
Definition Region.h:60
iterator begin()
Definition Region.h:55
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
BlockListType & getBlocks()
Definition Region.h:45
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
Definition Region.h:296
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
void replaceOpUsesWithIf(Operation *from, ValueRange to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class models how operands are forwarded to block arguments in control flow.
void erase(unsigned subStart, unsigned subLen=1)
Erase operands forwarded to the successor.
bool isOperandProduced(unsigned index) const
Returns true if the successor operand denoted by index is produced by the operation.
unsigned getProducedOperandCount() const
Returns the amount of operands that are produced internally by the operation.
unsigned size() const
Returns the amount of operands passed to the successor.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
Include the generated interface declarations.
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region &region)
Replace all uses of orig within the given region with replacement.
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
bool computeTopologicalSorting(MutableArrayRef< Operation * > ops, function_ref< bool(Value, Operation *)> isOperandReady=nullptr)
Compute a topological ordering of the given ops.
LogicalResult eraseUnreachableBlocks(RewriterBase &rewriter, MutableArrayRef< Region > regions, bool recurse=true)
Erase the unreachable blocks within the provided regions.
LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op, Operation *insertionPoint, DominanceInfo &dominance)
Move the operation dependencies (producers) of op before insertionPoint, so that op itself can subseq...
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:122
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:125
SmallVector< Value > makeRegionIsolatedFromAbove(RewriterBase &rewriter, Region &region, llvm::function_ref< bool(Operation *)> cloneOperationIntoRegion=[](Operation *) { return false;})
Make a region isolated from above.
bool eliminateTriviallyDeadOps(RewriterBase &rewriter, Region &region, bool includeNestedRegions=true)
Remove trivially dead operations from region.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
LogicalResult runRegionDCE(RewriterBase &rewriter, MutableArrayRef< Region > regions)
This function returns success if any operations or arguments were deleted, failure otherwise.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
LogicalResult simplifyRegions(RewriterBase &rewriter, MutableArrayRef< Region > regions, bool mergeBlocks=true)
Run a set of structural simplifications over the given regions.
LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values, Operation *insertionPoint, DominanceInfo &dominance)
Move definitions of values (and their transitive dependencies) before insertionPoint.
void visitUsedValuesDefinedAbove(Region &region, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
static llvm::hash_code ignoreHashValue(Value)
Helper that can be used with computeHash above to ignore operation operands/result mapping.
static bool isEquivalentTo(Operation *lhs, Operation *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent=nullptr, Flags flags=Flags::None, function_ref< LogicalResult(ValueRange, ValueRange)> checkCommutativeEquivalent=nullptr)
Compare two operations (including their regions) and return if they are equivalent.
static LogicalResult ignoreValueEquivalence(Value lhs, Value rhs)
Helper that can be used with isEquivalentTo above to consider ops equivalent even if their operands a...
static llvm::hash_code computeHash(Operation *op, function_ref< llvm::hash_code(Value)> hashOperands=[](Value v) { return hash_value(v);}, function_ref< llvm::hash_code(Value)> hashResults=[](Value v) { return hash_value(v);}, Flags flags=Flags::None)
Compute a hash for the given operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.