MLIR  18.0.0git
Mem2Reg.cpp
Go to the documentation of this file.
1 //===- Mem2Reg.cpp - Promotes memory slots into values ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/Dominance.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/IR/Value.h"
18 #include "mlir/Transforms/Passes.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/Support/Casting.h"
23 #include "llvm/Support/GenericIteratedDominanceFrontier.h"
24 
25 namespace mlir {
26 #define GEN_PASS_DEF_MEM2REG
27 #include "mlir/Transforms/Passes.h.inc"
28 } // namespace mlir
29 
30 #define DEBUG_TYPE "mem2reg"
31 
32 using namespace mlir;
33 
34 /// mem2reg
35 ///
36 /// This pass turns unnecessary uses of automatically allocated memory slots
37 /// into direct Value-based operations. For example, it will simplify storing a
38 /// constant in a memory slot to immediately load it to a direct use of that
39 /// constant. In other words, given a memory slot addressed by a non-aliased
40 /// "pointer" Value, mem2reg removes all the uses of that pointer.
41 ///
42 /// Within a block, this is done by following the chain of stores and loads of
43 /// the slot and replacing the results of loads with the values previously
44 /// stored. If a load happens before any other store, a poison value is used
45 /// instead.
46 ///
47 /// Control flow can create situations where a load could be replaced by
48 /// multiple possible stores depending on the control flow path taken. As a
49 /// result, this pass must introduce new block arguments in some blocks to
50 /// accomodate for the multiple possible definitions. Each predecessor will
51 /// populate the block argument with the definition reached at its end. With
52 /// this, the value stored can be well defined at block boundaries, allowing
53 /// the propagation of replacement through blocks.
54 ///
55 /// This pass computes this transformation in four main steps. The two first
56 /// steps are performed during an analysis phase that does not mutate IR.
57 ///
58 /// The two steps of the analysis phase are the following:
59 /// - A first step computes the list of operations that transitively use the
60 /// memory slot we would like to promote. The purpose of this phase is to
61 /// identify which uses must be removed to promote the slot, either by rewiring
62 /// the user or deleting it. Naturally, direct uses of the slot must be removed.
63 /// Sometimes additional uses must also be removed: this is notably the case
64 /// when a direct user of the slot cannot rewire its use and must delete itself,
65 /// and thus must make its users no longer use it. If any of those uses cannot
66 /// be removed by their users in any way, promotion cannot continue: this is
67 /// decided at this step.
68 /// - A second step computes the list of blocks where a block argument will be
69 /// needed ("merge points") without mutating the IR. These blocks are the blocks
70 /// leading to a definition clash between two predecessors. Such blocks happen
71 /// to be the Iterated Dominance Frontier (IDF) of the set of blocks containing
72 /// a store, as they represent the point where a clear defining dominator stops
73 /// existing. Computing this information in advance allows making sure the
74 /// terminators that will forward values are capable of doing so (inability to
75 /// do so aborts promotion at this step).
76 ///
77 /// At this point, promotion is guaranteed to happen, and the mutation phase can
78 /// begin with the following steps:
79 /// - A third step computes the reaching definition of the memory slot at each
80 /// blocking user. This is the core of the mem2reg algorithm, also known as
81 /// load-store forwarding. This analyses loads and stores and propagates which
82 /// value must be stored in the slot at each blocking user. This is achieved by
83 /// doing a depth-first walk of the dominator tree of the function. This is
84 /// sufficient because the reaching definition at the beginning of a block is
85 /// either its new block argument if it is a merge block, or the definition
86 /// reaching the end of its immediate dominator (parent in the dominator tree).
87 /// We can therefore propagate this information down the dominator tree to
88 /// proceed with renaming within blocks.
89 /// - The final fourth step uses the reaching definition to remove blocking uses
90 /// in topological order.
91 ///
92 /// For further reading, chapter three of SSA-based Compiler Design [1]
93 /// showcases SSA construction, where mem2reg is an adaptation of the same
94 /// process.
95 ///
96 /// [1]: Rastello F. & Bouchez Tichadou F., SSA-based Compiler Design (2022),
97 /// Springer.
98 
99 namespace {
100 
101 using BlockingUsesMap =
102  llvm::MapVector<Operation *, SmallPtrSet<OpOperand *, 4>>;
103 
104 /// Information computed during promotion analysis used to perform actual
105 /// promotion.
106 struct MemorySlotPromotionInfo {
107  /// Blocks for which at least two definitions of the slot values clash.
108  SmallPtrSet<Block *, 8> mergePoints;
109  /// Contains, for each operation, which uses must be eliminated by promotion.
110  /// This is a DAG structure because if an operation must eliminate some of
111  /// its uses, it is because the defining ops of the blocking uses requested
112  /// it. The defining ops therefore must also have blocking uses or be the
113  /// starting point of the bloccking uses.
114  BlockingUsesMap userToBlockingUses;
115 };
116 
117 /// Computes information for basic slot promotion. This will check that direct
118 /// slot promotion can be performed, and provide the information to execute the
119 /// promotion. This does not mutate IR.
120 class MemorySlotPromotionAnalyzer {
121 public:
122  MemorySlotPromotionAnalyzer(MemorySlot slot, DominanceInfo &dominance)
123  : slot(slot), dominance(dominance) {}
124 
125  /// Computes the information for slot promotion if promotion is possible,
126  /// returns nothing otherwise.
127  std::optional<MemorySlotPromotionInfo> computeInfo();
128 
129 private:
130  /// Computes the transitive uses of the slot that block promotion. This finds
131  /// uses that would block the promotion, checks that the operation has a
132  /// solution to remove the blocking use, and potentially forwards the analysis
133  /// if the operation needs further blocking uses resolved to resolve its own
134  /// uses (typically, removing its users because it will delete itself to
135  /// resolve its own blocking uses). This will fail if one of the transitive
136  /// users cannot remove a requested use, and should prevent promotion.
137  LogicalResult computeBlockingUses(BlockingUsesMap &userToBlockingUses);
138 
139  /// Computes in which blocks the value stored in the slot is actually used,
140  /// meaning blocks leading to a load. This method uses `definingBlocks`, the
141  /// set of blocks containing a store to the slot (defining the value of the
142  /// slot).
144  computeSlotLiveIn(SmallPtrSetImpl<Block *> &definingBlocks);
145 
146  /// Computes the points in which multiple re-definitions of the slot's value
147  /// (stores) may conflict.
148  void computeMergePoints(SmallPtrSetImpl<Block *> &mergePoints);
149 
150  /// Ensures predecessors of merge points can properly provide their current
151  /// definition of the value stored in the slot to the merge point. This can
152  /// notably be an issue if the terminator used does not have the ability to
153  /// forward values through block operands.
154  bool areMergePointsUsable(SmallPtrSetImpl<Block *> &mergePoints);
155 
156  MemorySlot slot;
157  DominanceInfo &dominance;
158 };
159 
160 /// The MemorySlotPromoter handles the state of promoting a memory slot. It
161 /// wraps a slot and its associated allocator. This will perform the mutation of
162 /// IR.
163 class MemorySlotPromoter {
164 public:
165  MemorySlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator,
166  RewriterBase &rewriter, DominanceInfo &dominance,
167  MemorySlotPromotionInfo info,
168  const Mem2RegStatistics &statistics);
169 
170  /// Actually promotes the slot by mutating IR. Promoting a slot DOES
171  /// invalidate the MemorySlotPromotionInfo of other slots. Preparation of
172  /// promotion info should NOT be performed in batches.
173  void promoteSlot();
174 
175 private:
176  /// Computes the reaching definition for all the operations that require
177  /// promotion. `reachingDef` is the value the slot should contain at the
178  /// beginning of the block. This method returns the reached definition at the
179  /// end of the block. This method must only be called at most once per block.
180  Value computeReachingDefInBlock(Block *block, Value reachingDef);
181 
182  /// Computes the reaching definition for all the operations that require
183  /// promotion. `reachingDef` corresponds to the initial value the
184  /// slot will contain before any write, typically a poison value.
185  /// This method must only be called at most once per region.
186  void computeReachingDefInRegion(Region *region, Value reachingDef);
187 
188  /// Removes the blocking uses of the slot, in topological order.
189  void removeBlockingUses();
190 
191  /// Lazily-constructed default value representing the content of the slot when
192  /// no store has been executed. This function may mutate IR.
193  Value getLazyDefaultValue();
194 
195  MemorySlot slot;
196  PromotableAllocationOpInterface allocator;
197  RewriterBase &rewriter;
198  /// Potentially non-initialized default value. Use `getLazyDefaultValue` to
199  /// initialize it on demand.
200  Value defaultValue;
201  /// Contains the reaching definition at this operation. Reaching definitions
202  /// are only computed for promotable memory operations with blocking uses.
204  DominanceInfo &dominance;
205  MemorySlotPromotionInfo info;
206  const Mem2RegStatistics &statistics;
207 };
208 
209 } // namespace
210 
211 MemorySlotPromoter::MemorySlotPromoter(
212  MemorySlot slot, PromotableAllocationOpInterface allocator,
213  RewriterBase &rewriter, DominanceInfo &dominance,
214  MemorySlotPromotionInfo info, const Mem2RegStatistics &statistics)
215  : slot(slot), allocator(allocator), rewriter(rewriter),
216  dominance(dominance), info(std::move(info)), statistics(statistics) {
217 #ifndef NDEBUG
218  auto isResultOrNewBlockArgument = [&]() {
219  if (BlockArgument arg = dyn_cast<BlockArgument>(slot.ptr))
220  return arg.getOwner()->getParentOp() == allocator;
221  return slot.ptr.getDefiningOp() == allocator;
222  };
223 
224  assert(isResultOrNewBlockArgument() &&
225  "a slot must be a result of the allocator or an argument of the child "
226  "regions of the allocator");
227 #endif // NDEBUG
228 }
229 
230 Value MemorySlotPromoter::getLazyDefaultValue() {
231  if (defaultValue)
232  return defaultValue;
233 
234  RewriterBase::InsertionGuard guard(rewriter);
236  return defaultValue = allocator.getDefaultValue(slot, rewriter);
237 }
238 
239 LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
240  BlockingUsesMap &userToBlockingUses) {
241  // The promotion of an operation may require the promotion of further
242  // operations (typically, removing operations that use an operation that must
243  // delete itself). We thus need to start from the use of the slot pointer and
244  // propagate further requests through the forward slice.
245 
246  // First insert that all immediate users of the slot pointer must no longer
247  // use it.
248  for (OpOperand &use : slot.ptr.getUses()) {
249  SmallPtrSet<OpOperand *, 4> &blockingUses =
250  userToBlockingUses[use.getOwner()];
251  blockingUses.insert(&use);
252  }
253 
254  // Then, propagate the requirements for the removal of uses. The
255  // topologically-sorted forward slice allows for all blocking uses of an
256  // operation to have been computed before it is reached. Operations are
257  // traversed in topological order of their uses, starting from the slot
258  // pointer.
259  SetVector<Operation *> forwardSlice;
260  mlir::getForwardSlice(slot.ptr, &forwardSlice);
261  for (Operation *user : forwardSlice) {
262  // If the next operation has no blocking uses, everything is fine.
263  if (!userToBlockingUses.contains(user))
264  continue;
265 
266  SmallPtrSet<OpOperand *, 4> &blockingUses = userToBlockingUses[user];
267 
268  SmallVector<OpOperand *> newBlockingUses;
269  // If the operation decides it cannot deal with removing the blocking uses,
270  // promotion must fail.
271  if (auto promotable = dyn_cast<PromotableOpInterface>(user)) {
272  if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses))
273  return failure();
274  } else if (auto promotable = dyn_cast<PromotableMemOpInterface>(user)) {
275  if (!promotable.canUsesBeRemoved(slot, blockingUses, newBlockingUses))
276  return failure();
277  } else {
278  // An operation that has blocking uses must be promoted. If it is not
279  // promotable, promotion must fail.
280  return failure();
281  }
282 
283  // Then, register any new blocking uses for coming operations.
284  for (OpOperand *blockingUse : newBlockingUses) {
285  assert(llvm::is_contained(user->getResults(), blockingUse->get()));
286 
287  SmallPtrSetImpl<OpOperand *> &newUserBlockingUseSet =
288  userToBlockingUses[blockingUse->getOwner()];
289  newUserBlockingUseSet.insert(blockingUse);
290  }
291  }
292 
293  // Because this pass currently only supports analysing the parent region of
294  // the slot pointer, if a promotable memory op that needs promotion is outside
295  // of this region, promotion must fail because it will be impossible to
296  // provide a valid `reachingDef` for it.
297  for (auto &[toPromote, _] : userToBlockingUses)
298  if (isa<PromotableMemOpInterface>(toPromote) &&
299  toPromote->getParentRegion() != slot.ptr.getParentRegion())
300  return failure();
301 
302  return success();
303 }
304 
305 SmallPtrSet<Block *, 16> MemorySlotPromotionAnalyzer::computeSlotLiveIn(
306  SmallPtrSetImpl<Block *> &definingBlocks) {
308 
309  // The worklist contains blocks in which it is known that the slot value is
310  // live-in. The further blocks where this value is live-in will be inferred
311  // from these.
312  SmallVector<Block *> liveInWorkList;
313 
314  // Blocks with a load before any other store to the slot are the starting
315  // points of the analysis. The slot value is definitely live-in in those
316  // blocks.
317  SmallPtrSet<Block *, 16> visited;
318  for (Operation *user : slot.ptr.getUsers()) {
319  if (visited.contains(user->getBlock()))
320  continue;
321  visited.insert(user->getBlock());
322 
323  for (Operation &op : user->getBlock()->getOperations()) {
324  if (auto memOp = dyn_cast<PromotableMemOpInterface>(op)) {
325  // If this operation loads the slot, it is loading from it before
326  // ever writing to it, so the value is live-in in this block.
327  if (memOp.loadsFrom(slot)) {
328  liveInWorkList.push_back(user->getBlock());
329  break;
330  }
331 
332  // If we store to the slot, further loads will see that value.
333  // Because we did not meet any load before, the value is not live-in.
334  if (memOp.storesTo(slot))
335  break;
336  }
337  }
338  }
339 
340  // The information is then propagated to the predecessors until a def site
341  // (store) is found.
342  while (!liveInWorkList.empty()) {
343  Block *liveInBlock = liveInWorkList.pop_back_val();
344 
345  if (!liveIn.insert(liveInBlock).second)
346  continue;
347 
348  // If a predecessor is a defining block, either:
349  // - It has a load before its first store, in which case it is live-in but
350  // has already been processed in the initialisation step.
351  // - It has a store before any load, in which case it is not live-in.
352  // We can thus at this stage insert to the worklist only predecessors that
353  // are not defining blocks.
354  for (Block *pred : liveInBlock->getPredecessors())
355  if (!definingBlocks.contains(pred))
356  liveInWorkList.push_back(pred);
357  }
358 
359  return liveIn;
360 }
361 
362 using IDFCalculator = llvm::IDFCalculatorBase<Block, false>;
363 void MemorySlotPromotionAnalyzer::computeMergePoints(
364  SmallPtrSetImpl<Block *> &mergePoints) {
365  if (slot.ptr.getParentRegion()->hasOneBlock())
366  return;
367 
368  IDFCalculator idfCalculator(dominance.getDomTree(slot.ptr.getParentRegion()));
369 
370  SmallPtrSet<Block *, 16> definingBlocks;
371  for (Operation *user : slot.ptr.getUsers())
372  if (auto storeOp = dyn_cast<PromotableMemOpInterface>(user))
373  if (storeOp.storesTo(slot))
374  definingBlocks.insert(user->getBlock());
375 
376  idfCalculator.setDefiningBlocks(definingBlocks);
377 
378  SmallPtrSet<Block *, 16> liveIn = computeSlotLiveIn(definingBlocks);
379  idfCalculator.setLiveInBlocks(liveIn);
380 
381  SmallVector<Block *> mergePointsVec;
382  idfCalculator.calculate(mergePointsVec);
383 
384  mergePoints.insert(mergePointsVec.begin(), mergePointsVec.end());
385 }
386 
387 bool MemorySlotPromotionAnalyzer::areMergePointsUsable(
388  SmallPtrSetImpl<Block *> &mergePoints) {
389  for (Block *mergePoint : mergePoints)
390  for (Block *pred : mergePoint->getPredecessors())
391  if (!isa<BranchOpInterface>(pred->getTerminator()))
392  return false;
393 
394  return true;
395 }
396 
397 std::optional<MemorySlotPromotionInfo>
398 MemorySlotPromotionAnalyzer::computeInfo() {
399  MemorySlotPromotionInfo info;
400 
401  // First, find the set of operations that will need to be changed for the
402  // promotion to happen. These operations need to resolve some of their uses,
403  // either by rewiring them or simply deleting themselves. If any of them
404  // cannot find a way to resolve their blocking uses, we abort the promotion.
405  if (failed(computeBlockingUses(info.userToBlockingUses)))
406  return {};
407 
408  // Then, compute blocks in which two or more definitions of the allocated
409  // variable may conflict. These blocks will need a new block argument to
410  // accomodate this.
411  computeMergePoints(info.mergePoints);
412 
413  // The slot can be promoted if the block arguments to be created can
414  // actually be populated with values, which may not be possible depending
415  // on their predecessors.
416  if (!areMergePointsUsable(info.mergePoints))
417  return {};
418 
419  return info;
420 }
421 
422 Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,
423  Value reachingDef) {
424  SmallVector<Operation *> blockOps;
425  for (Operation &op : block->getOperations())
426  blockOps.push_back(&op);
427  for (Operation *op : blockOps) {
428  if (auto memOp = dyn_cast<PromotableMemOpInterface>(op)) {
429  if (info.userToBlockingUses.contains(memOp))
430  reachingDefs.insert({memOp, reachingDef});
431 
432  if (memOp.storesTo(slot)) {
433  rewriter.setInsertionPointAfter(memOp);
434  Value stored = memOp.getStored(slot, rewriter);
435  assert(stored && "a memory operation storing to a slot must provide a "
436  "new definition of the slot");
437  reachingDef = stored;
438  }
439  }
440  }
441 
442  return reachingDef;
443 }
444 
445 void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
446  Value reachingDef) {
447  if (region->hasOneBlock()) {
448  computeReachingDefInBlock(&region->front(), reachingDef);
449  return;
450  }
451 
452  struct DfsJob {
453  llvm::DomTreeNodeBase<Block> *block;
454  Value reachingDef;
455  };
456 
457  SmallVector<DfsJob> dfsStack;
458 
459  auto &domTree = dominance.getDomTree(slot.ptr.getParentRegion());
460 
461  dfsStack.emplace_back<DfsJob>(
462  {domTree.getNode(&region->front()), reachingDef});
463 
464  while (!dfsStack.empty()) {
465  DfsJob job = dfsStack.pop_back_val();
466  Block *block = job.block->getBlock();
467 
468  if (info.mergePoints.contains(block)) {
469  // If the block is a merge point, we need to add a block argument to hold
470  // the selected reaching definition. This has to be a bit complicated
471  // because of RewriterBase limitations: we need to create a new block with
472  // the extra block argument, move the content of the block to the new
473  // block, and replace the block with the new block in the merge point set.
474  SmallVector<Type> argTypes;
475  SmallVector<Location> argLocs;
476  for (BlockArgument arg : block->getArguments()) {
477  argTypes.push_back(arg.getType());
478  argLocs.push_back(arg.getLoc());
479  }
480  argTypes.push_back(slot.elemType);
481  argLocs.push_back(slot.ptr.getLoc());
482  Block *newBlock = rewriter.createBlock(block, argTypes, argLocs);
483 
484  info.mergePoints.erase(block);
485  info.mergePoints.insert(newBlock);
486 
487  rewriter.replaceAllUsesWith(block, newBlock);
488  rewriter.mergeBlocks(block, newBlock,
489  newBlock->getArguments().drop_back());
490 
491  block = newBlock;
492 
493  BlockArgument blockArgument = block->getArguments().back();
494  rewriter.setInsertionPointToStart(block);
495  allocator.handleBlockArgument(slot, blockArgument, rewriter);
496  job.reachingDef = blockArgument;
497 
498  if (statistics.newBlockArgumentAmount)
499  (*statistics.newBlockArgumentAmount)++;
500  }
501 
502  job.reachingDef = computeReachingDefInBlock(block, job.reachingDef);
503 
504  if (auto terminator = dyn_cast<BranchOpInterface>(block->getTerminator())) {
505  for (BlockOperand &blockOperand : terminator->getBlockOperands()) {
506  if (info.mergePoints.contains(blockOperand.get())) {
507  if (!job.reachingDef)
508  job.reachingDef = getLazyDefaultValue();
509  rewriter.updateRootInPlace(terminator, [&]() {
510  terminator.getSuccessorOperands(blockOperand.getOperandNumber())
511  .append(job.reachingDef);
512  });
513  }
514  }
515  }
516 
517  for (auto *child : job.block->children())
518  dfsStack.emplace_back<DfsJob>({child, job.reachingDef});
519  }
520 }
521 
522 /// Sorts `ops` according to dominance. Relies on the topological order of basic
523 /// blocks to get a deterministic ordering.
524 static void dominanceSort(SmallVector<Operation *> &ops, Region &region) {
525  // Produce a topological block order and construct a map to lookup the indices
526  // of blocks.
527  DenseMap<Block *, size_t> topoBlockIndices;
528  SetVector<Block *> topologicalOrder = getTopologicallySortedBlocks(region);
529  for (auto [index, block] : llvm::enumerate(topologicalOrder))
530  topoBlockIndices[block] = index;
531 
532  // Combining the topological order of the basic blocks together with block
533  // internal operation order guarantees a deterministic, dominance respecting
534  // order.
535  llvm::sort(ops, [&](Operation *lhs, Operation *rhs) {
536  size_t lhsBlockIndex = topoBlockIndices.at(lhs->getBlock());
537  size_t rhsBlockIndex = topoBlockIndices.at(rhs->getBlock());
538  if (lhsBlockIndex == rhsBlockIndex)
539  return lhs->isBeforeInBlock(rhs);
540  return lhsBlockIndex < rhsBlockIndex;
541  });
542 }
543 
544 void MemorySlotPromoter::removeBlockingUses() {
545  llvm::SmallVector<Operation *> usersToRemoveUses(
546  llvm::make_first_range(info.userToBlockingUses));
547 
548  // Sort according to dominance.
549  dominanceSort(usersToRemoveUses, *slot.ptr.getParentBlock()->getParent());
550 
552  for (Operation *toPromote : llvm::reverse(usersToRemoveUses)) {
553  if (auto toPromoteMemOp = dyn_cast<PromotableMemOpInterface>(toPromote)) {
554  Value reachingDef = reachingDefs.lookup(toPromoteMemOp);
555  // If no reaching definition is known, this use is outside the reach of
556  // the slot. The default value should thus be used.
557  if (!reachingDef)
558  reachingDef = getLazyDefaultValue();
559 
560  rewriter.setInsertionPointAfter(toPromote);
561  if (toPromoteMemOp.removeBlockingUses(
562  slot, info.userToBlockingUses[toPromote], rewriter,
563  reachingDef) == DeletionKind::Delete)
564  toErase.push_back(toPromote);
565 
566  continue;
567  }
568 
569  auto toPromoteBasic = cast<PromotableOpInterface>(toPromote);
570  rewriter.setInsertionPointAfter(toPromote);
571  if (toPromoteBasic.removeBlockingUses(info.userToBlockingUses[toPromote],
572  rewriter) == DeletionKind::Delete)
573  toErase.push_back(toPromote);
574  }
575 
576  for (Operation *toEraseOp : toErase)
577  rewriter.eraseOp(toEraseOp);
578 
579  assert(slot.ptr.use_empty() &&
580  "after promotion, the slot pointer should not be used anymore");
581 }
582 
583 void MemorySlotPromoter::promoteSlot() {
584  computeReachingDefInRegion(slot.ptr.getParentRegion(), {});
585 
586  // Now that reaching definitions are known, remove all users.
587  removeBlockingUses();
588 
589  // Update terminators in dead branches to forward default if they are
590  // succeeded by a merge points.
591  for (Block *mergePoint : info.mergePoints) {
592  for (BlockOperand &use : mergePoint->getUses()) {
593  auto user = cast<BranchOpInterface>(use.getOwner());
594  SuccessorOperands succOperands =
595  user.getSuccessorOperands(use.getOperandNumber());
596  assert(succOperands.size() == mergePoint->getNumArguments() ||
597  succOperands.size() + 1 == mergePoint->getNumArguments());
598  if (succOperands.size() + 1 == mergePoint->getNumArguments())
599  rewriter.updateRootInPlace(
600  user, [&]() { succOperands.append(getLazyDefaultValue()); });
601  }
602  }
603 
604  LLVM_DEBUG(llvm::dbgs() << "[mem2reg] Promoted memory slot: " << slot.ptr
605  << "\n");
606 
607  if (statistics.promotedAmount)
608  (*statistics.promotedAmount)++;
609 
610  allocator.handlePromotionComplete(slot, defaultValue, rewriter);
611 }
612 
615  RewriterBase &rewriter, Mem2RegStatistics statistics) {
616  bool promotedAny = false;
617 
618  for (PromotableAllocationOpInterface allocator : allocators) {
619  for (MemorySlot slot : allocator.getPromotableSlots()) {
620  if (slot.ptr.use_empty())
621  continue;
622 
623  DominanceInfo dominance;
624  MemorySlotPromotionAnalyzer analyzer(slot, dominance);
625  std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
626  if (info) {
627  MemorySlotPromoter(slot, allocator, rewriter, dominance,
628  std::move(*info), statistics)
629  .promoteSlot();
630  promotedAny = true;
631  }
632  }
633  }
634 
635  return success(promotedAny);
636 }
637 
639 Mem2RegPattern::matchAndRewrite(PromotableAllocationOpInterface allocator,
640  PatternRewriter &rewriter) const {
642  return tryToPromoteMemorySlots({allocator}, rewriter, statistics);
643 }
644 
645 namespace {
646 
647 struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
648  using impl::Mem2RegBase<Mem2Reg>::Mem2RegBase;
649 
650  void runOnOperation() override {
651  Operation *scopeOp = getOperation();
652 
653  Mem2RegStatistics statictics{&promotedAmount, &newBlockArgumentAmount};
654 
655  GreedyRewriteConfig config;
656  config.enableRegionSimplification = enableRegionSimplification;
657 
658  RewritePatternSet rewritePatterns(&getContext());
659  rewritePatterns.add<Mem2RegPattern>(&getContext(), statictics);
660  FrozenRewritePatternSet frozen(std::move(rewritePatterns));
661 
662  if (failed(applyPatternsAndFoldGreedily(scopeOp, frozen, config)))
663  signalPassFailure();
664  }
665 };
666 
667 } // namespace
static MLIRContext * getContext(OpFoldResult val)
static void dominanceSort(SmallVector< Operation * > &ops, Region &region)
Sorts ops according to dominance.
Definition: Mem2Reg.cpp:524
llvm::IDFCalculatorBase< Block, false > IDFCalculator
Definition: Mem2Reg.cpp:362
static SetVector< llvm::BasicBlock * > getTopologicallySortedBlocks(llvm::Function *func)
Get a topologically sorted list of blocks for the given function.
This class represents an argument of a Block.
Definition: Value.h:315
A block operand represents an operand that holds a reference to a Block, e.g.
Definition: BlockSupport.h:30
Block represents an ordered list of Operations.
Definition: Block.h:30
unsigned getNumArguments()
Definition: Block.h:121
void erase()
Unlink this Block from its parent region and delete it.
Definition: Block.cpp:60
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:238
iterator_range< pred_iterator > getPredecessors()
Definition: Block.h:230
OpListType & getOperations()
Definition: Block.h:130
BlockArgListType getArguments()
Definition: Block.h:80
A class for computing basic dominance information.
Definition: Dominance.h:121
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
bool enableRegionSimplification
Perform control flow optimizations to the region tree after applying all patterns.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: UseDefLists.h:253
Pattern applying mem2reg to the regions of the operations on which it matches.
Definition: Mem2Reg.h:31
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:416
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:419
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:397
This class represents an operand of an operation.
Definition: Value.h:263
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:385
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
Definition: PatternMatch.h:128
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Block & front()
Definition: Region.h:65
bool hasOneBlock()
Return true if this region has exactly one block.
Definition: Region.h:68
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:606
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:615
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class models how operands are forwarded to block arguments in control flow.
void append(ValueRange valueRange)
Add new operands that are forwarded to the successor.
unsigned size() const
Returns the amount of operands passed to the successor.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:214
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:208
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
user_range getUsers() const
Definition: Value.h:224
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
DomTree & getDomTree(Region *region) const
Definition: Dominance.h:86
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult tryToPromoteMemorySlots(ArrayRef< PromotableAllocationOpInterface > allocators, RewriterBase &rewriter, Mem2RegStatistics statistics={})
Attempts to promote the memory slots of the provided allocators.
Definition: Mem2Reg.cpp:613
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Statistics collected while applying mem2reg.
Definition: Mem2Reg.h:21
llvm::Statistic * promotedAmount
Total amount of memory slots promoted.
Definition: Mem2Reg.h:23
llvm::Statistic * newBlockArgumentAmount
Total amount of new block arguments inserted in blocks.
Definition: Mem2Reg.h:25
Represents a slot in memory.
Value ptr
Pointer to the memory slot, used by operations to refer to it.
Type elemType
Type of the value contained in the slot.