MLIR  19.0.0git
RegionUtils.cpp
Go to the documentation of this file.
1 //===- RegionUtils.cpp - Region-related transformation utilities ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/IR/Block.h"
11 #include "mlir/IR/IRMapping.h"
12 #include "mlir/IR/Operation.h"
13 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/IR/Value.h"
19 
20 #include "llvm/ADT/DepthFirstIterator.h"
21 #include "llvm/ADT/PostOrderIterator.h"
22 #include "llvm/ADT/SmallSet.h"
23 
24 #include <deque>
25 
26 using namespace mlir;
27 
29  Region &region) {
30  for (auto &use : llvm::make_early_inc_range(orig.getUses())) {
31  if (region.isAncestor(use.getOwner()->getParentRegion()))
32  use.set(replacement);
33  }
34 }
35 
37  Region &region, Region &limit, function_ref<void(OpOperand *)> callback) {
38  assert(limit.isAncestor(&region) &&
39  "expected isolation limit to be an ancestor of the given region");
40 
41  // Collect proper ancestors of `limit` upfront to avoid traversing the region
42  // tree for every value.
43  SmallPtrSet<Region *, 4> properAncestors;
44  for (auto *reg = limit.getParentRegion(); reg != nullptr;
45  reg = reg->getParentRegion()) {
46  properAncestors.insert(reg);
47  }
48 
49  region.walk([callback, &properAncestors](Operation *op) {
50  for (OpOperand &operand : op->getOpOperands())
51  // Callback on values defined in a proper ancestor of region.
52  if (properAncestors.count(operand.get().getParentRegion()))
53  callback(&operand);
54  });
55 }
56 
58  MutableArrayRef<Region> regions, function_ref<void(OpOperand *)> callback) {
59  for (Region &region : regions)
60  visitUsedValuesDefinedAbove(region, region, callback);
61 }
62 
64  SetVector<Value> &values) {
65  visitUsedValuesDefinedAbove(region, limit, [&](OpOperand *operand) {
66  values.insert(operand->get());
67  });
68 }
69 
71  SetVector<Value> &values) {
72  for (Region &region : regions)
73  getUsedValuesDefinedAbove(region, region, values);
74 }
75 
76 //===----------------------------------------------------------------------===//
77 // Make block isolated from above.
78 //===----------------------------------------------------------------------===//
79 
81  RewriterBase &rewriter, Region &region,
82  llvm::function_ref<bool(Operation *)> cloneOperationIntoRegion) {
83 
84  // Get initial list of values used within region but defined above.
85  llvm::SetVector<Value> initialCapturedValues;
86  mlir::getUsedValuesDefinedAbove(region, initialCapturedValues);
87 
88  std::deque<Value> worklist(initialCapturedValues.begin(),
89  initialCapturedValues.end());
90  llvm::DenseSet<Value> visited;
91  llvm::DenseSet<Operation *> visitedOps;
92 
93  llvm::SetVector<Value> finalCapturedValues;
94  SmallVector<Operation *> clonedOperations;
95  while (!worklist.empty()) {
96  Value currValue = worklist.front();
97  worklist.pop_front();
98  if (visited.count(currValue))
99  continue;
100  visited.insert(currValue);
101 
102  Operation *definingOp = currValue.getDefiningOp();
103  if (!definingOp || visitedOps.count(definingOp)) {
104  finalCapturedValues.insert(currValue);
105  continue;
106  }
107  visitedOps.insert(definingOp);
108 
109  if (!cloneOperationIntoRegion(definingOp)) {
110  // Defining operation isnt cloned, so add the current value to final
111  // captured values list.
112  finalCapturedValues.insert(currValue);
113  continue;
114  }
115 
116  // Add all operands of the operation to the worklist and mark the op as to
117  // be cloned.
118  for (Value operand : definingOp->getOperands()) {
119  if (visited.count(operand))
120  continue;
121  worklist.push_back(operand);
122  }
123  clonedOperations.push_back(definingOp);
124  }
125 
126  // The operations to be cloned need to be ordered in topological order
127  // so that they can be cloned into the region without violating use-def
128  // chains.
129  mlir::computeTopologicalSorting(clonedOperations);
130 
131  OpBuilder::InsertionGuard g(rewriter);
132  // Collect types of existing block
133  Block *entryBlock = &region.front();
134  SmallVector<Type> newArgTypes =
135  llvm::to_vector(entryBlock->getArgumentTypes());
136  SmallVector<Location> newArgLocs = llvm::to_vector(llvm::map_range(
137  entryBlock->getArguments(), [](BlockArgument b) { return b.getLoc(); }));
138 
139  // Append the types of the captured values.
140  for (auto value : finalCapturedValues) {
141  newArgTypes.push_back(value.getType());
142  newArgLocs.push_back(value.getLoc());
143  }
144 
145  // Create a new entry block.
146  Block *newEntryBlock =
147  rewriter.createBlock(&region, region.begin(), newArgTypes, newArgLocs);
148  auto newEntryBlockArgs = newEntryBlock->getArguments();
149 
150  // Create a mapping between the captured values and the new arguments added.
151  IRMapping map;
152  auto replaceIfFn = [&](OpOperand &use) {
153  return use.getOwner()->getBlock()->getParent() == &region;
154  };
155  for (auto [arg, capturedVal] :
156  llvm::zip(newEntryBlockArgs.take_back(finalCapturedValues.size()),
157  finalCapturedValues)) {
158  map.map(capturedVal, arg);
159  rewriter.replaceUsesWithIf(capturedVal, arg, replaceIfFn);
160  }
161  rewriter.setInsertionPointToStart(newEntryBlock);
162  for (auto *clonedOp : clonedOperations) {
163  Operation *newOp = rewriter.clone(*clonedOp, map);
164  rewriter.replaceOpUsesWithIf(clonedOp, newOp->getResults(), replaceIfFn);
165  }
166  rewriter.mergeBlocks(
167  entryBlock, newEntryBlock,
168  newEntryBlock->getArguments().take_front(entryBlock->getNumArguments()));
169  return llvm::to_vector(finalCapturedValues);
170 }
171 
172 //===----------------------------------------------------------------------===//
173 // Unreachable Block Elimination
174 //===----------------------------------------------------------------------===//
175 
176 /// Erase the unreachable blocks within the provided regions. Returns success
177 /// if any blocks were erased, failure otherwise.
178 // TODO: We could likely merge this with the DCE algorithm below.
180  MutableArrayRef<Region> regions) {
181  // Set of blocks found to be reachable within a given region.
182  llvm::df_iterator_default_set<Block *, 16> reachable;
183  // If any blocks were found to be dead.
184  bool erasedDeadBlocks = false;
185 
186  SmallVector<Region *, 1> worklist;
187  worklist.reserve(regions.size());
188  for (Region &region : regions)
189  worklist.push_back(&region);
190  while (!worklist.empty()) {
191  Region *region = worklist.pop_back_val();
192  if (region->empty())
193  continue;
194 
195  // If this is a single block region, just collect the nested regions.
196  if (std::next(region->begin()) == region->end()) {
197  for (Operation &op : region->front())
198  for (Region &region : op.getRegions())
199  worklist.push_back(&region);
200  continue;
201  }
202 
203  // Mark all reachable blocks.
204  reachable.clear();
205  for (Block *block : depth_first_ext(&region->front(), reachable))
206  (void)block /* Mark all reachable blocks */;
207 
208  // Collect all of the dead blocks and push the live regions onto the
209  // worklist.
210  for (Block &block : llvm::make_early_inc_range(*region)) {
211  if (!reachable.count(&block)) {
212  block.dropAllDefinedValueUses();
213  rewriter.eraseBlock(&block);
214  erasedDeadBlocks = true;
215  continue;
216  }
217 
218  // Walk any regions within this block.
219  for (Operation &op : block)
220  for (Region &region : op.getRegions())
221  worklist.push_back(&region);
222  }
223  }
224 
225  return success(erasedDeadBlocks);
226 }
227 
228 //===----------------------------------------------------------------------===//
229 // Dead Code Elimination
230 //===----------------------------------------------------------------------===//
231 
232 namespace {
233 /// Data structure used to track which values have already been proved live.
234 ///
235 /// Because Operation's can have multiple results, this data structure tracks
236 /// liveness for both Value's and Operation's to avoid having to look through
237 /// all Operation results when analyzing a use.
238 ///
239 /// This data structure essentially tracks the dataflow lattice.
240 /// The set of values/ops proved live increases monotonically to a fixed-point.
241 class LiveMap {
242 public:
243  /// Value methods.
244  bool wasProvenLive(Value value) {
245  // TODO: For results that are removable, e.g. for region based control flow,
246  // we could allow for these values to be tracked independently.
247  if (OpResult result = dyn_cast<OpResult>(value))
248  return wasProvenLive(result.getOwner());
249  return wasProvenLive(cast<BlockArgument>(value));
250  }
251  bool wasProvenLive(BlockArgument arg) { return liveValues.count(arg); }
252  void setProvedLive(Value value) {
253  // TODO: For results that are removable, e.g. for region based control flow,
254  // we could allow for these values to be tracked independently.
255  if (OpResult result = dyn_cast<OpResult>(value))
256  return setProvedLive(result.getOwner());
257  setProvedLive(cast<BlockArgument>(value));
258  }
259  void setProvedLive(BlockArgument arg) {
260  changed |= liveValues.insert(arg).second;
261  }
262 
263  /// Operation methods.
264  bool wasProvenLive(Operation *op) { return liveOps.count(op); }
265  void setProvedLive(Operation *op) { changed |= liveOps.insert(op).second; }
266 
267  /// Methods for tracking if we have reached a fixed-point.
268  void resetChanged() { changed = false; }
269  bool hasChanged() { return changed; }
270 
271 private:
272  bool changed = false;
273  DenseSet<Value> liveValues;
274  DenseSet<Operation *> liveOps;
275 };
276 } // namespace
277 
278 static bool isUseSpeciallyKnownDead(OpOperand &use, LiveMap &liveMap) {
279  Operation *owner = use.getOwner();
280  unsigned operandIndex = use.getOperandNumber();
281  // This pass generally treats all uses of an op as live if the op itself is
282  // considered live. However, for successor operands to terminators we need a
283  // finer-grained notion where we deduce liveness for operands individually.
284  // The reason for this is easiest to think about in terms of a classical phi
285  // node based SSA IR, where each successor operand is really an operand to a
286  // *separate* phi node, rather than all operands to the branch itself as with
287  // the block argument representation that MLIR uses.
288  //
289  // And similarly, because each successor operand is really an operand to a phi
290  // node, rather than to the terminator op itself, a terminator op can't e.g.
291  // "print" the value of a successor operand.
292  if (owner->hasTrait<OpTrait::IsTerminator>()) {
293  if (BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(owner))
294  if (auto arg = branchInterface.getSuccessorBlockArgument(operandIndex))
295  return !liveMap.wasProvenLive(*arg);
296  return false;
297  }
298  return false;
299 }
300 
301 static void processValue(Value value, LiveMap &liveMap) {
302  bool provedLive = llvm::any_of(value.getUses(), [&](OpOperand &use) {
303  if (isUseSpeciallyKnownDead(use, liveMap))
304  return false;
305  return liveMap.wasProvenLive(use.getOwner());
306  });
307  if (provedLive)
308  liveMap.setProvedLive(value);
309 }
310 
311 static void propagateLiveness(Region &region, LiveMap &liveMap);
312 
313 static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) {
314  // Terminators are always live.
315  liveMap.setProvedLive(op);
316 
317  // Check to see if we can reason about the successor operands and mutate them.
318  BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(op);
319  if (!branchInterface) {
320  for (Block *successor : op->getSuccessors())
321  for (BlockArgument arg : successor->getArguments())
322  liveMap.setProvedLive(arg);
323  return;
324  }
325 
326  // If we can't reason about the operand to a successor, conservatively mark
327  // it as live.
328  for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) {
329  SuccessorOperands successorOperands =
330  branchInterface.getSuccessorOperands(i);
331  for (unsigned opI = 0, opE = successorOperands.getProducedOperandCount();
332  opI != opE; ++opI)
333  liveMap.setProvedLive(op->getSuccessor(i)->getArgument(opI));
334  }
335 }
336 
337 static void propagateLiveness(Operation *op, LiveMap &liveMap) {
338  // Recurse on any regions the op has.
339  for (Region &region : op->getRegions())
340  propagateLiveness(region, liveMap);
341 
342  // Process terminator operations.
343  if (op->hasTrait<OpTrait::IsTerminator>())
344  return propagateTerminatorLiveness(op, liveMap);
345 
346  // Don't reprocess live operations.
347  if (liveMap.wasProvenLive(op))
348  return;
349 
350  // Process the op itself.
351  if (!wouldOpBeTriviallyDead(op))
352  return liveMap.setProvedLive(op);
353 
354  // If the op isn't intrinsically alive, check it's results.
355  for (Value value : op->getResults())
356  processValue(value, liveMap);
357 }
358 
359 static void propagateLiveness(Region &region, LiveMap &liveMap) {
360  if (region.empty())
361  return;
362 
363  for (Block *block : llvm::post_order(&region.front())) {
364  // We process block arguments after the ops in the block, to promote
365  // faster convergence to a fixed point (we try to visit uses before defs).
366  for (Operation &op : llvm::reverse(block->getOperations()))
367  propagateLiveness(&op, liveMap);
368 
369  // We currently do not remove entry block arguments, so there is no need to
370  // track their liveness.
371  // TODO: We could track these and enable removing dead operands/arguments
372  // from region control flow operations.
373  if (block->isEntryBlock())
374  continue;
375 
376  for (Value value : block->getArguments()) {
377  if (!liveMap.wasProvenLive(value))
378  processValue(value, liveMap);
379  }
380  }
381 }
382 
384  LiveMap &liveMap) {
385  BranchOpInterface branchOp = dyn_cast<BranchOpInterface>(terminator);
386  if (!branchOp)
387  return;
388 
389  for (unsigned succI = 0, succE = terminator->getNumSuccessors();
390  succI < succE; succI++) {
391  // Iterating successors in reverse is not strictly needed, since we
392  // aren't erasing any successors. But it is slightly more efficient
393  // since it will promote later operands of the terminator being erased
394  // first, reducing the quadratic-ness.
395  unsigned succ = succE - succI - 1;
396  SuccessorOperands succOperands = branchOp.getSuccessorOperands(succ);
397  Block *successor = terminator->getSuccessor(succ);
398 
399  for (unsigned argI = 0, argE = succOperands.size(); argI < argE; ++argI) {
400  // Iterating args in reverse is needed for correctness, to avoid
401  // shifting later args when earlier args are erased.
402  unsigned arg = argE - argI - 1;
403  if (!liveMap.wasProvenLive(successor->getArgument(arg)))
404  succOperands.erase(arg);
405  }
406  }
407 }
408 
410  MutableArrayRef<Region> regions,
411  LiveMap &liveMap) {
412  bool erasedAnything = false;
413  for (Region &region : regions) {
414  if (region.empty())
415  continue;
416  bool hasSingleBlock = llvm::hasSingleElement(region);
417 
418  // Delete every operation that is not live. Graph regions may have cycles
419  // in the use-def graph, so we must explicitly dropAllUses() from each
420  // operation as we erase it. Visiting the operations in post-order
421  // guarantees that in SSA CFG regions value uses are removed before defs,
422  // which makes dropAllUses() a no-op.
423  for (Block *block : llvm::post_order(&region.front())) {
424  if (!hasSingleBlock)
425  eraseTerminatorSuccessorOperands(block->getTerminator(), liveMap);
426  for (Operation &childOp :
427  llvm::make_early_inc_range(llvm::reverse(block->getOperations()))) {
428  if (!liveMap.wasProvenLive(&childOp)) {
429  erasedAnything = true;
430  childOp.dropAllUses();
431  rewriter.eraseOp(&childOp);
432  } else {
433  erasedAnything |= succeeded(
434  deleteDeadness(rewriter, childOp.getRegions(), liveMap));
435  }
436  }
437  }
438  // Delete block arguments.
439  // The entry block has an unknown contract with their enclosing block, so
440  // skip it.
441  for (Block &block : llvm::drop_begin(region.getBlocks(), 1)) {
442  block.eraseArguments(
443  [&](BlockArgument arg) { return !liveMap.wasProvenLive(arg); });
444  }
445  }
446  return success(erasedAnything);
447 }
448 
449 // This function performs a simple dead code elimination algorithm over the
450 // given regions.
451 //
452 // The overall goal is to prove that Values are dead, which allows deleting ops
453 // and block arguments.
454 //
455 // This uses an optimistic algorithm that assumes everything is dead until
456 // proved otherwise, allowing it to delete recursively dead cycles.
457 //
458 // This is a simple fixed-point dataflow analysis algorithm on a lattice
459 // {Dead,Alive}. Because liveness flows backward, we generally try to
460 // iterate everything backward to speed up convergence to the fixed-point. This
461 // allows for being able to delete recursively dead cycles of the use-def graph,
462 // including block arguments.
463 //
464 // This function returns success if any operations or arguments were deleted,
465 // failure otherwise.
467  MutableArrayRef<Region> regions) {
468  LiveMap liveMap;
469  do {
470  liveMap.resetChanged();
471 
472  for (Region &region : regions)
473  propagateLiveness(region, liveMap);
474  } while (liveMap.hasChanged());
475 
476  return deleteDeadness(rewriter, regions, liveMap);
477 }
478 
479 //===----------------------------------------------------------------------===//
480 // Block Merging
481 //===----------------------------------------------------------------------===//
482 
483 //===----------------------------------------------------------------------===//
484 // BlockEquivalenceData
485 
486 namespace {
487 /// This class contains the information for comparing the equivalencies of two
488 /// blocks. Blocks are considered equivalent if they contain the same operations
489 /// in the same order. The only allowed divergence is for operands that come
490 /// from sources outside of the parent block, i.e. the uses of values produced
491 /// within the block must be equivalent.
492 /// e.g.,
493 /// Equivalent:
494 /// ^bb1(%arg0: i32)
495 /// return %arg0, %foo : i32, i32
496 /// ^bb2(%arg1: i32)
497 /// return %arg1, %bar : i32, i32
498 /// Not Equivalent:
499 /// ^bb1(%arg0: i32)
500 /// return %foo, %arg0 : i32, i32
501 /// ^bb2(%arg1: i32)
502 /// return %arg1, %bar : i32, i32
503 struct BlockEquivalenceData {
504  BlockEquivalenceData(Block *block);
505 
506  /// Return the order index for the given value that is within the block of
507  /// this data.
508  unsigned getOrderOf(Value value) const;
509 
510  /// The block this data refers to.
511  Block *block;
512  /// A hash value for this block.
513  llvm::hash_code hash;
514  /// A map of result producing operations to their relative orders within this
515  /// block. The order of an operation is the number of defined values that are
516  /// produced within the block before this operation.
517  DenseMap<Operation *, unsigned> opOrderIndex;
518 };
519 } // namespace
520 
521 BlockEquivalenceData::BlockEquivalenceData(Block *block)
522  : block(block), hash(0) {
523  unsigned orderIt = block->getNumArguments();
524  for (Operation &op : *block) {
525  if (unsigned numResults = op.getNumResults()) {
526  opOrderIndex.try_emplace(&op, orderIt);
527  orderIt += numResults;
528  }
529  auto opHash = OperationEquivalence::computeHash(
533  hash = llvm::hash_combine(hash, opHash);
534  }
535 }
536 
537 unsigned BlockEquivalenceData::getOrderOf(Value value) const {
538  assert(value.getParentBlock() == block && "expected value of this block");
539 
540  // Arguments use the argument number as the order index.
541  if (BlockArgument arg = dyn_cast<BlockArgument>(value))
542  return arg.getArgNumber();
543 
544  // Otherwise, the result order is offset from the parent op's order.
545  OpResult result = cast<OpResult>(value);
546  auto opOrderIt = opOrderIndex.find(result.getDefiningOp());
547  assert(opOrderIt != opOrderIndex.end() && "expected op to have an order");
548  return opOrderIt->second + result.getResultNumber();
549 }
550 
551 //===----------------------------------------------------------------------===//
552 // BlockMergeCluster
553 
554 namespace {
555 /// This class represents a cluster of blocks to be merged together.
556 class BlockMergeCluster {
557 public:
558  BlockMergeCluster(BlockEquivalenceData &&leaderData)
559  : leaderData(std::move(leaderData)) {}
560 
561  /// Attempt to add the given block to this cluster. Returns success if the
562  /// block was merged, failure otherwise.
563  LogicalResult addToCluster(BlockEquivalenceData &blockData);
564 
565  /// Try to merge all of the blocks within this cluster into the leader block.
566  LogicalResult merge(RewriterBase &rewriter);
567 
568 private:
569  /// The equivalence data for the leader of the cluster.
570  BlockEquivalenceData leaderData;
571 
572  /// The set of blocks that can be merged into the leader.
573  llvm::SmallSetVector<Block *, 1> blocksToMerge;
574 
575  /// A set of operand+index pairs that correspond to operands that need to be
576  /// replaced by arguments when the cluster gets merged.
577  std::set<std::pair<int, int>> operandsToMerge;
578 };
579 } // namespace
580 
581 LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) {
582  if (leaderData.hash != blockData.hash)
583  return failure();
584  Block *leaderBlock = leaderData.block, *mergeBlock = blockData.block;
585  if (leaderBlock->getArgumentTypes() != mergeBlock->getArgumentTypes())
586  return failure();
587 
588  // A set of operands that mismatch between the leader and the new block.
589  SmallVector<std::pair<int, int>, 8> mismatchedOperands;
590  auto lhsIt = leaderBlock->begin(), lhsE = leaderBlock->end();
591  auto rhsIt = blockData.block->begin(), rhsE = blockData.block->end();
592  for (int opI = 0; lhsIt != lhsE && rhsIt != rhsE; ++lhsIt, ++rhsIt, ++opI) {
593  // Check that the operations are equivalent.
594  if (!OperationEquivalence::isEquivalentTo(
595  &*lhsIt, &*rhsIt, OperationEquivalence::ignoreValueEquivalence,
596  /*markEquivalent=*/nullptr,
597  OperationEquivalence::Flags::IgnoreLocations))
598  return failure();
599 
600  // Compare the operands of the two operations. If the operand is within
601  // the block, it must refer to the same operation.
602  auto lhsOperands = lhsIt->getOperands(), rhsOperands = rhsIt->getOperands();
603  for (int operand : llvm::seq<int>(0, lhsIt->getNumOperands())) {
604  Value lhsOperand = lhsOperands[operand];
605  Value rhsOperand = rhsOperands[operand];
606  if (lhsOperand == rhsOperand)
607  continue;
608  // Check that the types of the operands match.
609  if (lhsOperand.getType() != rhsOperand.getType())
610  return failure();
611 
612  // Check that these uses are both external, or both internal.
613  bool lhsIsInBlock = lhsOperand.getParentBlock() == leaderBlock;
614  bool rhsIsInBlock = rhsOperand.getParentBlock() == mergeBlock;
615  if (lhsIsInBlock != rhsIsInBlock)
616  return failure();
617  // Let the operands differ if they are defined in a different block. These
618  // will become new arguments if the blocks get merged.
619  if (!lhsIsInBlock) {
620 
621  // Check whether the operands aren't the result of an immediate
622  // predecessors terminator. In that case we are not able to use it as a
623  // successor operand when branching to the merged block as it does not
624  // dominate its producing operation.
625  auto isValidSuccessorArg = [](Block *block, Value operand) {
626  if (operand.getDefiningOp() !=
627  operand.getParentBlock()->getTerminator())
628  return true;
629  return !llvm::is_contained(block->getPredecessors(),
630  operand.getParentBlock());
631  };
632 
633  if (!isValidSuccessorArg(leaderBlock, lhsOperand) ||
634  !isValidSuccessorArg(mergeBlock, rhsOperand))
635  return failure();
636 
637  mismatchedOperands.emplace_back(opI, operand);
638  continue;
639  }
640 
641  // Otherwise, these operands must have the same logical order within the
642  // parent block.
643  if (leaderData.getOrderOf(lhsOperand) != blockData.getOrderOf(rhsOperand))
644  return failure();
645  }
646 
647  // If the lhs or rhs has external uses, the blocks cannot be merged as the
648  // merged version of this operation will not be either the lhs or rhs
649  // alone (thus semantically incorrect), but some mix dependending on which
650  // block preceeded this.
651  // TODO allow merging of operations when one block does not dominate the
652  // other
653  if (rhsIt->isUsedOutsideOfBlock(mergeBlock) ||
654  lhsIt->isUsedOutsideOfBlock(leaderBlock)) {
655  return failure();
656  }
657  }
658  // Make sure that the block sizes are equivalent.
659  if (lhsIt != lhsE || rhsIt != rhsE)
660  return failure();
661 
662  // If we get here, the blocks are equivalent and can be merged.
663  operandsToMerge.insert(mismatchedOperands.begin(), mismatchedOperands.end());
664  blocksToMerge.insert(blockData.block);
665  return success();
666 }
667 
668 /// Returns true if the predecessor terminators of the given block can not have
669 /// their operands updated.
670 static bool ableToUpdatePredOperands(Block *block) {
671  for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
672  if (!isa<BranchOpInterface>((*it)->getTerminator()))
673  return false;
674  }
675  return true;
676 }
677 
678 LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
679  // Don't consider clusters that don't have blocks to merge.
680  if (blocksToMerge.empty())
681  return failure();
682 
683  Block *leaderBlock = leaderData.block;
684  if (!operandsToMerge.empty()) {
685  // If the cluster has operands to merge, verify that the predecessor
686  // terminators of each of the blocks can have their successor operands
687  // updated.
688  // TODO: We could try and sub-partition this cluster if only some blocks
689  // cause the mismatch.
690  if (!ableToUpdatePredOperands(leaderBlock) ||
691  !llvm::all_of(blocksToMerge, ableToUpdatePredOperands))
692  return failure();
693 
694  // Collect the iterators for each of the blocks to merge. We will walk all
695  // of the iterators at once to avoid operand index invalidation.
696  SmallVector<Block::iterator, 2> blockIterators;
697  blockIterators.reserve(blocksToMerge.size() + 1);
698  blockIterators.push_back(leaderBlock->begin());
699  for (Block *mergeBlock : blocksToMerge)
700  blockIterators.push_back(mergeBlock->begin());
701 
702  // Update each of the predecessor terminators with the new arguments.
703  SmallVector<SmallVector<Value, 8>, 2> newArguments(
704  1 + blocksToMerge.size(),
705  SmallVector<Value, 8>(operandsToMerge.size()));
706  unsigned curOpIndex = 0;
707  for (const auto &it : llvm::enumerate(operandsToMerge)) {
708  unsigned nextOpOffset = it.value().first - curOpIndex;
709  curOpIndex = it.value().first;
710 
711  // Process the operand for each of the block iterators.
712  for (unsigned i = 0, e = blockIterators.size(); i != e; ++i) {
713  Block::iterator &blockIter = blockIterators[i];
714  std::advance(blockIter, nextOpOffset);
715  auto &operand = blockIter->getOpOperand(it.value().second);
716  newArguments[i][it.index()] = operand.get();
717 
718  // Update the operand and insert an argument if this is the leader.
719  if (i == 0) {
720  Value operandVal = operand.get();
721  operand.set(leaderBlock->addArgument(operandVal.getType(),
722  operandVal.getLoc()));
723  }
724  }
725  }
726  // Update the predecessors for each of the blocks.
727  auto updatePredecessors = [&](Block *block, unsigned clusterIndex) {
728  for (auto predIt = block->pred_begin(), predE = block->pred_end();
729  predIt != predE; ++predIt) {
730  auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
731  unsigned succIndex = predIt.getSuccessorIndex();
732  branch.getSuccessorOperands(succIndex).append(
733  newArguments[clusterIndex]);
734  }
735  };
736  updatePredecessors(leaderBlock, /*clusterIndex=*/0);
737  for (unsigned i = 0, e = blocksToMerge.size(); i != e; ++i)
738  updatePredecessors(blocksToMerge[i], /*clusterIndex=*/i + 1);
739  }
740 
741  // Replace all uses of the merged blocks with the leader and erase them.
742  for (Block *block : blocksToMerge) {
743  block->replaceAllUsesWith(leaderBlock);
744  rewriter.eraseBlock(block);
745  }
746  return success();
747 }
748 
749 /// Identify identical blocks within the given region and merge them, inserting
750 /// new block arguments as necessary. Returns success if any blocks were merged,
751 /// failure otherwise.
753  Region &region) {
754  if (region.empty() || llvm::hasSingleElement(region))
755  return failure();
756 
757  // Identify sets of blocks, other than the entry block, that branch to the
758  // same successors. We will use these groups to create clusters of equivalent
759  // blocks.
761  for (Block &block : llvm::drop_begin(region, 1))
762  matchingSuccessors[block.getSuccessors()].push_back(&block);
763 
764  bool mergedAnyBlocks = false;
765  for (ArrayRef<Block *> blocks : llvm::make_second_range(matchingSuccessors)) {
766  if (blocks.size() == 1)
767  continue;
768 
770  for (Block *block : blocks) {
771  BlockEquivalenceData data(block);
772 
773  // Don't allow merging if this block has any regions.
774  // TODO: Add support for regions if necessary.
775  bool hasNonEmptyRegion = llvm::any_of(*block, [](Operation &op) {
776  return llvm::any_of(op.getRegions(),
777  [](Region &region) { return !region.empty(); });
778  });
779  if (hasNonEmptyRegion)
780  continue;
781 
782  // Try to add this block to an existing cluster.
783  bool addedToCluster = false;
784  for (auto &cluster : clusters)
785  if ((addedToCluster = succeeded(cluster.addToCluster(data))))
786  break;
787  if (!addedToCluster)
788  clusters.emplace_back(std::move(data));
789  }
790  for (auto &cluster : clusters)
791  mergedAnyBlocks |= succeeded(cluster.merge(rewriter));
792  }
793 
794  return success(mergedAnyBlocks);
795 }
796 
797 /// Identify identical blocks within the given regions and merge them, inserting
798 /// new block arguments as necessary.
800  MutableArrayRef<Region> regions) {
801  llvm::SmallSetVector<Region *, 1> worklist;
802  for (auto &region : regions)
803  worklist.insert(&region);
804  bool anyChanged = false;
805  while (!worklist.empty()) {
806  Region *region = worklist.pop_back_val();
807  if (succeeded(mergeIdenticalBlocks(rewriter, *region))) {
808  worklist.insert(region);
809  anyChanged = true;
810  }
811 
812  // Add any nested regions to the worklist.
813  for (Block &block : *region)
814  for (auto &op : block)
815  for (auto &nestedRegion : op.getRegions())
816  worklist.insert(&nestedRegion);
817  }
818 
819  return success(anyChanged);
820 }
821 
822 //===----------------------------------------------------------------------===//
823 // Region Simplification
824 //===----------------------------------------------------------------------===//
825 
826 /// Run a set of structural simplifications over the given regions. This
827 /// includes transformations like unreachable block elimination, dead argument
828 /// elimination, as well as some other DCE. This function returns success if any
829 /// of the regions were simplified, failure otherwise.
831  MutableArrayRef<Region> regions) {
832  bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
833  bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
834  bool mergedIdenticalBlocks =
835  succeeded(mergeIdenticalBlocks(rewriter, regions));
836  return success(eliminatedBlocks || eliminatedOpsOrArgs ||
837  mergedIdenticalBlocks);
838 }
839 
841  // For each block that has not been visited yet (i.e. that has no
842  // predecessors), add it to the list as well as its successors.
843  SetVector<Block *> blocks;
844  for (Block &b : region) {
845  if (blocks.count(&b) == 0) {
846  llvm::ReversePostOrderTraversal<Block *> traversal(&b);
847  blocks.insert(traversal.begin(), traversal.end());
848  }
849  }
850  assert(blocks.size() == region.getBlocks().size() &&
851  "some blocks are not sorted");
852 
853  return blocks;
854 }
static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter, Region &region)
Identify identical blocks within the given region and merge them, inserting new block arguments as ne...
static void propagateLiveness(Region &region, LiveMap &liveMap)
static void processValue(Value value, LiveMap &liveMap)
static bool ableToUpdatePredOperands(Block *block)
Returns true if the predecessor terminators of the given block can not have their operands updated.
static void eraseTerminatorSuccessorOperands(Operation *terminator, LiveMap &liveMap)
static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap)
static bool isUseSpeciallyKnownDead(OpOperand &use, LiveMap &liveMap)
static LogicalResult deleteDeadness(RewriterBase &rewriter, MutableArrayRef< Region > regions, LiveMap &liveMap)
This class represents an argument of a Block.
Definition: Value.h:319
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:331
Block represents an ordered list of Operations.
Definition: Block.h:30
OpListType::iterator iterator
Definition: Block.h:137
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:148
BlockArgument getArgument(unsigned i)
Definition: Block.h:126
unsigned getNumArguments()
Definition: Block.h:125
pred_iterator pred_begin()
Definition: Block.h:230
SuccessorRange getSuccessors()
Definition: Block.h:264
iterator_range< pred_iterator > getPredecessors()
Definition: Block.h:234
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:152
BlockArgListType getArguments()
Definition: Block.h:84
iterator end()
Definition: Block.h:141
iterator begin()
Definition: Block.h:140
pred_iterator pred_end()
Definition: Block.h:233
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: UseDefLists.h:211
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:555
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:764
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
Block * getSuccessor(unsigned index)
Definition: Operation.h:704
unsigned getNumSuccessors()
Definition: Operation.h:702
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
SuccessorRange getSuccessors()
Definition: Operation.h:699
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition: Region.cpp:45
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition: Region.h:222
bool empty()
Definition: Region.h:60
iterator end()
Definition: Region.h:56
iterator begin()
Definition: Region.h:55
BlockListType & getBlocks()
Definition: Region.h:45
Block & front()
Definition: Region.h:65
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
Definition: Region.h:285
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
void replaceOpUsesWithIf(Operation *from, ValueRange to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Definition: PatternMatch.h:679
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
This class models how operands are forwarded to block arguments in control flow.
void erase(unsigned subStart, unsigned subLen=1)
Erase operands forwarded to the successor.
unsigned getProducedOperandCount() const
Returns the amount of operands that are produced internally by the operation.
unsigned size() const
Returns the amount of operands passed to the successor.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:212
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region &region)
Replace all uses of orig within the given region with replacement.
Definition: RegionUtils.cpp:28
SetVector< Block * > getTopologicallySortedBlocks(Region &region)
Get a topologically sorted list of blocks of the given region.
bool computeTopologicalSorting(MutableArrayRef< Operation * > ops, function_ref< bool(Value, Operation *)> isOperandReady=nullptr)
Compute a topological ordering of the given ops.
LogicalResult simplifyRegions(RewriterBase &rewriter, MutableArrayRef< Region > regions)
Run a set of structural simplifications over the given regions.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
LogicalResult eraseUnreachableBlocks(RewriterBase &rewriter, MutableArrayRef< Region > regions)
Erase the unreachable blocks within the provided regions.
SmallVector< Value > makeRegionIsolatedFromAbove(RewriterBase &rewriter, Region &region, llvm::function_ref< bool(Operation *)> cloneOperationIntoRegion=[](Operation *) { return false;})
Make a region isolated from above.
Definition: RegionUtils.cpp:80
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
Definition: RegionUtils.cpp:63
LogicalResult runRegionDCE(RewriterBase &rewriter, MutableArrayRef< Region > regions)
This function returns success if any operations or arguments were deleted, failure otherwise.
void visitUsedValuesDefinedAbove(Region &region, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
Definition: RegionUtils.cpp:36
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
static llvm::hash_code ignoreHashValue(Value)
Helper that can be used with computeHash above to ignore operation operands/result mapping.
static llvm::hash_code computeHash(Operation *op, function_ref< llvm::hash_code(Value)> hashOperands=[](Value v) { return hash_value(v);}, function_ref< llvm::hash_code(Value)> hashResults=[](Value v) { return hash_value(v);}, Flags flags=Flags::None)
Compute a hash for the given operation.