MLIR  20.0.0git
RegionUtils.cpp
Go to the documentation of this file.
1 //===- RegionUtils.cpp - Region-related transformation utilities ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/Block.h"
12 #include "mlir/IR/IRMapping.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/IR/Value.h"
19 
20 #include "llvm/ADT/DepthFirstIterator.h"
21 #include "llvm/ADT/PostOrderIterator.h"
22 
23 #include <deque>
24 
25 using namespace mlir;
26 
28  Region &region) {
29  for (auto &use : llvm::make_early_inc_range(orig.getUses())) {
30  if (region.isAncestor(use.getOwner()->getParentRegion()))
31  use.set(replacement);
32  }
33 }
34 
36  Region &region, Region &limit, function_ref<void(OpOperand *)> callback) {
37  assert(limit.isAncestor(&region) &&
38  "expected isolation limit to be an ancestor of the given region");
39 
40  // Collect proper ancestors of `limit` upfront to avoid traversing the region
41  // tree for every value.
42  SmallPtrSet<Region *, 4> properAncestors;
43  for (auto *reg = limit.getParentRegion(); reg != nullptr;
44  reg = reg->getParentRegion()) {
45  properAncestors.insert(reg);
46  }
47 
48  region.walk([callback, &properAncestors](Operation *op) {
49  for (OpOperand &operand : op->getOpOperands())
50  // Callback on values defined in a proper ancestor of region.
51  if (properAncestors.count(operand.get().getParentRegion()))
52  callback(&operand);
53  });
54 }
55 
57  MutableArrayRef<Region> regions, function_ref<void(OpOperand *)> callback) {
58  for (Region &region : regions)
59  visitUsedValuesDefinedAbove(region, region, callback);
60 }
61 
63  SetVector<Value> &values) {
64  visitUsedValuesDefinedAbove(region, limit, [&](OpOperand *operand) {
65  values.insert(operand->get());
66  });
67 }
68 
70  SetVector<Value> &values) {
71  for (Region &region : regions)
72  getUsedValuesDefinedAbove(region, region, values);
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // Make block isolated from above.
77 //===----------------------------------------------------------------------===//
78 
80  RewriterBase &rewriter, Region &region,
81  llvm::function_ref<bool(Operation *)> cloneOperationIntoRegion) {
82 
83  // Get initial list of values used within region but defined above.
84  llvm::SetVector<Value> initialCapturedValues;
85  mlir::getUsedValuesDefinedAbove(region, initialCapturedValues);
86 
87  std::deque<Value> worklist(initialCapturedValues.begin(),
88  initialCapturedValues.end());
89  llvm::DenseSet<Value> visited;
90  llvm::DenseSet<Operation *> visitedOps;
91 
92  llvm::SetVector<Value> finalCapturedValues;
93  SmallVector<Operation *> clonedOperations;
94  while (!worklist.empty()) {
95  Value currValue = worklist.front();
96  worklist.pop_front();
97  if (visited.count(currValue))
98  continue;
99  visited.insert(currValue);
100 
101  Operation *definingOp = currValue.getDefiningOp();
102  if (!definingOp || visitedOps.count(definingOp)) {
103  finalCapturedValues.insert(currValue);
104  continue;
105  }
106  visitedOps.insert(definingOp);
107 
108  if (!cloneOperationIntoRegion(definingOp)) {
109  // Defining operation isnt cloned, so add the current value to final
110  // captured values list.
111  finalCapturedValues.insert(currValue);
112  continue;
113  }
114 
115  // Add all operands of the operation to the worklist and mark the op as to
116  // be cloned.
117  for (Value operand : definingOp->getOperands()) {
118  if (visited.count(operand))
119  continue;
120  worklist.push_back(operand);
121  }
122  clonedOperations.push_back(definingOp);
123  }
124 
125  // The operations to be cloned need to be ordered in topological order
126  // so that they can be cloned into the region without violating use-def
127  // chains.
128  mlir::computeTopologicalSorting(clonedOperations);
129 
130  OpBuilder::InsertionGuard g(rewriter);
131  // Collect types of existing block
132  Block *entryBlock = &region.front();
133  SmallVector<Type> newArgTypes =
134  llvm::to_vector(entryBlock->getArgumentTypes());
135  SmallVector<Location> newArgLocs = llvm::to_vector(llvm::map_range(
136  entryBlock->getArguments(), [](BlockArgument b) { return b.getLoc(); }));
137 
138  // Append the types of the captured values.
139  for (auto value : finalCapturedValues) {
140  newArgTypes.push_back(value.getType());
141  newArgLocs.push_back(value.getLoc());
142  }
143 
144  // Create a new entry block.
145  Block *newEntryBlock =
146  rewriter.createBlock(&region, region.begin(), newArgTypes, newArgLocs);
147  auto newEntryBlockArgs = newEntryBlock->getArguments();
148 
149  // Create a mapping between the captured values and the new arguments added.
150  IRMapping map;
151  auto replaceIfFn = [&](OpOperand &use) {
152  return use.getOwner()->getBlock()->getParent() == &region;
153  };
154  for (auto [arg, capturedVal] :
155  llvm::zip(newEntryBlockArgs.take_back(finalCapturedValues.size()),
156  finalCapturedValues)) {
157  map.map(capturedVal, arg);
158  rewriter.replaceUsesWithIf(capturedVal, arg, replaceIfFn);
159  }
160  rewriter.setInsertionPointToStart(newEntryBlock);
161  for (auto *clonedOp : clonedOperations) {
162  Operation *newOp = rewriter.clone(*clonedOp, map);
163  rewriter.replaceOpUsesWithIf(clonedOp, newOp->getResults(), replaceIfFn);
164  }
165  rewriter.mergeBlocks(
166  entryBlock, newEntryBlock,
167  newEntryBlock->getArguments().take_front(entryBlock->getNumArguments()));
168  return llvm::to_vector(finalCapturedValues);
169 }
170 
171 //===----------------------------------------------------------------------===//
172 // Unreachable Block Elimination
173 //===----------------------------------------------------------------------===//
174 
175 /// Erase the unreachable blocks within the provided regions. Returns success
176 /// if any blocks were erased, failure otherwise.
177 // TODO: We could likely merge this with the DCE algorithm below.
179  MutableArrayRef<Region> regions) {
180  // Set of blocks found to be reachable within a given region.
181  llvm::df_iterator_default_set<Block *, 16> reachable;
182  // If any blocks were found to be dead.
183  bool erasedDeadBlocks = false;
184 
185  SmallVector<Region *, 1> worklist;
186  worklist.reserve(regions.size());
187  for (Region &region : regions)
188  worklist.push_back(&region);
189  while (!worklist.empty()) {
190  Region *region = worklist.pop_back_val();
191  if (region->empty())
192  continue;
193 
194  // If this is a single block region, just collect the nested regions.
195  if (std::next(region->begin()) == region->end()) {
196  for (Operation &op : region->front())
197  for (Region &region : op.getRegions())
198  worklist.push_back(&region);
199  continue;
200  }
201 
202  // Mark all reachable blocks.
203  reachable.clear();
204  for (Block *block : depth_first_ext(&region->front(), reachable))
205  (void)block /* Mark all reachable blocks */;
206 
207  // Collect all of the dead blocks and push the live regions onto the
208  // worklist.
209  for (Block &block : llvm::make_early_inc_range(*region)) {
210  if (!reachable.count(&block)) {
211  block.dropAllDefinedValueUses();
212  rewriter.eraseBlock(&block);
213  erasedDeadBlocks = true;
214  continue;
215  }
216 
217  // Walk any regions within this block.
218  for (Operation &op : block)
219  for (Region &region : op.getRegions())
220  worklist.push_back(&region);
221  }
222  }
223 
224  return success(erasedDeadBlocks);
225 }
226 
227 //===----------------------------------------------------------------------===//
228 // Dead Code Elimination
229 //===----------------------------------------------------------------------===//
230 
231 namespace {
232 /// Data structure used to track which values have already been proved live.
233 ///
234 /// Because Operation's can have multiple results, this data structure tracks
235 /// liveness for both Value's and Operation's to avoid having to look through
236 /// all Operation results when analyzing a use.
237 ///
238 /// This data structure essentially tracks the dataflow lattice.
239 /// The set of values/ops proved live increases monotonically to a fixed-point.
240 class LiveMap {
241 public:
242  /// Value methods.
243  bool wasProvenLive(Value value) {
244  // TODO: For results that are removable, e.g. for region based control flow,
245  // we could allow for these values to be tracked independently.
246  if (OpResult result = dyn_cast<OpResult>(value))
247  return wasProvenLive(result.getOwner());
248  return wasProvenLive(cast<BlockArgument>(value));
249  }
250  bool wasProvenLive(BlockArgument arg) { return liveValues.count(arg); }
251  void setProvedLive(Value value) {
252  // TODO: For results that are removable, e.g. for region based control flow,
253  // we could allow for these values to be tracked independently.
254  if (OpResult result = dyn_cast<OpResult>(value))
255  return setProvedLive(result.getOwner());
256  setProvedLive(cast<BlockArgument>(value));
257  }
258  void setProvedLive(BlockArgument arg) {
259  changed |= liveValues.insert(arg).second;
260  }
261 
262  /// Operation methods.
263  bool wasProvenLive(Operation *op) { return liveOps.count(op); }
264  void setProvedLive(Operation *op) { changed |= liveOps.insert(op).second; }
265 
266  /// Methods for tracking if we have reached a fixed-point.
267  void resetChanged() { changed = false; }
268  bool hasChanged() { return changed; }
269 
270 private:
271  bool changed = false;
272  DenseSet<Value> liveValues;
273  DenseSet<Operation *> liveOps;
274 };
275 } // namespace
276 
277 static bool isUseSpeciallyKnownDead(OpOperand &use, LiveMap &liveMap) {
278  Operation *owner = use.getOwner();
279  unsigned operandIndex = use.getOperandNumber();
280  // This pass generally treats all uses of an op as live if the op itself is
281  // considered live. However, for successor operands to terminators we need a
282  // finer-grained notion where we deduce liveness for operands individually.
283  // The reason for this is easiest to think about in terms of a classical phi
284  // node based SSA IR, where each successor operand is really an operand to a
285  // *separate* phi node, rather than all operands to the branch itself as with
286  // the block argument representation that MLIR uses.
287  //
288  // And similarly, because each successor operand is really an operand to a phi
289  // node, rather than to the terminator op itself, a terminator op can't e.g.
290  // "print" the value of a successor operand.
291  if (owner->hasTrait<OpTrait::IsTerminator>()) {
292  if (BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(owner))
293  if (auto arg = branchInterface.getSuccessorBlockArgument(operandIndex))
294  return !liveMap.wasProvenLive(*arg);
295  return false;
296  }
297  return false;
298 }
299 
300 static void processValue(Value value, LiveMap &liveMap) {
301  bool provedLive = llvm::any_of(value.getUses(), [&](OpOperand &use) {
302  if (isUseSpeciallyKnownDead(use, liveMap))
303  return false;
304  return liveMap.wasProvenLive(use.getOwner());
305  });
306  if (provedLive)
307  liveMap.setProvedLive(value);
308 }
309 
310 static void propagateLiveness(Region &region, LiveMap &liveMap);
311 
312 static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) {
313  // Terminators are always live.
314  liveMap.setProvedLive(op);
315 
316  // Check to see if we can reason about the successor operands and mutate them.
317  BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(op);
318  if (!branchInterface) {
319  for (Block *successor : op->getSuccessors())
320  for (BlockArgument arg : successor->getArguments())
321  liveMap.setProvedLive(arg);
322  return;
323  }
324 
325  // If we can't reason about the operand to a successor, conservatively mark
326  // it as live.
327  for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) {
328  SuccessorOperands successorOperands =
329  branchInterface.getSuccessorOperands(i);
330  for (unsigned opI = 0, opE = successorOperands.getProducedOperandCount();
331  opI != opE; ++opI)
332  liveMap.setProvedLive(op->getSuccessor(i)->getArgument(opI));
333  }
334 }
335 
336 static void propagateLiveness(Operation *op, LiveMap &liveMap) {
337  // Recurse on any regions the op has.
338  for (Region &region : op->getRegions())
339  propagateLiveness(region, liveMap);
340 
341  // Process terminator operations.
342  if (op->hasTrait<OpTrait::IsTerminator>())
343  return propagateTerminatorLiveness(op, liveMap);
344 
345  // Don't reprocess live operations.
346  if (liveMap.wasProvenLive(op))
347  return;
348 
349  // Process the op itself.
350  if (!wouldOpBeTriviallyDead(op))
351  return liveMap.setProvedLive(op);
352 
353  // If the op isn't intrinsically alive, check it's results.
354  for (Value value : op->getResults())
355  processValue(value, liveMap);
356 }
357 
358 static void propagateLiveness(Region &region, LiveMap &liveMap) {
359  if (region.empty())
360  return;
361 
362  for (Block *block : llvm::post_order(&region.front())) {
363  // We process block arguments after the ops in the block, to promote
364  // faster convergence to a fixed point (we try to visit uses before defs).
365  for (Operation &op : llvm::reverse(block->getOperations()))
366  propagateLiveness(&op, liveMap);
367 
368  // We currently do not remove entry block arguments, so there is no need to
369  // track their liveness.
370  // TODO: We could track these and enable removing dead operands/arguments
371  // from region control flow operations.
372  if (block->isEntryBlock())
373  continue;
374 
375  for (Value value : block->getArguments()) {
376  if (!liveMap.wasProvenLive(value))
377  processValue(value, liveMap);
378  }
379  }
380 }
381 
383  LiveMap &liveMap) {
384  BranchOpInterface branchOp = dyn_cast<BranchOpInterface>(terminator);
385  if (!branchOp)
386  return;
387 
388  for (unsigned succI = 0, succE = terminator->getNumSuccessors();
389  succI < succE; succI++) {
390  // Iterating successors in reverse is not strictly needed, since we
391  // aren't erasing any successors. But it is slightly more efficient
392  // since it will promote later operands of the terminator being erased
393  // first, reducing the quadratic-ness.
394  unsigned succ = succE - succI - 1;
395  SuccessorOperands succOperands = branchOp.getSuccessorOperands(succ);
396  Block *successor = terminator->getSuccessor(succ);
397 
398  for (unsigned argI = 0, argE = succOperands.size(); argI < argE; ++argI) {
399  // Iterating args in reverse is needed for correctness, to avoid
400  // shifting later args when earlier args are erased.
401  unsigned arg = argE - argI - 1;
402  if (!liveMap.wasProvenLive(successor->getArgument(arg)))
403  succOperands.erase(arg);
404  }
405  }
406 }
407 
408 static LogicalResult deleteDeadness(RewriterBase &rewriter,
409  MutableArrayRef<Region> regions,
410  LiveMap &liveMap) {
411  bool erasedAnything = false;
412  for (Region &region : regions) {
413  if (region.empty())
414  continue;
415  bool hasSingleBlock = llvm::hasSingleElement(region);
416 
417  // Delete every operation that is not live. Graph regions may have cycles
418  // in the use-def graph, so we must explicitly dropAllUses() from each
419  // operation as we erase it. Visiting the operations in post-order
420  // guarantees that in SSA CFG regions value uses are removed before defs,
421  // which makes dropAllUses() a no-op.
422  for (Block *block : llvm::post_order(&region.front())) {
423  if (!hasSingleBlock)
424  eraseTerminatorSuccessorOperands(block->getTerminator(), liveMap);
425  for (Operation &childOp :
426  llvm::make_early_inc_range(llvm::reverse(block->getOperations()))) {
427  if (!liveMap.wasProvenLive(&childOp)) {
428  erasedAnything = true;
429  childOp.dropAllUses();
430  rewriter.eraseOp(&childOp);
431  } else {
432  erasedAnything |= succeeded(
433  deleteDeadness(rewriter, childOp.getRegions(), liveMap));
434  }
435  }
436  }
437  // Delete block arguments.
438  // The entry block has an unknown contract with their enclosing block, so
439  // skip it.
440  for (Block &block : llvm::drop_begin(region.getBlocks(), 1)) {
441  block.eraseArguments(
442  [&](BlockArgument arg) { return !liveMap.wasProvenLive(arg); });
443  }
444  }
445  return success(erasedAnything);
446 }
447 
448 // This function performs a simple dead code elimination algorithm over the
449 // given regions.
450 //
451 // The overall goal is to prove that Values are dead, which allows deleting ops
452 // and block arguments.
453 //
454 // This uses an optimistic algorithm that assumes everything is dead until
455 // proved otherwise, allowing it to delete recursively dead cycles.
456 //
457 // This is a simple fixed-point dataflow analysis algorithm on a lattice
458 // {Dead,Alive}. Because liveness flows backward, we generally try to
459 // iterate everything backward to speed up convergence to the fixed-point. This
460 // allows for being able to delete recursively dead cycles of the use-def graph,
461 // including block arguments.
462 //
463 // This function returns success if any operations or arguments were deleted,
464 // failure otherwise.
465 LogicalResult mlir::runRegionDCE(RewriterBase &rewriter,
466  MutableArrayRef<Region> regions) {
467  LiveMap liveMap;
468  do {
469  liveMap.resetChanged();
470 
471  for (Region &region : regions)
472  propagateLiveness(region, liveMap);
473  } while (liveMap.hasChanged());
474 
475  return deleteDeadness(rewriter, regions, liveMap);
476 }
477 
478 //===----------------------------------------------------------------------===//
479 // Block Merging
480 //===----------------------------------------------------------------------===//
481 
482 //===----------------------------------------------------------------------===//
483 // BlockEquivalenceData
484 
485 namespace {
486 /// This class contains the information for comparing the equivalencies of two
487 /// blocks. Blocks are considered equivalent if they contain the same operations
488 /// in the same order. The only allowed divergence is for operands that come
489 /// from sources outside of the parent block, i.e. the uses of values produced
490 /// within the block must be equivalent.
491 /// e.g.,
492 /// Equivalent:
493 /// ^bb1(%arg0: i32)
494 /// return %arg0, %foo : i32, i32
495 /// ^bb2(%arg1: i32)
496 /// return %arg1, %bar : i32, i32
497 /// Not Equivalent:
498 /// ^bb1(%arg0: i32)
499 /// return %foo, %arg0 : i32, i32
500 /// ^bb2(%arg1: i32)
501 /// return %arg1, %bar : i32, i32
502 struct BlockEquivalenceData {
503  BlockEquivalenceData(Block *block);
504 
505  /// Return the order index for the given value that is within the block of
506  /// this data.
507  unsigned getOrderOf(Value value) const;
508 
509  /// The block this data refers to.
510  Block *block;
511  /// A hash value for this block.
512  llvm::hash_code hash;
513  /// A map of result producing operations to their relative orders within this
514  /// block. The order of an operation is the number of defined values that are
515  /// produced within the block before this operation.
516  DenseMap<Operation *, unsigned> opOrderIndex;
517 };
518 } // namespace
519 
520 BlockEquivalenceData::BlockEquivalenceData(Block *block)
521  : block(block), hash(0) {
522  unsigned orderIt = block->getNumArguments();
523  for (Operation &op : *block) {
524  if (unsigned numResults = op.getNumResults()) {
525  opOrderIndex.try_emplace(&op, orderIt);
526  orderIt += numResults;
527  }
528  auto opHash = OperationEquivalence::computeHash(
532  hash = llvm::hash_combine(hash, opHash);
533  }
534 }
535 
536 unsigned BlockEquivalenceData::getOrderOf(Value value) const {
537  assert(value.getParentBlock() == block && "expected value of this block");
538 
539  // Arguments use the argument number as the order index.
540  if (BlockArgument arg = dyn_cast<BlockArgument>(value))
541  return arg.getArgNumber();
542 
543  // Otherwise, the result order is offset from the parent op's order.
544  OpResult result = cast<OpResult>(value);
545  auto opOrderIt = opOrderIndex.find(result.getDefiningOp());
546  assert(opOrderIt != opOrderIndex.end() && "expected op to have an order");
547  return opOrderIt->second + result.getResultNumber();
548 }
549 
550 //===----------------------------------------------------------------------===//
551 // BlockMergeCluster
552 
553 namespace {
554 /// This class represents a cluster of blocks to be merged together.
555 class BlockMergeCluster {
556 public:
557  BlockMergeCluster(BlockEquivalenceData &&leaderData)
558  : leaderData(std::move(leaderData)) {}
559 
560  /// Attempt to add the given block to this cluster. Returns success if the
561  /// block was merged, failure otherwise.
562  LogicalResult addToCluster(BlockEquivalenceData &blockData);
563 
564  /// Try to merge all of the blocks within this cluster into the leader block.
565  LogicalResult merge(RewriterBase &rewriter);
566 
567 private:
568  /// The equivalence data for the leader of the cluster.
569  BlockEquivalenceData leaderData;
570 
571  /// The set of blocks that can be merged into the leader.
572  llvm::SmallSetVector<Block *, 1> blocksToMerge;
573 
574  /// A set of operand+index pairs that correspond to operands that need to be
575  /// replaced by arguments when the cluster gets merged.
576  std::set<std::pair<int, int>> operandsToMerge;
577 };
578 } // namespace
579 
580 LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) {
581  if (leaderData.hash != blockData.hash)
582  return failure();
583  Block *leaderBlock = leaderData.block, *mergeBlock = blockData.block;
584  if (leaderBlock->getArgumentTypes() != mergeBlock->getArgumentTypes())
585  return failure();
586 
587  // A set of operands that mismatch between the leader and the new block.
588  SmallVector<std::pair<int, int>, 8> mismatchedOperands;
589  auto lhsIt = leaderBlock->begin(), lhsE = leaderBlock->end();
590  auto rhsIt = blockData.block->begin(), rhsE = blockData.block->end();
591  for (int opI = 0; lhsIt != lhsE && rhsIt != rhsE; ++lhsIt, ++rhsIt, ++opI) {
592  // Check that the operations are equivalent.
593  if (!OperationEquivalence::isEquivalentTo(
594  &*lhsIt, &*rhsIt, OperationEquivalence::ignoreValueEquivalence,
595  /*markEquivalent=*/nullptr,
596  OperationEquivalence::Flags::IgnoreLocations))
597  return failure();
598 
599  // Compare the operands of the two operations. If the operand is within
600  // the block, it must refer to the same operation.
601  auto lhsOperands = lhsIt->getOperands(), rhsOperands = rhsIt->getOperands();
602  for (int operand : llvm::seq<int>(0, lhsIt->getNumOperands())) {
603  Value lhsOperand = lhsOperands[operand];
604  Value rhsOperand = rhsOperands[operand];
605  if (lhsOperand == rhsOperand)
606  continue;
607  // Check that the types of the operands match.
608  if (lhsOperand.getType() != rhsOperand.getType())
609  return failure();
610 
611  // Check that these uses are both external, or both internal.
612  bool lhsIsInBlock = lhsOperand.getParentBlock() == leaderBlock;
613  bool rhsIsInBlock = rhsOperand.getParentBlock() == mergeBlock;
614  if (lhsIsInBlock != rhsIsInBlock)
615  return failure();
616  // Let the operands differ if they are defined in a different block. These
617  // will become new arguments if the blocks get merged.
618  if (!lhsIsInBlock) {
619 
620  // Check whether the operands aren't the result of an immediate
621  // predecessors terminator. In that case we are not able to use it as a
622  // successor operand when branching to the merged block as it does not
623  // dominate its producing operation.
624  auto isValidSuccessorArg = [](Block *block, Value operand) {
625  if (operand.getDefiningOp() !=
626  operand.getParentBlock()->getTerminator())
627  return true;
628  return !llvm::is_contained(block->getPredecessors(),
629  operand.getParentBlock());
630  };
631 
632  if (!isValidSuccessorArg(leaderBlock, lhsOperand) ||
633  !isValidSuccessorArg(mergeBlock, rhsOperand))
634  return failure();
635 
636  mismatchedOperands.emplace_back(opI, operand);
637  continue;
638  }
639 
640  // Otherwise, these operands must have the same logical order within the
641  // parent block.
642  if (leaderData.getOrderOf(lhsOperand) != blockData.getOrderOf(rhsOperand))
643  return failure();
644  }
645 
646  // If the lhs or rhs has external uses, the blocks cannot be merged as the
647  // merged version of this operation will not be either the lhs or rhs
648  // alone (thus semantically incorrect), but some mix dependending on which
649  // block preceeded this.
650  // TODO allow merging of operations when one block does not dominate the
651  // other
652  if (rhsIt->isUsedOutsideOfBlock(mergeBlock) ||
653  lhsIt->isUsedOutsideOfBlock(leaderBlock)) {
654  return failure();
655  }
656  }
657  // Make sure that the block sizes are equivalent.
658  if (lhsIt != lhsE || rhsIt != rhsE)
659  return failure();
660 
661  // If we get here, the blocks are equivalent and can be merged.
662  operandsToMerge.insert(mismatchedOperands.begin(), mismatchedOperands.end());
663  blocksToMerge.insert(blockData.block);
664  return success();
665 }
666 
667 /// Returns true if the predecessor terminators of the given block can not have
668 /// their operands updated.
669 static bool ableToUpdatePredOperands(Block *block) {
670  for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
671  if (!isa<BranchOpInterface>((*it)->getTerminator()))
672  return false;
673  }
674  return true;
675 }
676 
677 LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
678  // Don't consider clusters that don't have blocks to merge.
679  if (blocksToMerge.empty())
680  return failure();
681 
682  Block *leaderBlock = leaderData.block;
683  if (!operandsToMerge.empty()) {
684  // If the cluster has operands to merge, verify that the predecessor
685  // terminators of each of the blocks can have their successor operands
686  // updated.
687  // TODO: We could try and sub-partition this cluster if only some blocks
688  // cause the mismatch.
689  if (!ableToUpdatePredOperands(leaderBlock) ||
690  !llvm::all_of(blocksToMerge, ableToUpdatePredOperands))
691  return failure();
692 
693  // Collect the iterators for each of the blocks to merge. We will walk all
694  // of the iterators at once to avoid operand index invalidation.
695  SmallVector<Block::iterator, 2> blockIterators;
696  blockIterators.reserve(blocksToMerge.size() + 1);
697  blockIterators.push_back(leaderBlock->begin());
698  for (Block *mergeBlock : blocksToMerge)
699  blockIterators.push_back(mergeBlock->begin());
700 
701  // Update each of the predecessor terminators with the new arguments.
702  SmallVector<SmallVector<Value, 8>, 2> newArguments(
703  1 + blocksToMerge.size(),
704  SmallVector<Value, 8>(operandsToMerge.size()));
705  unsigned curOpIndex = 0;
706  for (const auto &it : llvm::enumerate(operandsToMerge)) {
707  unsigned nextOpOffset = it.value().first - curOpIndex;
708  curOpIndex = it.value().first;
709 
710  // Process the operand for each of the block iterators.
711  for (unsigned i = 0, e = blockIterators.size(); i != e; ++i) {
712  Block::iterator &blockIter = blockIterators[i];
713  std::advance(blockIter, nextOpOffset);
714  auto &operand = blockIter->getOpOperand(it.value().second);
715  newArguments[i][it.index()] = operand.get();
716 
717  // Update the operand and insert an argument if this is the leader.
718  if (i == 0) {
719  Value operandVal = operand.get();
720  operand.set(leaderBlock->addArgument(operandVal.getType(),
721  operandVal.getLoc()));
722  }
723  }
724  }
725  // Update the predecessors for each of the blocks.
726  auto updatePredecessors = [&](Block *block, unsigned clusterIndex) {
727  for (auto predIt = block->pred_begin(), predE = block->pred_end();
728  predIt != predE; ++predIt) {
729  auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
730  unsigned succIndex = predIt.getSuccessorIndex();
731  branch.getSuccessorOperands(succIndex).append(
732  newArguments[clusterIndex]);
733  }
734  };
735  updatePredecessors(leaderBlock, /*clusterIndex=*/0);
736  for (unsigned i = 0, e = blocksToMerge.size(); i != e; ++i)
737  updatePredecessors(blocksToMerge[i], /*clusterIndex=*/i + 1);
738  }
739 
740  // Replace all uses of the merged blocks with the leader and erase them.
741  for (Block *block : blocksToMerge) {
742  block->replaceAllUsesWith(leaderBlock);
743  rewriter.eraseBlock(block);
744  }
745  return success();
746 }
747 
748 /// Identify identical blocks within the given region and merge them, inserting
749 /// new block arguments as necessary. Returns success if any blocks were merged,
750 /// failure otherwise.
751 static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
752  Region &region) {
753  if (region.empty() || llvm::hasSingleElement(region))
754  return failure();
755 
756  // Identify sets of blocks, other than the entry block, that branch to the
757  // same successors. We will use these groups to create clusters of equivalent
758  // blocks.
760  for (Block &block : llvm::drop_begin(region, 1))
761  matchingSuccessors[block.getSuccessors()].push_back(&block);
762 
763  bool mergedAnyBlocks = false;
764  for (ArrayRef<Block *> blocks : llvm::make_second_range(matchingSuccessors)) {
765  if (blocks.size() == 1)
766  continue;
767 
769  for (Block *block : blocks) {
770  BlockEquivalenceData data(block);
771 
772  // Don't allow merging if this block has any regions.
773  // TODO: Add support for regions if necessary.
774  bool hasNonEmptyRegion = llvm::any_of(*block, [](Operation &op) {
775  return llvm::any_of(op.getRegions(),
776  [](Region &region) { return !region.empty(); });
777  });
778  if (hasNonEmptyRegion)
779  continue;
780 
781  // Try to add this block to an existing cluster.
782  bool addedToCluster = false;
783  for (auto &cluster : clusters)
784  if ((addedToCluster = succeeded(cluster.addToCluster(data))))
785  break;
786  if (!addedToCluster)
787  clusters.emplace_back(std::move(data));
788  }
789  for (auto &cluster : clusters)
790  mergedAnyBlocks |= succeeded(cluster.merge(rewriter));
791  }
792 
793  return success(mergedAnyBlocks);
794 }
795 
796 /// Identify identical blocks within the given regions and merge them, inserting
797 /// new block arguments as necessary.
798 static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
799  MutableArrayRef<Region> regions) {
800  llvm::SmallSetVector<Region *, 1> worklist;
801  for (auto &region : regions)
802  worklist.insert(&region);
803  bool anyChanged = false;
804  while (!worklist.empty()) {
805  Region *region = worklist.pop_back_val();
806  if (succeeded(mergeIdenticalBlocks(rewriter, *region))) {
807  worklist.insert(region);
808  anyChanged = true;
809  }
810 
811  // Add any nested regions to the worklist.
812  for (Block &block : *region)
813  for (auto &op : block)
814  for (auto &nestedRegion : op.getRegions())
815  worklist.insert(&nestedRegion);
816  }
817 
818  return success(anyChanged);
819 }
820 
821 //===----------------------------------------------------------------------===//
822 // Region Simplification
823 //===----------------------------------------------------------------------===//
824 
825 /// Run a set of structural simplifications over the given regions. This
826 /// includes transformations like unreachable block elimination, dead argument
827 /// elimination, as well as some other DCE. This function returns success if any
828 /// of the regions were simplified, failure otherwise.
829 LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
830  MutableArrayRef<Region> regions,
831  bool mergeBlocks) {
832  bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
833  bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
834  bool mergedIdenticalBlocks = false;
835  if (mergeBlocks)
836  mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions));
837  return success(eliminatedBlocks || eliminatedOpsOrArgs ||
838  mergedIdenticalBlocks);
839 }
static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter, Region &region)
Identify identical blocks within the given region and merge them, inserting new block arguments as ne...
static void propagateLiveness(Region &region, LiveMap &liveMap)
static void processValue(Value value, LiveMap &liveMap)
static bool ableToUpdatePredOperands(Block *block)
Returns true if the predecessor terminators of the given block can not have their operands updated.
static void eraseTerminatorSuccessorOperands(Operation *terminator, LiveMap &liveMap)
static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap)
static bool isUseSpeciallyKnownDead(OpOperand &use, LiveMap &liveMap)
static LogicalResult deleteDeadness(RewriterBase &rewriter, MutableArrayRef< Region > regions, LiveMap &liveMap)
This class represents an argument of a Block.
Definition: Value.h:319
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:331
Block represents an ordered list of Operations.
Definition: Block.h:31
OpListType::iterator iterator
Definition: Block.h:138
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:148
BlockArgument getArgument(unsigned i)
Definition: Block.h:127
unsigned getNumArguments()
Definition: Block.h:126
pred_iterator pred_begin()
Definition: Block.h:231
SuccessorRange getSuccessors()
Definition: Block.h:265
iterator_range< pred_iterator > getPredecessors()
Definition: Block.h:235
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:152
BlockArgListType getArguments()
Definition: Block.h:85
iterator end()
Definition: Block.h:142
iterator begin()
Definition: Block.h:141
pred_iterator pred_end()
Definition: Block.h:234
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: UseDefLists.h:211
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:351
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:559
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:434
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:441
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:764
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
Block * getSuccessor(unsigned index)
Definition: Operation.h:704
unsigned getNumSuccessors()
Definition: Operation.h:702
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
SuccessorRange getSuccessors()
Definition: Operation.h:699
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition: Region.cpp:45
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition: Region.h:222
bool empty()
Definition: Region.h:60
iterator end()
Definition: Region.h:56
iterator begin()
Definition: Region.h:55
BlockListType & getBlocks()
Definition: Region.h:45
Block & front()
Definition: Region.h:65
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
Definition: Region.h:285
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
void replaceOpUsesWithIf(Operation *from, ValueRange to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Definition: PatternMatch.h:679
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
This class models how operands are forwarded to block arguments in control flow.
void erase(unsigned subStart, unsigned subLen=1)
Erase operands forwarded to the successor.
unsigned getProducedOperandCount() const
Returns the amount of operands that are produced internally by the operation.
unsigned size() const
Returns the amount of operands passed to the successor.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:212
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Include the generated interface declarations.
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region &region)
Replace all uses of orig within the given region with replacement.
Definition: RegionUtils.cpp:27
bool computeTopologicalSorting(MutableArrayRef< Operation * > ops, function_ref< bool(Value, Operation *)> isOperandReady=nullptr)
Compute a topological ordering of the given ops.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
LogicalResult eraseUnreachableBlocks(RewriterBase &rewriter, MutableArrayRef< Region > regions)
Erase the unreachable blocks within the provided regions.
SmallVector< Value > makeRegionIsolatedFromAbove(RewriterBase &rewriter, Region &region, llvm::function_ref< bool(Operation *)> cloneOperationIntoRegion=[](Operation *) { return false;})
Make a region isolated from above.
Definition: RegionUtils.cpp:79
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
Definition: RegionUtils.cpp:62
LogicalResult runRegionDCE(RewriterBase &rewriter, MutableArrayRef< Region > regions)
This function returns success if any operations or arguments were deleted, failure otherwise.
LogicalResult simplifyRegions(RewriterBase &rewriter, MutableArrayRef< Region > regions, bool mergeBlocks=true)
Run a set of structural simplifications over the given regions.
void visitUsedValuesDefinedAbove(Region &region, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
Definition: RegionUtils.cpp:35
static llvm::hash_code ignoreHashValue(Value)
Helper that can be used with computeHash above to ignore operation operands/result mapping.
static llvm::hash_code computeHash(Operation *op, function_ref< llvm::hash_code(Value)> hashOperands=[](Value v) { return hash_value(v);}, function_ref< llvm::hash_code(Value)> hashResults=[](Value v) { return hash_value(v);}, Flags flags=Flags::None)
Compute a hash for the given operation.