MLIR 23.0.0git
RegionUtils.cpp
Go to the documentation of this file.
1//===- RegionUtils.cpp - Region-related transformation utilities ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
13#include "mlir/IR/Block.h"
14#include "mlir/IR/Dominance.h"
15#include "mlir/IR/IRMapping.h"
16#include "mlir/IR/Operation.h"
18#include "mlir/IR/Value.h"
22#include "llvm/ADT/DepthFirstIterator.h"
23#include "llvm/ADT/PostOrderIterator.h"
24#include "llvm/ADT/STLExtras.h"
25#include "llvm/ADT/SmallVectorExtras.h"
26#include "llvm/Support/DebugLog.h"
27
28#include <deque>
29#include <iterator>
30
31using namespace mlir;
32
33#define DEBUG_TYPE "region-utils"
34
36 Region &region) {
37 for (auto &use : llvm::make_early_inc_range(orig.getUses())) {
38 if (region.isAncestor(use.getOwner()->getParentRegion()))
39 use.set(replacement);
40 }
41}
42
44 Region &region, Region &limit, function_ref<void(OpOperand *)> callback) {
45 assert(limit.isAncestor(&region) &&
46 "expected isolation limit to be an ancestor of the given region");
47
48 // Collect proper ancestors of `limit` upfront to avoid traversing the region
49 // tree for every value.
50 SmallPtrSet<Region *, 4> properAncestors;
51 for (auto *reg = limit.getParentRegion(); reg != nullptr;
52 reg = reg->getParentRegion()) {
53 properAncestors.insert(reg);
54 }
55
56 region.walk([callback, &properAncestors](Operation *op) {
57 for (OpOperand &operand : op->getOpOperands())
58 // Callback on values defined in a proper ancestor of region.
59 if (properAncestors.count(operand.get().getParentRegion()))
60 callback(&operand);
61 });
62}
63
65 MutableArrayRef<Region> regions, function_ref<void(OpOperand *)> callback) {
66 for (Region &region : regions)
67 visitUsedValuesDefinedAbove(region, region, callback);
68}
69
71 SetVector<Value> &values) {
72 visitUsedValuesDefinedAbove(region, limit, [&](OpOperand *operand) {
73 values.insert(operand->get());
74 });
75}
76
78 SetVector<Value> &values) {
79 for (Region &region : regions)
80 getUsedValuesDefinedAbove(region, region, values);
81}
82
83//===----------------------------------------------------------------------===//
84// Make block isolated from above.
85//===----------------------------------------------------------------------===//
86
88 RewriterBase &rewriter, Region &region,
89 llvm::function_ref<bool(Operation *)> cloneOperationIntoRegion) {
90
91 // Get initial list of values used within region but defined above.
92 llvm::SetVector<Value> initialCapturedValues;
93 mlir::getUsedValuesDefinedAbove(region, initialCapturedValues);
94
95 std::deque<Value> worklist(initialCapturedValues.begin(),
96 initialCapturedValues.end());
99
100 llvm::SetVector<Value> finalCapturedValues;
101 SmallVector<Operation *> clonedOperations;
102 while (!worklist.empty()) {
103 Value currValue = worklist.front();
104 worklist.pop_front();
105 if (visited.count(currValue))
106 continue;
107 visited.insert(currValue);
108
109 Operation *definingOp = currValue.getDefiningOp();
110 if (!definingOp || visitedOps.count(definingOp)) {
111 finalCapturedValues.insert(currValue);
112 continue;
113 }
114 visitedOps.insert(definingOp);
115
116 if (!cloneOperationIntoRegion(definingOp)) {
117 // Defining operation isnt cloned, so add the current value to final
118 // captured values list.
119 finalCapturedValues.insert(currValue);
120 continue;
121 }
122
123 // Add all operands of the operation to the worklist and mark the op as to
124 // be cloned.
125 for (Value operand : definingOp->getOperands()) {
126 if (visited.count(operand))
127 continue;
128 worklist.push_back(operand);
129 }
130 clonedOperations.push_back(definingOp);
131 }
132
133 // The operations to be cloned need to be ordered in topological order
134 // so that they can be cloned into the region without violating use-def
135 // chains.
136 mlir::computeTopologicalSorting(clonedOperations);
137
138 OpBuilder::InsertionGuard g(rewriter);
139 // Collect types of existing block
140 Block *entryBlock = &region.front();
141 SmallVector<Type> newArgTypes =
142 llvm::to_vector(entryBlock->getArgumentTypes());
143 SmallVector<Location> newArgLocs = llvm::map_to_vector(
144 entryBlock->getArguments(), [](BlockArgument b) { return b.getLoc(); });
145
146 // Append the types of the captured values.
147 for (auto value : finalCapturedValues) {
148 newArgTypes.push_back(value.getType());
149 newArgLocs.push_back(value.getLoc());
150 }
151
152 // Create a new entry block.
153 Block *newEntryBlock =
154 rewriter.createBlock(&region, region.begin(), newArgTypes, newArgLocs);
155 auto newEntryBlockArgs = newEntryBlock->getArguments();
156
157 // Create a mapping between the captured values and the new arguments added.
158 IRMapping map;
159 auto replaceIfFn = [&](OpOperand &use) {
160 return region.isAncestor(use.getOwner()->getParentRegion());
161 };
162
163 for (auto [arg, capturedVal] :
164 llvm::zip(newEntryBlockArgs.take_back(finalCapturedValues.size()),
165 finalCapturedValues)) {
166 map.map(capturedVal, arg);
167 rewriter.replaceUsesWithIf(capturedVal, arg, replaceIfFn);
168 }
169 rewriter.setInsertionPointToStart(newEntryBlock);
170 for (auto *clonedOp : clonedOperations) {
171 Operation *newOp = rewriter.clone(*clonedOp, map);
172 rewriter.replaceOpUsesWithIf(clonedOp, newOp->getResults(), replaceIfFn);
173 }
174 rewriter.mergeBlocks(
175 entryBlock, newEntryBlock,
176 newEntryBlock->getArguments().take_front(entryBlock->getNumArguments()));
177 return llvm::to_vector(finalCapturedValues);
178}
179
180//===----------------------------------------------------------------------===//
181// Unreachable Block Elimination
182//===----------------------------------------------------------------------===//
183
184/// Erase the unreachable blocks within the provided regions. Returns success
185/// if any blocks were erased, failure otherwise.
186// TODO: We could likely merge this with the DCE algorithm below.
188 MutableArrayRef<Region> regions) {
189 LDBG() << "Starting eraseUnreachableBlocks with " << regions.size()
190 << " regions";
191
192 // Set of blocks found to be reachable within a given region.
193 llvm::df_iterator_default_set<Block *, 16> reachable;
194 // If any blocks were found to be dead.
195 int erasedDeadBlocks = 0;
196
198 worklist.reserve(regions.size());
199 for (Region &region : regions)
200 worklist.push_back(&region);
201
202 LDBG(2) << "Initial worklist size: " << worklist.size();
203
204 while (!worklist.empty()) {
205 Region *region = worklist.pop_back_val();
206 if (region->empty()) {
207 LDBG(2) << "Skipping empty region";
208 continue;
209 }
210
211 LDBG(2) << "Processing region with " << region->getBlocks().size()
212 << " blocks";
213 if (region->getParentOp())
214 LDBG(2) << " -> for operation: "
215 << OpWithFlags(region->getParentOp(),
216 OpPrintingFlags().skipRegions());
217
218 // If this is a single block region, just collect the nested regions.
219 if (region->hasOneBlock()) {
220 for (Operation &op : region->front())
221 for (Region &region : op.getRegions())
222 worklist.push_back(&region);
223 continue;
224 }
225
226 // Mark all reachable blocks.
227 reachable.clear();
228 for (Block *block : depth_first_ext(&region->front(), reachable))
229 (void)block /* Mark all reachable blocks */;
230
231 LDBG(2) << "Found " << reachable.size() << " reachable blocks out of "
232 << region->getBlocks().size() << " total blocks";
233
234 // Collect all of the dead blocks and push the live regions onto the
235 // worklist.
236 for (Block &block : llvm::make_early_inc_range(*region)) {
237 if (!reachable.count(&block)) {
238 LDBG() << "Erasing unreachable block: " << &block;
239 block.dropAllDefinedValueUses();
240 rewriter.eraseBlock(&block);
241 ++erasedDeadBlocks;
242 continue;
243 }
244
245 // Walk any regions within this block.
246 for (Operation &op : block)
247 for (Region &region : op.getRegions())
248 worklist.push_back(&region);
249 }
250 }
251
252 LDBG() << "Finished eraseUnreachableBlocks, erased " << erasedDeadBlocks
253 << " dead blocks";
254
255 return success(erasedDeadBlocks > 0);
256}
257
258//===----------------------------------------------------------------------===//
259// Dead Code Elimination
260//===----------------------------------------------------------------------===//
261
262namespace {
263/// Data structure used to track which values have already been proved live.
264///
265/// Because Operation's can have multiple results, this data structure tracks
266/// liveness for both Value's and Operation's to avoid having to look through
267/// all Operation results when analyzing a use.
268///
269/// This data structure essentially tracks the dataflow lattice.
270/// The set of values/ops proved live increases monotonically to a fixed-point.
271class LiveMap {
272public:
273 /// Value methods.
274 bool wasProvenLive(Value value) {
275 // TODO: For results that are removable, e.g. for region based control flow,
276 // we could allow for these values to be tracked independently.
277 if (OpResult result = dyn_cast<OpResult>(value))
278 return wasProvenLive(result.getOwner());
279 return wasProvenLive(cast<BlockArgument>(value));
280 }
281 bool wasProvenLive(BlockArgument arg) { return liveValues.count(arg); }
282 void setProvedLive(Value value) {
283 // TODO: For results that are removable, e.g. for region based control flow,
284 // we could allow for these values to be tracked independently.
285 if (OpResult result = dyn_cast<OpResult>(value))
286 return setProvedLive(result.getOwner());
287 setProvedLive(cast<BlockArgument>(value));
288 }
289 void setProvedLive(BlockArgument arg) {
290 changed |= liveValues.insert(arg).second;
291 }
292
293 /// Operation methods.
294 bool wasProvenLive(Operation *op) { return liveOps.count(op); }
295 void setProvedLive(Operation *op) { changed |= liveOps.insert(op).second; }
296
297 /// Methods for tracking if we have reached a fixed-point.
298 void resetChanged() { changed = false; }
299 bool hasChanged() { return changed; }
300
301private:
302 bool changed = false;
303 DenseSet<Value> liveValues;
304 DenseSet<Operation *> liveOps;
305};
306} // namespace
307
308static bool isUseSpeciallyKnownDead(OpOperand &use, LiveMap &liveMap) {
309 Operation *owner = use.getOwner();
310 unsigned operandIndex = use.getOperandNumber();
311 // This pass generally treats all uses of an op as live if the op itself is
312 // considered live. However, for successor operands to terminators we need a
313 // finer-grained notion where we deduce liveness for operands individually.
314 // The reason for this is easiest to think about in terms of a classical phi
315 // node based SSA IR, where each successor operand is really an operand to a
316 // *separate* phi node, rather than all operands to the branch itself as with
317 // the block argument representation that MLIR uses.
318 //
319 // And similarly, because each successor operand is really an operand to a phi
320 // node, rather than to the terminator op itself, a terminator op can't e.g.
321 // "print" the value of a successor operand.
322 if (owner->hasTrait<OpTrait::IsTerminator>()) {
323 if (BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(owner))
324 if (auto arg = branchInterface.getSuccessorBlockArgument(operandIndex))
325 return !liveMap.wasProvenLive(*arg);
326 return false;
327 }
328 return false;
329}
330
331static void processValue(Value value, LiveMap &liveMap) {
332 bool provedLive = llvm::any_of(value.getUses(), [&](OpOperand &use) {
333 if (isUseSpeciallyKnownDead(use, liveMap))
334 return false;
335 return liveMap.wasProvenLive(use.getOwner());
336 });
337 if (provedLive)
338 liveMap.setProvedLive(value);
339}
340
341static void propagateLiveness(Region &region, LiveMap &liveMap);
342
343static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) {
344 // Terminators are always live.
345 liveMap.setProvedLive(op);
346
347 // Check to see if we can reason about the successor operands and mutate them.
348 BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(op);
349 if (!branchInterface) {
350 for (Block *successor : op->getSuccessors())
351 for (BlockArgument arg : successor->getArguments())
352 liveMap.setProvedLive(arg);
353 return;
354 }
355
356 // If we can't reason about the operand to a successor, conservatively mark
357 // it as live.
358 for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) {
359 SuccessorOperands successorOperands =
360 branchInterface.getSuccessorOperands(i);
361 for (unsigned opI = 0, opE = successorOperands.getProducedOperandCount();
362 opI != opE; ++opI)
363 liveMap.setProvedLive(op->getSuccessor(i)->getArgument(opI));
364 }
365}
366
367static void propagateLiveness(Operation *op, LiveMap &liveMap) {
368 // Recurse on any regions the op has.
369 for (Region &region : op->getRegions())
370 propagateLiveness(region, liveMap);
371
372 // Process terminator operations.
374 return propagateTerminatorLiveness(op, liveMap);
375
376 // Don't reprocess live operations.
377 if (liveMap.wasProvenLive(op))
378 return;
379
380 // Process the op itself.
381 if (!wouldOpBeTriviallyDead(op))
382 return liveMap.setProvedLive(op);
383
384 // If the op isn't intrinsically alive, check it's results.
385 for (Value value : op->getResults())
386 processValue(value, liveMap);
387}
388
389static void propagateLiveness(Region &region, LiveMap &liveMap) {
390 if (region.empty())
391 return;
392
393 for (Block *block : llvm::post_order(&region.front())) {
394 // We process block arguments after the ops in the block, to promote
395 // faster convergence to a fixed point (we try to visit uses before defs).
396 for (Operation &op : llvm::reverse(block->getOperations()))
397 propagateLiveness(&op, liveMap);
398
399 // We currently do not remove entry block arguments, so there is no need to
400 // track their liveness.
401 // TODO: We could track these and enable removing dead operands/arguments
402 // from region control flow operations.
403 if (block->isEntryBlock())
404 continue;
405
406 for (Value value : block->getArguments()) {
407 if (!liveMap.wasProvenLive(value))
408 processValue(value, liveMap);
409 }
410 }
411}
412
414 LiveMap &liveMap) {
415 BranchOpInterface branchOp = dyn_cast<BranchOpInterface>(terminator);
416 if (!branchOp)
417 return;
418
419 for (unsigned succI = 0, succE = terminator->getNumSuccessors();
420 succI < succE; succI++) {
421 // Iterating successors in reverse is not strictly needed, since we
422 // aren't erasing any successors. But it is slightly more efficient
423 // since it will promote later operands of the terminator being erased
424 // first, reducing the quadratic-ness.
425 unsigned succ = succE - succI - 1;
426 SuccessorOperands succOperands = branchOp.getSuccessorOperands(succ);
427 Block *successor = terminator->getSuccessor(succ);
428
429 for (unsigned argI = 0, argE = succOperands.size(); argI < argE; ++argI) {
430 // Iterating args in reverse is needed for correctness, to avoid
431 // shifting later args when earlier args are erased.
432 unsigned arg = argE - argI - 1;
433 if (!liveMap.wasProvenLive(successor->getArgument(arg)))
434 succOperands.erase(arg);
435 }
436 }
437}
438
439static LogicalResult deleteDeadness(RewriterBase &rewriter,
441 LiveMap &liveMap) {
442 bool erasedAnything = false;
443 for (Region &region : regions) {
444 if (region.empty())
445 continue;
446 bool hasSingleBlock = region.hasOneBlock();
447
448 // Delete every operation that is not live. Graph regions may have cycles
449 // in the use-def graph, so we must explicitly dropAllUses() from each
450 // operation as we erase it. Visiting the operations in post-order
451 // guarantees that in SSA CFG regions value uses are removed before defs,
452 // which makes dropAllUses() a no-op.
453 for (Block *block : llvm::post_order(&region.front())) {
454 if (!hasSingleBlock)
455 eraseTerminatorSuccessorOperands(block->getTerminator(), liveMap);
456 for (Operation &childOp :
457 llvm::make_early_inc_range(llvm::reverse(block->getOperations()))) {
458 if (!liveMap.wasProvenLive(&childOp)) {
459 erasedAnything = true;
460 childOp.dropAllUses();
461 rewriter.eraseOp(&childOp);
462 } else {
463 erasedAnything |= succeeded(
464 deleteDeadness(rewriter, childOp.getRegions(), liveMap));
465 }
466 }
467 }
468 // Delete block arguments.
469 // The entry block has an unknown contract with their enclosing block, so
470 // skip it.
471 for (Block &block : llvm::drop_begin(region.getBlocks(), 1)) {
472 block.eraseArguments(
473 [&](BlockArgument arg) { return !liveMap.wasProvenLive(arg); });
474 }
475 }
476 return success(erasedAnything);
477}
478
479// This function performs a simple dead code elimination algorithm over the
480// given regions.
481//
482// The overall goal is to prove that Values are dead, which allows deleting ops
483// and block arguments.
484//
485// This uses an optimistic algorithm that assumes everything is dead until
486// proved otherwise, allowing it to delete recursively dead cycles.
487//
488// This is a simple fixed-point dataflow analysis algorithm on a lattice
489// {Dead,Alive}. Because liveness flows backward, we generally try to
490// iterate everything backward to speed up convergence to the fixed-point. This
491// allows for being able to delete recursively dead cycles of the use-def graph,
492// including block arguments.
493//
494// This function returns success if any operations or arguments were deleted,
495// failure otherwise.
496LogicalResult mlir::runRegionDCE(RewriterBase &rewriter,
497 MutableArrayRef<Region> regions) {
498 LiveMap liveMap;
499 do {
500 liveMap.resetChanged();
501
502 for (Region &region : regions)
503 propagateLiveness(region, liveMap);
504 } while (liveMap.hasChanged());
505
506 return deleteDeadness(rewriter, regions, liveMap);
507}
508
509//===----------------------------------------------------------------------===//
510// Block Merging
511//===----------------------------------------------------------------------===//
512
513//===----------------------------------------------------------------------===//
514// BlockEquivalenceData
515//===----------------------------------------------------------------------===//
516
517namespace {
518/// This class contains the information for comparing the equivalencies of two
519/// blocks. Blocks are considered equivalent if they contain the same operations
520/// in the same order. The only allowed divergence is for operands that come
521/// from sources outside of the parent block, i.e. the uses of values produced
522/// within the block must be equivalent.
523/// e.g.,
524/// Equivalent:
525/// ^bb1(%arg0: i32)
526/// return %arg0, %foo : i32, i32
527/// ^bb2(%arg1: i32)
528/// return %arg1, %bar : i32, i32
529/// Not Equivalent:
530/// ^bb1(%arg0: i32)
531/// return %foo, %arg0 : i32, i32
532/// ^bb2(%arg1: i32)
533/// return %arg1, %bar : i32, i32
534struct BlockEquivalenceData {
535 BlockEquivalenceData(Block *block);
536
537 /// Return the order index for the given value that is within the block of
538 /// this data.
539 unsigned getOrderOf(Value value) const;
540
541 /// The block this data refers to.
542 Block *block;
543 /// A hash value for this block.
544 llvm::hash_code hash;
545 /// A map of result producing operations to their relative orders within this
546 /// block. The order of an operation is the number of defined values that are
547 /// produced within the block before this operation.
549};
550} // namespace
551
552BlockEquivalenceData::BlockEquivalenceData(Block *block)
553 : block(block), hash(0) {
554 unsigned orderIt = block->getNumArguments();
555 for (Operation &op : *block) {
556 if (unsigned numResults = op.getNumResults()) {
557 opOrderIndex.try_emplace(&op, orderIt);
558 orderIt += numResults;
559 }
564 hash = llvm::hash_combine(hash, opHash);
565 }
566}
567
568unsigned BlockEquivalenceData::getOrderOf(Value value) const {
569 assert(value.getParentBlock() == block && "expected value of this block");
570
571 // Arguments use the argument number as the order index.
572 if (BlockArgument arg = dyn_cast<BlockArgument>(value))
573 return arg.getArgNumber();
574
575 // Otherwise, the result order is offset from the parent op's order.
576 OpResult result = cast<OpResult>(value);
577 auto opOrderIt = opOrderIndex.find(result.getDefiningOp());
578 assert(opOrderIt != opOrderIndex.end() && "expected op to have an order");
579 return opOrderIt->second + result.getResultNumber();
580}
581
582//===----------------------------------------------------------------------===//
583// BlockMergeCluster
584//===----------------------------------------------------------------------===//
585
586namespace {
587/// This class represents a cluster of blocks to be merged together.
588class BlockMergeCluster {
589public:
590 BlockMergeCluster(BlockEquivalenceData &&leaderData)
591 : leaderData(std::move(leaderData)) {}
592
593 /// Attempt to add the given block to this cluster. Returns success if the
594 /// block was merged, failure otherwise.
595 LogicalResult addToCluster(BlockEquivalenceData &blockData);
596
597 /// Try to merge all of the blocks within this cluster into the leader block.
598 LogicalResult merge(RewriterBase &rewriter);
599
600private:
601 /// The equivalence data for the leader of the cluster.
602 BlockEquivalenceData leaderData;
603
604 /// The set of blocks that can be merged into the leader.
605 llvm::SmallSetVector<Block *, 1> blocksToMerge;
606
607 /// A set of operand+index pairs that correspond to operands that need to be
608 /// replaced by arguments when the cluster gets merged.
609 std::set<std::pair<int, int>> operandsToMerge;
610};
611} // namespace
612
613LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) {
614 if (leaderData.hash != blockData.hash)
615 return failure();
616 Block *leaderBlock = leaderData.block, *mergeBlock = blockData.block;
617 if (leaderBlock->getArgumentTypes() != mergeBlock->getArgumentTypes())
618 return failure();
619
620 // A set of operands that mismatch between the leader and the new block.
621 SmallVector<std::pair<int, int>, 8> mismatchedOperands;
622 auto lhsIt = leaderBlock->begin(), lhsE = leaderBlock->end();
623 auto rhsIt = blockData.block->begin(), rhsE = blockData.block->end();
624 for (int opI = 0; lhsIt != lhsE && rhsIt != rhsE; ++lhsIt, ++rhsIt, ++opI) {
625 // Check that the operations are equivalent.
628 /*markEquivalent=*/nullptr,
629 OperationEquivalence::Flags::IgnoreLocations))
630 return failure();
631
632 // Compare the operands of the two operations. If the operand is within
633 // the block, it must refer to the same operation.
634 auto lhsOperands = lhsIt->getOperands(), rhsOperands = rhsIt->getOperands();
635 for (int operand : llvm::seq<int>(0, lhsIt->getNumOperands())) {
636 Value lhsOperand = lhsOperands[operand];
637 Value rhsOperand = rhsOperands[operand];
638 if (lhsOperand == rhsOperand)
639 continue;
640 // Check that the types of the operands match.
641 if (lhsOperand.getType() != rhsOperand.getType())
642 return failure();
643
644 // Check that these uses are both external, or both internal.
645 bool lhsIsInBlock = lhsOperand.getParentBlock() == leaderBlock;
646 bool rhsIsInBlock = rhsOperand.getParentBlock() == mergeBlock;
647 if (lhsIsInBlock != rhsIsInBlock)
648 return failure();
649 // Let the operands differ if they are defined in a different block. These
650 // will become new arguments if the blocks get merged.
651 if (!lhsIsInBlock) {
652
653 // Check whether the operands aren't the result of an immediate
654 // predecessors terminator. In that case we are not able to use it as a
655 // successor operand when branching to the merged block as it does not
656 // dominate its producing operation.
657 auto isValidSuccessorArg = [](Block *block, Value operand) {
658 if (operand.getDefiningOp() !=
659 operand.getParentBlock()->getTerminator())
660 return true;
661 return !llvm::is_contained(block->getPredecessors(),
662 operand.getParentBlock());
663 };
664
665 if (!isValidSuccessorArg(leaderBlock, lhsOperand) ||
666 !isValidSuccessorArg(mergeBlock, rhsOperand))
667 return failure();
668
669 mismatchedOperands.emplace_back(opI, operand);
670 continue;
671 }
672
673 // Otherwise, these operands must have the same logical order within the
674 // parent block.
675 if (leaderData.getOrderOf(lhsOperand) != blockData.getOrderOf(rhsOperand))
676 return failure();
677 }
678
679 // If the lhs or rhs has external uses, the blocks cannot be merged as the
680 // merged version of this operation will not be either the lhs or rhs
681 // alone (thus semantically incorrect), but some mix dependending on which
682 // block preceeded this.
683 // TODO allow merging of operations when one block does not dominate the
684 // other
685 if (rhsIt->isUsedOutsideOfBlock(mergeBlock) ||
686 lhsIt->isUsedOutsideOfBlock(leaderBlock)) {
687 return failure();
688 }
689 }
690 // Make sure that the block sizes are equivalent.
691 if (lhsIt != lhsE || rhsIt != rhsE)
692 return failure();
693
694 // If we get here, the blocks are equivalent and can be merged.
695 operandsToMerge.insert(mismatchedOperands.begin(), mismatchedOperands.end());
696 blocksToMerge.insert(blockData.block);
697 return success();
698}
699
700/// Returns true if the predecessor terminators of the given block can not have
701/// their operands updated.
702static bool ableToUpdatePredOperands(Block *block) {
703 for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
704 if (!isa<BranchOpInterface>((*it)->getTerminator()))
705 return false;
706 }
707 return true;
708}
709
710/// Prunes the redundant list of new arguments. E.g., if we are passing an
711/// argument list like [x, y, z, x] this would return [x, y, z] and it would
712/// update the `block` (to whom the argument are passed to) accordingly. The new
713/// arguments are passed as arguments at the back of the block, hence we need to
714/// know how many `numOldArguments` were before, in order to correctly replace
715/// the new arguments in the block
717 const SmallVector<SmallVector<Value, 8>, 2> &newArguments,
718 RewriterBase &rewriter, unsigned numOldArguments, Block *block) {
719
720 SmallVector<SmallVector<Value, 8>, 2> newArgumentsPruned(
721 newArguments.size(), SmallVector<Value, 8>());
722
723 if (newArguments.empty())
724 return newArguments;
725
726 // `newArguments` is a 2D array of size `numLists` x `numArgs`
727 unsigned numLists = newArguments.size();
728 unsigned numArgs = newArguments[0].size();
729
730 // Map that for each arg index contains the index that we can use in place of
731 // the original index. E.g., if we have newArgs = [x, y, z, x], we will have
732 // idxToReplacement[3] = 0
733 llvm::DenseMap<unsigned, unsigned> idxToReplacement;
734
735 // This is a useful data structure to track the first appearance of a Value
736 // on a given list of arguments
737 DenseMap<Value, unsigned> firstValueToIdx;
738 for (unsigned j = 0; j < numArgs; ++j) {
739 Value newArg = newArguments[0][j];
740 firstValueToIdx.try_emplace(newArg, j);
741 }
742
743 // Go through the first list of arguments (list 0).
744 for (unsigned j = 0; j < numArgs; ++j) {
745 // Look back to see if there are possible redundancies in list 0. Please
746 // note that we are using a map to annotate when an argument was seen first
747 // to avoid a O(N^2) algorithm. This has the drawback that if we have two
748 // lists like:
749 // list0: [%a, %a, %a]
750 // list1: [%c, %b, %b]
751 // We cannot simplify it, because firstValueToIdx[%a] = 0, but we cannot
752 // point list1[1](==%b) or list1[2](==%b) to list1[0](==%c). However, since
753 // the number of arguments can be potentially unbounded we cannot afford a
754 // O(N^2) algorithm (to search to all the possible pairs) and we need to
755 // accept the trade-off.
756 unsigned k = firstValueToIdx[newArguments[0][j]];
757 if (k == j)
758 continue;
759
760 bool shouldReplaceJ = true;
761 unsigned replacement = k;
762 // If a possible redundancy is found, then scan the other lists: we
763 // can prune the arguments if and only if they are redundant in every
764 // list.
765 for (unsigned i = 1; i < numLists; ++i)
766 shouldReplaceJ =
767 shouldReplaceJ && (newArguments[i][k] == newArguments[i][j]);
768 // Save the replacement.
769 if (shouldReplaceJ)
770 idxToReplacement[j] = replacement;
771 }
772
773 // Populate the pruned argument list.
774 for (unsigned i = 0; i < numLists; ++i)
775 for (unsigned j = 0; j < numArgs; ++j)
776 if (!idxToReplacement.contains(j))
777 newArgumentsPruned[i].push_back(newArguments[i][j]);
778
779 // Replace the block's redundant arguments.
780 SmallVector<unsigned> toErase;
781 for (auto [idx, arg] : llvm::enumerate(block->getArguments())) {
782 if (idxToReplacement.contains(idx)) {
783 Value oldArg = block->getArgument(numOldArguments + idx);
784 Value newArg =
785 block->getArgument(numOldArguments + idxToReplacement[idx]);
786 rewriter.replaceAllUsesWith(oldArg, newArg);
787 toErase.push_back(numOldArguments + idx);
788 }
789 }
790
791 // Erase the block's redundant arguments.
792 for (unsigned idxToErase : llvm::reverse(toErase))
793 block->eraseArgument(idxToErase);
794 return newArgumentsPruned;
795}
796
797LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
798 // Don't consider clusters that don't have blocks to merge.
799 if (blocksToMerge.empty())
800 return failure();
801
802 Block *leaderBlock = leaderData.block;
803 if (!operandsToMerge.empty()) {
804 // If the cluster has operands to merge, verify that the predecessor
805 // terminators of each of the blocks can have their successor operands
806 // updated.
807 // TODO: We could try and sub-partition this cluster if only some blocks
808 // cause the mismatch.
809 if (!ableToUpdatePredOperands(leaderBlock) ||
810 !llvm::all_of(blocksToMerge, ableToUpdatePredOperands))
811 return failure();
812
813 // Collect the iterators for each of the blocks to merge. We will walk all
814 // of the iterators at once to avoid operand index invalidation.
815 SmallVector<Block::iterator, 2> blockIterators;
816 blockIterators.reserve(blocksToMerge.size() + 1);
817 blockIterators.push_back(leaderBlock->begin());
818 for (Block *mergeBlock : blocksToMerge)
819 blockIterators.push_back(mergeBlock->begin());
820
821 // Update each of the predecessor terminators with the new arguments.
822 SmallVector<SmallVector<Value, 8>, 2> newArguments(
823 1 + blocksToMerge.size(),
824 SmallVector<Value, 8>(operandsToMerge.size()));
825 unsigned curOpIndex = 0;
826 unsigned numOldArguments = leaderBlock->getNumArguments();
827 for (const auto &it : llvm::enumerate(operandsToMerge)) {
828 unsigned nextOpOffset = it.value().first - curOpIndex;
829 curOpIndex = it.value().first;
830
831 // Process the operand for each of the block iterators.
832 for (unsigned i = 0, e = blockIterators.size(); i != e; ++i) {
833 Block::iterator &blockIter = blockIterators[i];
834 std::advance(blockIter, nextOpOffset);
835 auto &operand = blockIter->getOpOperand(it.value().second);
836 newArguments[i][it.index()] = operand.get();
837
838 // Update the operand and insert an argument if this is the leader.
839 if (i == 0) {
840 Value operandVal = operand.get();
841 operand.set(leaderBlock->addArgument(operandVal.getType(),
842 operandVal.getLoc()));
843 }
844 }
845 }
846
847 // Prune redundant arguments and update the leader block argument list
848 newArguments = pruneRedundantArguments(newArguments, rewriter,
849 numOldArguments, leaderBlock);
850
851 // Update the predecessors for each of the blocks.
852 auto updatePredecessors = [&](Block *block, unsigned clusterIndex) {
853 for (auto predIt = block->pred_begin(), predE = block->pred_end();
854 predIt != predE; ++predIt) {
855 auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
856 unsigned succIndex = predIt.getSuccessorIndex();
857 branch.getSuccessorOperands(succIndex).append(
858 newArguments[clusterIndex]);
859 }
860 };
861 updatePredecessors(leaderBlock, /*clusterIndex=*/0);
862 for (unsigned i = 0, e = blocksToMerge.size(); i != e; ++i)
863 updatePredecessors(blocksToMerge[i], /*clusterIndex=*/i + 1);
864 }
865
866 // Replace all uses of the merged blocks with the leader and erase them.
867 for (Block *block : blocksToMerge) {
868 block->replaceAllUsesWith(leaderBlock);
869 rewriter.eraseBlock(block);
870 }
871 return success();
872}
873
874/// Identify identical blocks within the given region and merge them, inserting
875/// new block arguments as necessary. Returns success if any blocks were merged,
876/// failure otherwise.
877static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
878 Region &region) {
879 if (region.empty() || region.hasOneBlock())
880 return failure();
881
882 // Identify sets of blocks, other than the entry block, that branch to the
883 // same successors. We will use these groups to create clusters of equivalent
884 // blocks.
886 for (Block &block : llvm::drop_begin(region, 1))
887 matchingSuccessors[block.getSuccessors()].push_back(&block);
888
889 bool mergedAnyBlocks = false;
890 for (ArrayRef<Block *> blocks : llvm::make_second_range(matchingSuccessors)) {
891 if (blocks.size() == 1)
892 continue;
893
895 for (Block *block : blocks) {
896 BlockEquivalenceData data(block);
897
898 // Don't allow merging if this block has any regions.
899 // TODO: Add support for regions if necessary.
900 bool hasNonEmptyRegion = llvm::any_of(*block, [](Operation &op) {
901 return llvm::any_of(op.getRegions(),
902 [](Region &region) { return !region.empty(); });
903 });
904 if (hasNonEmptyRegion)
905 continue;
906
907 // Don't allow merging if this block's arguments are used outside of the
908 // original block.
909 bool argHasExternalUsers = llvm::any_of(
910 block->getArguments(), [block](mlir::BlockArgument &arg) {
911 return arg.isUsedOutsideOfBlock(block);
912 });
913 if (argHasExternalUsers)
914 continue;
915
916 // Try to add this block to an existing cluster.
917 bool addedToCluster = false;
918 for (auto &cluster : clusters)
919 if ((addedToCluster = succeeded(cluster.addToCluster(data))))
920 break;
921 if (!addedToCluster)
922 clusters.emplace_back(std::move(data));
923 }
924 for (auto &cluster : clusters)
925 mergedAnyBlocks |= succeeded(cluster.merge(rewriter));
926 }
927
928 return success(mergedAnyBlocks);
929}
930
931/// Identify identical blocks within the given regions and merge them, inserting
932/// new block arguments as necessary.
933static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
934 MutableArrayRef<Region> regions) {
935 llvm::SmallSetVector<Region *, 1> worklist;
936 for (auto &region : regions)
937 worklist.insert(&region);
938 bool anyChanged = false;
939 while (!worklist.empty()) {
940 Region *region = worklist.pop_back_val();
941 if (succeeded(mergeIdenticalBlocks(rewriter, *region))) {
942 worklist.insert(region);
943 anyChanged = true;
944 }
945
946 // Add any nested regions to the worklist.
947 for (Block &block : *region)
948 for (auto &op : block)
949 for (auto &nestedRegion : op.getRegions())
950 worklist.insert(&nestedRegion);
951 }
952
953 return success(anyChanged);
954}
955
956/// If a block's argument is always the same across different invocations, then
957/// drop the argument and use the value directly inside the block
958static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
959 Block &block) {
960 SmallVector<size_t> argsToErase;
961
962 // Go through the arguments of the block.
963 for (auto [argIdx, blockOperand] : llvm::enumerate(block.getArguments())) {
964 bool sameArg = true;
965 Value commonValue;
966
967 // Go through the block predecessor and flag if they pass to the block
968 // different values for the same argument.
969 for (Block::pred_iterator predIt = block.pred_begin(),
970 predE = block.pred_end();
971 predIt != predE; ++predIt) {
972 auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator());
973 if (!branch) {
974 sameArg = false;
975 break;
976 }
977 unsigned succIndex = predIt.getSuccessorIndex();
978 SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
979
980 // Produced operands are generated by the terminator operation itself
981 // (e.g., results of an async call) and cannot be forwarded or dropped.
982 if (succOperands.isOperandProduced(argIdx)) {
983 sameArg = false;
984 break;
985 }
986
987 // Get the forwarded operand value using operator[] which correctly
988 // adjusts for the produced operand offset.
989 Value operandValue = succOperands[argIdx];
990 if (!commonValue) {
991 commonValue = operandValue;
992 continue;
993 }
994 if (operandValue != commonValue) {
995 sameArg = false;
996 break;
997 }
998 }
999
1000 // If they are passing the same value, drop the argument.
1001 if (commonValue && sameArg) {
1002 argsToErase.push_back(argIdx);
1003
1004 // Remove the argument from the block.
1005 rewriter.replaceAllUsesWith(blockOperand, commonValue);
1006 }
1007 }
1008
1009 // Remove the arguments.
1010 for (size_t argIdx : llvm::reverse(argsToErase)) {
1011 block.eraseArgument(argIdx);
1012
1013 // Remove the argument from the branch ops.
1014 for (auto predIt = block.pred_begin(), predE = block.pred_end();
1015 predIt != predE; ++predIt) {
1016 auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
1017 unsigned succIndex = predIt.getSuccessorIndex();
1018 SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
1019 succOperands.erase(argIdx);
1020 }
1021 }
1022 return success(!argsToErase.empty());
1023}
1024
1025/// This optimization drops redundant argument to blocks. I.e., if a given
1026/// argument to a block receives the same value from each of the block
1027/// predecessors, we can remove the argument from the block and use directly the
1028/// original value. This is a simple example:
1029///
1030/// %cond = llvm.call @rand() : () -> i1
1031/// %val0 = llvm.mlir.constant(1 : i64) : i64
1032/// %val1 = llvm.mlir.constant(2 : i64) : i64
1033/// %val2 = llvm.mlir.constant(3 : i64) : i64
1034/// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
1035/// : i64)
1036///
1037/// ^bb1(%arg0 : i64, %arg1 : i64):
1038/// llvm.call @foo(%arg0, %arg1)
1039///
1040/// The previous IR can be rewritten as:
1041/// %cond = llvm.call @rand() : () -> i1
1042/// %val0 = llvm.mlir.constant(1 : i64) : i64
1043/// %val1 = llvm.mlir.constant(2 : i64) : i64
1044/// %val2 = llvm.mlir.constant(3 : i64) : i64
1045/// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
1046///
1047/// ^bb1(%arg0 : i64):
1048/// llvm.call @foo(%val0, %arg0)
1049///
1050static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
1051 MutableArrayRef<Region> regions) {
1052 llvm::SmallSetVector<Region *, 1> worklist;
1053 for (Region &region : regions)
1054 worklist.insert(&region);
1055 bool anyChanged = false;
1056 while (!worklist.empty()) {
1057 Region *region = worklist.pop_back_val();
1058
1059 // Add any nested regions to the worklist.
1060 for (Block &block : *region) {
1061 anyChanged =
1062 succeeded(dropRedundantArguments(rewriter, block)) || anyChanged;
1063
1064 for (Operation &op : block)
1065 for (Region &nestedRegion : op.getRegions())
1066 worklist.insert(&nestedRegion);
1067 }
1068 }
1069 return success(anyChanged);
1070}
1071
1072//===----------------------------------------------------------------------===//
1073// Region Simplification
1074//===----------------------------------------------------------------------===//
1075
1076/// Run a set of structural simplifications over the given regions. This
1077/// includes transformations like unreachable block elimination, dead argument
1078/// elimination, as well as some other DCE. This function returns success if any
1079/// of the regions were simplified, failure otherwise.
1080LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
1082 bool mergeBlocks) {
1083 bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
1084 bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
1085 bool mergedIdenticalBlocks = false;
1086 bool droppedRedundantArguments = false;
1087 if (mergeBlocks) {
1088 mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions));
1089 droppedRedundantArguments =
1090 succeeded(dropRedundantArguments(rewriter, regions));
1091 }
1092 return success(eliminatedBlocks || eliminatedOpsOrArgs ||
1093 mergedIdenticalBlocks || droppedRedundantArguments);
1094}
1095
1096//===---------------------------------------------------------------------===//
1097// Move operation dependencies
1098//===---------------------------------------------------------------------===//
1099
1100/// Check if moving operations in the slice before `insertionPoint` would break
1101/// dominance due to block argument operands. Returns true if all block args
1102/// dominate the insertion point (no issue), false otherwise. If `failingOp` is
1103/// provided, it will be set to the first problematic op.
1104///
1105/// For operands defined by ops: either the defining op is in the slice (so
1106/// dominance preserved), or it already dominates insertionPoint (otherwise it
1107/// would be in the slice). So we only need to check block argument operands,
1108/// both as direct operands and as values captured inside regions.
1110 const llvm::SetVector<Operation *> &slice, Operation *insertionPoint,
1111 DominanceInfo &dominance, Operation **failingOp = nullptr) {
1112 Block *insertionBlock = insertionPoint->getBlock();
1113
1114 // Returns true if the block arg dominates, false otherwise. Sets failingOp
1115 // on failure.
1116 auto argDominates = [&](BlockArgument arg, Operation *op) {
1117 Block *argBlock = arg.getOwner();
1118 bool dominates = argBlock == insertionBlock ||
1119 dominance.dominates(argBlock, insertionBlock);
1120 if (!dominates && failingOp)
1121 *failingOp = op;
1122 return dominates;
1123 };
1124
1125 for (Operation *op : slice) {
1126 // Check direct operands.
1127 for (Value operand : op->getOperands()) {
1128 auto arg = dyn_cast<BlockArgument>(operand);
1129 if (!arg)
1130 continue;
1131 if (!argDominates(arg, op))
1132 return false;
1133 }
1134
1135 // Check block arguments captured inside regions. Process one region at a
1136 // time to enable early exit without collecting values from all regions.
1137 for (Region &region : op->getRegions()) {
1138 SetVector<Value> capturedValues;
1139 getUsedValuesDefinedAbove(region, region, capturedValues);
1140 for (Value val : capturedValues) {
1141 auto arg = dyn_cast<BlockArgument>(val);
1142 if (!arg)
1143 continue;
1144 if (!argDominates(arg, op))
1145 return false;
1146 }
1147 }
1148 }
1149 return true;
1150}
1151
1152/// Check if any region between an operation and an ancestor block is
1153/// isolated from above. If so, moving the operation out would break
1154/// the isolation semantics.
1155static bool hasIsolatedRegionBetween(Operation *op, Block *ancestorBlock) {
1156 Region *ancestorRegion = ancestorBlock->getParent();
1157
1158 // Walk up from the op's region to find if there's an isolated region
1159 // between the op and the ancestor.
1160 Region *region = op->getParentRegion();
1161 while (region && region != ancestorRegion) {
1162 Operation *parentOp = region->getParentOp();
1163 if (!parentOp)
1164 break;
1165
1166 if (parentOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
1167 return true;
1168
1169 region = parentOp->getParentRegion();
1170 }
1171 return false;
1172}
1173
1175 Operation *op,
1176 Operation *insertionPoint,
1177 DominanceInfo &dominance) {
1178 Block *insertionBlock = insertionPoint->getBlock();
1179
1180 // If `insertionPoint` does not dominate `op`, do nothing.
1181 if (!dominance.properlyDominates(insertionPoint, op)) {
1182 return rewriter.notifyMatchFailure(op,
1183 "insertion point does not dominate op");
1184 }
1185
1186 // Verify we're not crossing an isolated region.
1187 if (hasIsolatedRegionBetween(op, insertionBlock)) {
1188 return rewriter.notifyMatchFailure(
1189 op, "cannot move operation across isolated-from-above region");
1190 }
1191
1192 // Find the backward slice of operation for each `Value` the operation
1193 // depends on. Prune the slice to only include operations not already
1194 // dominated by the `insertionPoint`.
1196 options.inclusive = false;
1197 options.omitUsesFromAbove = false;
1198 // Block arguments cannot be moved; dominance check handles this case.
1199 options.omitBlockArguments = true;
1200 bool dependsOnSideEffectingOp = false;
1201 options.filter = [&](Operation *sliceBoundaryOp) {
1202 // Skip the root op - we're moving its dependencies, not the op itself.
1203 // The root op is filtered out by options.inclusive = false anyway.
1204 if (sliceBoundaryOp == op)
1205 return true;
1206 bool dominated =
1207 dominance.properlyDominates(sliceBoundaryOp, insertionPoint);
1208 // Op is already before insertion point, no need to include in slice.
1209 if (dominated)
1210 return false;
1211 // Op needs to move but is side-effecting - stop traversal early.
1212 if (!isPure(sliceBoundaryOp)) {
1213 dependsOnSideEffectingOp = true;
1214 return false;
1215 }
1216 return true;
1217 };
1219 LogicalResult result = getBackwardSlice(op, &slice, options);
1220 assert(result.succeeded() && "expected a backward slice");
1221 (void)result;
1222
1223 // Check if any operation in the slice is side-effecting.
1224 if (dependsOnSideEffectingOp) {
1225 return rewriter.notifyMatchFailure(
1226 op, "cannot move operation with side-effecting dependencies");
1227 }
1228
1229 // If the slice contains `insertionPoint` cannot move the dependencies.
1230 if (slice.contains(insertionPoint)) {
1231 return rewriter.notifyMatchFailure(
1232 op,
1233 "cannot move dependencies before operation in backward slice of op");
1234 }
1235
1236 // Verify no operation in the slice uses a block argument that wouldn't
1237 // dominate at the new location.
1238 Operation *badOp = nullptr;
1239 if (!blockArgsDominateInsertionPoint(slice, insertionPoint, dominance,
1240 &badOp)) {
1241 return rewriter.notifyMatchFailure(
1242 badOp, "moving op would break dominance for block argument operand");
1243 }
1244
1245 // We should move the slice in topological order, but `getBackwardSlice`
1246 // already does that. So no need to sort again.
1247 for (Operation *op : slice) {
1248 rewriter.moveOpBefore(op, insertionPoint);
1249 }
1250 return success();
1251}
1252
1254 Operation *op,
1255 Operation *insertionPoint) {
1256 DominanceInfo dominance(op);
1257 return moveOperationDependencies(rewriter, op, insertionPoint, dominance);
1258}
1259
1261 ValueRange values,
1262 Operation *insertionPoint,
1263 DominanceInfo &dominance) {
1264 // Remove the values that already dominate the insertion point.
1265 SmallVector<Value> prunedValues;
1266 for (auto value : values) {
1267 if (dominance.properlyDominates(value, insertionPoint))
1268 continue;
1269 // Block arguments are not supported.
1270 if (isa<BlockArgument>(value)) {
1271 return rewriter.notifyMatchFailure(
1272 insertionPoint,
1273 "unsupported case of moving block argument before insertion point");
1274 }
1275
1276 Block *insertionBlock = insertionPoint->getBlock();
1277 Operation *definingOp = value.getDefiningOp();
1278 Block *definingBlock = definingOp->getBlock();
1279
1280 // Verify we're not crossing an isolated region.
1281 if (hasIsolatedRegionBetween(definingOp, insertionBlock)) {
1282 return rewriter.notifyMatchFailure(
1283 insertionPoint,
1284 "cannot move value definition across isolated-from-above region");
1285 }
1286
1287 // Verify the insertion point's block dominates the defining block,
1288 // otherwise we're trying to move "backwards" in the CFG which doesn't
1289 // make sense.
1290 if (!dominance.dominates(insertionBlock, definingBlock)) {
1291 return rewriter.notifyMatchFailure(
1292 insertionPoint,
1293 "insertion point block does not dominate the value's defining "
1294 "block");
1295 }
1296 prunedValues.push_back(value);
1297 }
1298
1299 // Find the backward slice of operation for each `Value` the operation
1300 // depends on. Prune the slice to only include operations not already
1301 // dominated by the `insertionPoint`
1303 options.inclusive = true;
1304 options.omitUsesFromAbove = false;
1305 // Block arguments cannot be moved, so we stop the slice computation there.
1306 // If an op uses a block argument that wouldn't dominate at the new location,
1307 // the dominance check will catch it.
1308 options.omitBlockArguments = true;
1309 bool dependsOnSideEffectingOp = false;
1310 options.filter = [&](Operation *sliceBoundaryOp) {
1311 bool dominated =
1312 dominance.properlyDominates(sliceBoundaryOp, insertionPoint);
1313 // Op is already before insertion point, no need to include in slice.
1314 if (dominated)
1315 return false;
1316 // Op needs to move but is side-effecting - stop traversal early.
1317 if (!isPure(sliceBoundaryOp)) {
1318 dependsOnSideEffectingOp = true;
1319 return false;
1320 }
1321 return true;
1322 };
1324 for (auto value : prunedValues) {
1325 LogicalResult result = getBackwardSlice(value, &slice, options);
1326 assert(result.succeeded() && "expected a backward slice");
1327 (void)result;
1328 }
1329
1330 // Check if any operation in the slice is side-effecting.
1331 if (dependsOnSideEffectingOp) {
1332 return rewriter.notifyMatchFailure(
1333 insertionPoint, "cannot move value definitions with side-effecting "
1334 "operations in the slice");
1335 }
1336
1337 // If the slice contains `insertionPoint` cannot move the dependencies.
1338 if (slice.contains(insertionPoint)) {
1339 return rewriter.notifyMatchFailure(
1340 insertionPoint,
1341 "cannot move dependencies before operation in backward slice of op");
1342 }
1343
1344 // Sort operations topologically. This is needed because we call
1345 // getBackwardSlice multiple times (once per value), and the combined slice
1346 // may not be in topological order when independent subgraphs interleave.
1347 mlir::topologicalSort(slice);
1348
1349 // Verify no operation in the slice uses a block argument that wouldn't
1350 // dominate at the new location.
1351 Operation *badOp = nullptr;
1352 if (!blockArgsDominateInsertionPoint(slice, insertionPoint, dominance,
1353 &badOp)) {
1354 return rewriter.notifyMatchFailure(
1355 badOp, "moving op would break dominance for block argument operand");
1356 }
1357
1358 for (Operation *op : slice)
1359 rewriter.moveOpBefore(op, insertionPoint);
1360 return success();
1361}
1362
1364 ValueRange values,
1365 Operation *insertionPoint) {
1366 DominanceInfo dominance(insertionPoint);
1367 return moveValueDefinitions(rewriter, values, insertionPoint, dominance);
1368}
return success()
static size_t hash(const T &value)
Local helper to compute std::hash for a value.
Definition IRCore.cpp:55
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static llvm::ManagedStatic< PassManagerOptions > options
static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter, Region &region)
Identify identical blocks within the given region and merge them, inserting new block arguments as ne...
static void propagateLiveness(Region &region, LiveMap &liveMap)
static SmallVector< SmallVector< Value, 8 >, 2 > pruneRedundantArguments(const SmallVector< SmallVector< Value, 8 >, 2 > &newArguments, RewriterBase &rewriter, unsigned numOldArguments, Block *block)
Prunes the redundant list of new arguments.
static void processValue(Value value, LiveMap &liveMap)
static bool ableToUpdatePredOperands(Block *block)
Returns true if the predecessor terminators of the given block can not have their operands updated.
static bool blockArgsDominateInsertionPoint(const llvm::SetVector< Operation * > &slice, Operation *insertionPoint, DominanceInfo &dominance, Operation **failingOp=nullptr)
Check if moving operations in the slice before insertionPoint would break dominance due to block argu...
static void eraseTerminatorSuccessorOperands(Operation *terminator, LiveMap &liveMap)
static LogicalResult dropRedundantArguments(RewriterBase &rewriter, Block &block)
If a block's argument is always the same across different invocations, then drop the argument and use...
static bool hasIsolatedRegionBetween(Operation *op, Block *ancestorBlock)
Check if any region between an operation and an ancestor block is isolated from above.
static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap)
static bool isUseSpeciallyKnownDead(OpOperand &use, LiveMap &liveMap)
static LogicalResult deleteDeadness(RewriterBase &rewriter, MutableArrayRef< Region > regions, LiveMap &liveMap)
This class represents an argument of a Block.
Definition Value.h:309
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:321
Block * getOwner() const
Returns the block that owns this argument.
Definition Value.h:318
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType::iterator iterator
Definition Block.h:150
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:154
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
iterator_range< pred_iterator > getPredecessors()
Definition Block.h:250
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
pred_iterator pred_begin()
Definition Block.h:246
SuccessorRange getSuccessors()
Definition Block.h:280
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
BlockArgListType getArguments()
Definition Block.h:97
PredecessorIterator pred_iterator
Definition Block.h:245
iterator end()
Definition Block.h:154
iterator begin()
Definition Block.h:153
void eraseArgument(unsigned index)
Erase the argument at 'index' and remove it from the argument list.
Definition Block.cpp:198
pred_iterator pred_end()
Definition Block.h:249
A class for computing basic dominance information.
Definition Dominance.h:140
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition Dominance.h:158
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
IRValueT get() const
Return the current value being used by this operand.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
Definition Value.h:457
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition Operation.h:1111
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
unsigned getNumSuccessors()
Definition Operation.h:706
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
Block * getSuccessor(unsigned index)
Definition Operation.h:708
SuccessorRange getSuccessors()
Definition Operation.h:703
result_range getResults()
Definition Operation.h:415
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:230
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition Region.cpp:45
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition Region.h:222
bool empty()
Definition Region.h:60
iterator begin()
Definition Region.h:55
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
BlockListType & getBlocks()
Definition Region.h:45
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
Definition Region.h:285
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
void replaceOpUsesWithIf(Operation *from, ValueRange to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class models how operands are forwarded to block arguments in control flow.
void erase(unsigned subStart, unsigned subLen=1)
Erase operands forwarded to the successor.
bool isOperandProduced(unsigned index) const
Returns true if the successor operand denoted by index is produced by the operation.
unsigned getProducedOperandCount() const
Returns the amount of operands that are produced internally by the operation.
unsigned size() const
Returns the amount of operands passed to the successor.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
Include the generated interface declarations.
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region &region)
Replace all uses of orig within the given region with replacement.
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
bool computeTopologicalSorting(MutableArrayRef< Operation * > ops, function_ref< bool(Value, Operation *)> isOperandReady=nullptr)
Compute a topological ordering of the given ops.
LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op, Operation *insertionPoint, DominanceInfo &dominance)
Move the operation dependencies (producers) of op before insertionPoint, so that op itself can subseq...
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:120
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
LogicalResult eraseUnreachableBlocks(RewriterBase &rewriter, MutableArrayRef< Region > regions)
Erase the unreachable blocks within the provided regions.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:123
SmallVector< Value > makeRegionIsolatedFromAbove(RewriterBase &rewriter, Region &region, llvm::function_ref< bool(Operation *)> cloneOperationIntoRegion=[](Operation *) { return false;})
Make a region isolated from above.
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
LogicalResult runRegionDCE(RewriterBase &rewriter, MutableArrayRef< Region > regions)
This function returns success if any operations or arguments were deleted, failure otherwise.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
LogicalResult simplifyRegions(RewriterBase &rewriter, MutableArrayRef< Region > regions, bool mergeBlocks=true)
Run a set of structural simplifications over the given regions.
LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values, Operation *insertionPoint, DominanceInfo &dominance)
Move definitions of values (and their transitive dependencies) before insertionPoint.
void visitUsedValuesDefinedAbove(Region &region, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
static llvm::hash_code ignoreHashValue(Value)
Helper that can be used with computeHash above to ignore operation operands/result mapping.
static bool isEquivalentTo(Operation *lhs, Operation *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent=nullptr, Flags flags=Flags::None, function_ref< LogicalResult(ValueRange, ValueRange)> checkCommutativeEquivalent=nullptr)
Compare two operations (including their regions) and return if they are equivalent.
static LogicalResult ignoreValueEquivalence(Value lhs, Value rhs)
Helper that can be used with isEquivalentTo above to consider ops equivalent even if their operands a...
static llvm::hash_code computeHash(Operation *op, function_ref< llvm::hash_code(Value)> hashOperands=[](Value v) { return hash_value(v);}, function_ref< llvm::hash_code(Value)> hashResults=[](Value v) { return hash_value(v);}, Flags flags=Flags::None)
Compute a hash for the given operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.