MLIR  22.0.0git
RemoveDeadValues.cpp
Go to the documentation of this file.
1 //===- RemoveDeadValues.cpp - Remove Dead Values --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The goal of this pass is optimization (reducing runtime) by removing
10 // unnecessary instructions. Unlike other passes that rely on local information
11 // gathered from patterns to accomplish optimization, this pass uses a full
12 // analysis of the IR, specifically, liveness analysis, and is thus more
13 // powerful.
14 //
15 // Currently, this pass performs the following optimizations:
16 // (A) Removes function arguments that are not live,
17 // (B) Removes function return values that are not live across all callers of
18 // the function,
19 // (C) Removes unneccesary operands, results, region arguments, and region
20 // terminator operands of region branch ops, and,
21 // (D) Removes simple and region branch ops that have all non-live results and
22 // don't affect memory in any way,
23 //
24 // iff
25 //
26 // the IR doesn't have any non-function symbol ops, non-call symbol user ops and
27 // branch ops.
28 //
29 // Here, a "simple op" refers to an op that isn't a symbol op, symbol-user op,
30 // region branch op, branch op, region branch terminator op, or return-like.
31 //
32 //===----------------------------------------------------------------------===//
33 
36 #include "mlir/IR/Builders.h"
38 #include "mlir/IR/Dialect.h"
39 #include "mlir/IR/Operation.h"
41 #include "mlir/IR/SymbolTable.h"
42 #include "mlir/IR/Value.h"
43 #include "mlir/IR/ValueRange.h"
44 #include "mlir/IR/Visitors.h"
49 #include "mlir/Pass/Pass.h"
50 #include "mlir/Support/LLVM.h"
52 #include "mlir/Transforms/Passes.h"
53 #include "llvm/ADT/STLExtras.h"
54 #include "llvm/Support/Debug.h"
55 #include "llvm/Support/DebugLog.h"
56 #include <cassert>
57 #include <cstddef>
58 #include <memory>
59 #include <optional>
60 #include <vector>
61 
62 #define DEBUG_TYPE "remove-dead-values"
63 
64 namespace mlir {
65 #define GEN_PASS_DEF_REMOVEDEADVALUES
66 #include "mlir/Transforms/Passes.h.inc"
67 } // namespace mlir
68 
69 using namespace mlir;
70 using namespace mlir::dataflow;
71 
72 //===----------------------------------------------------------------------===//
73 // RemoveDeadValues Pass
74 //===----------------------------------------------------------------------===//
75 
76 namespace {
77 
78 // Set of structures below to be filled with operations and arguments to erase.
79 // This is done to separate analysis and tree modification phases,
80 // otherwise analysis is operating on half-deleted tree which is incorrect.
81 
82 struct FunctionToCleanUp {
83  FunctionOpInterface funcOp;
84  BitVector nonLiveArgs;
85  BitVector nonLiveRets;
86 };
87 
88 struct OperationToCleanup {
89  Operation *op;
90  BitVector nonLive;
91 };
92 
93 struct BlockArgsToCleanup {
94  Block *b;
95  BitVector nonLiveArgs;
96 };
97 
98 struct SuccessorOperandsToCleanup {
99  BranchOpInterface branch;
100  unsigned successorIndex;
101  BitVector nonLiveOperands;
102 };
103 
104 struct RDVFinalCleanupList {
105  SmallVector<Operation *> operations;
106  SmallVector<Value> values;
111  SmallVector<SuccessorOperandsToCleanup> successorOperands;
112 };
113 
114 // Some helper functions...
115 
116 /// Return true iff at least one value in `values` is live, given the liveness
117 /// information in `la`.
118 static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
119  RunLivenessAnalysis &la) {
120  for (Value value : values) {
121  if (nonLiveSet.contains(value)) {
122  LDBG() << "Value " << value << " is already marked non-live (dead)";
123  continue;
124  }
125 
126  const Liveness *liveness = la.getLiveness(value);
127  if (!liveness) {
128  LDBG() << "Value " << value
129  << " has no liveness info, conservatively considered live";
130  return true;
131  }
132  if (liveness->isLive) {
133  LDBG() << "Value " << value << " is live according to liveness analysis";
134  return true;
135  } else {
136  LDBG() << "Value " << value << " is dead according to liveness analysis";
137  }
138  }
139  return false;
140 }
141 
142 /// Return a BitVector of size `values.size()` where its i-th bit is 1 iff the
143 /// i-th value in `values` is live, given the liveness information in `la`.
144 static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
145  RunLivenessAnalysis &la) {
146  BitVector lives(values.size(), true);
147 
148  for (auto [index, value] : llvm::enumerate(values)) {
149  if (nonLiveSet.contains(value)) {
150  lives.reset(index);
151  LDBG() << "Value " << value
152  << " is already marked non-live (dead) at index " << index;
153  continue;
154  }
155 
156  const Liveness *liveness = la.getLiveness(value);
157  // It is important to note that when `liveness` is null, we can't tell if
158  // `value` is live or not. So, the safe option is to consider it live. Also,
159  // the execution of this pass might create new SSA values when erasing some
160  // of the results of an op and we know that these new values are live
161  // (because they weren't erased) and also their liveness is null because
162  // liveness analysis ran before their creation.
163  if (!liveness) {
164  LDBG() << "Value " << value << " at index " << index
165  << " has no liveness info, conservatively considered live";
166  continue;
167  }
168  if (!liveness->isLive) {
169  lives.reset(index);
170  LDBG() << "Value " << value << " at index " << index
171  << " is dead according to liveness analysis";
172  } else {
173  LDBG() << "Value " << value << " at index " << index
174  << " is live according to liveness analysis";
175  }
176  }
177 
178  return lives;
179 }
180 
181 /// Collects values marked as "non-live" in the provided range and inserts them
182 /// into the nonLiveSet. A value is considered "non-live" if the corresponding
183 /// index in the `nonLive` bit vector is set.
184 static void collectNonLiveValues(DenseSet<Value> &nonLiveSet, ValueRange range,
185  const BitVector &nonLive) {
186  for (auto [index, result] : llvm::enumerate(range)) {
187  if (!nonLive[index])
188  continue;
189  nonLiveSet.insert(result);
190  LDBG() << "Marking value " << result << " as non-live (dead) at index "
191  << index;
192  }
193 }
194 
195 /// Drop the uses of the i-th result of `op` and then erase it iff toErase[i]
196 /// is 1.
197 static void dropUsesAndEraseResults(Operation *op, BitVector toErase) {
198  assert(op->getNumResults() == toErase.size() &&
199  "expected the number of results in `op` and the size of `toErase` to "
200  "be the same");
201 
202  std::vector<Type> newResultTypes;
203  for (OpResult result : op->getResults())
204  if (!toErase[result.getResultNumber()])
205  newResultTypes.push_back(result.getType());
206  OpBuilder builder(op);
207  builder.setInsertionPointAfter(op);
208  OperationState state(op->getLoc(), op->getName().getStringRef(),
209  op->getOperands(), newResultTypes, op->getAttrs());
210  for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i)
211  state.addRegion();
212  Operation *newOp = builder.create(state);
213  for (const auto &[index, region] : llvm::enumerate(op->getRegions())) {
214  Region &newRegion = newOp->getRegion(index);
215  // Move all blocks of `region` into `newRegion`.
216  Block *temp = new Block();
217  newRegion.push_back(temp);
218  while (!region.empty())
219  region.front().moveBefore(temp);
220  temp->erase();
221  }
222 
223  unsigned indexOfNextNewCallOpResultToReplace = 0;
224  for (auto [index, result] : llvm::enumerate(op->getResults())) {
225  assert(result && "expected result to be non-null");
226  if (toErase[index]) {
227  result.dropAllUses();
228  } else {
229  result.replaceAllUsesWith(
230  newOp->getResult(indexOfNextNewCallOpResultToReplace++));
231  }
232  }
233  op->erase();
234 }
235 
236 /// Convert a list of `Operand`s to a list of `OpOperand`s.
238  OpOperand *values = operands.getBase();
239  SmallVector<OpOperand *> opOperands;
240  for (unsigned i = 0, e = operands.size(); i < e; i++)
241  opOperands.push_back(&values[i]);
242  return opOperands;
243 }
244 
245 /// Process a simple operation `op` using the liveness analysis `la`.
246 /// If the operation has no memory effects and none of its results are live:
247 /// 1. Add the operation to a list for future removal, and
248 /// 2. Mark all its results as non-live values
249 ///
250 /// The operation `op` is assumed to be simple. A simple operation is one that
251 /// is NOT:
252 /// - Function-like
253 /// - Call-like
254 /// - A region branch operation
255 /// - A branch operation
256 /// - A region branch terminator
257 /// - Return-like
258 static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
259  DenseSet<Value> &nonLiveSet,
260  RDVFinalCleanupList &cl) {
261  if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) {
262  LDBG() << "Simple op is not memory effect free or has live results, "
263  "preserving it: "
264  << OpWithFlags(op, OpPrintingFlags().skipRegions());
265  return;
266  }
267 
268  LDBG()
269  << "Simple op has all dead results and is memory effect free, scheduling "
270  "for removal: "
271  << OpWithFlags(op, OpPrintingFlags().skipRegions());
272  cl.operations.push_back(op);
273  collectNonLiveValues(nonLiveSet, op->getResults(),
274  BitVector(op->getNumResults(), true));
275 }
276 
277 /// Process a function-like operation `funcOp` using the liveness analysis `la`
278 /// and the IR in `module`. If it is not public or external:
279 /// (1) Adding its non-live arguments to a list for future removal.
280 /// (2) Marking their corresponding operands in its callers for removal.
281 /// (3) Identifying and enqueueing unnecessary terminator operands
282 /// (return values that are non-live across all callers) for removal.
283 /// (4) Enqueueing the non-live arguments and return values for removal.
284 /// (5) Collecting the uses of these return values in its callers for future
285 /// removal.
286 /// (6) Marking all its results as non-live values.
287 static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
288  RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
289  RDVFinalCleanupList &cl) {
290  LDBG() << "Processing function op: " << funcOp.getOperation()->getName();
291  if (funcOp.isPublic() || funcOp.isExternal()) {
292  LDBG() << "Function is public or external, skipping: "
293  << funcOp.getOperation()->getName();
294  return;
295  }
296 
297  // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
298  SmallVector<Value> arguments(funcOp.getArguments());
299  BitVector nonLiveArgs = markLives(arguments, nonLiveSet, la);
300  nonLiveArgs = nonLiveArgs.flip();
301 
302  // Do (1).
303  for (auto [index, arg] : llvm::enumerate(arguments))
304  if (arg && nonLiveArgs[index]) {
305  cl.values.push_back(arg);
306  nonLiveSet.insert(arg);
307  }
308 
309  // Do (2).
310  SymbolTable::UseRange uses = *funcOp.getSymbolUses(module);
311  for (SymbolTable::SymbolUse use : uses) {
312  Operation *callOp = use.getUser();
313  assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
314  // The number of operands in the call op may not match the number of
315  // arguments in the func op.
316  BitVector nonLiveCallOperands(callOp->getNumOperands(), false);
317  SmallVector<OpOperand *> callOpOperands =
318  operandsToOpOperands(cast<CallOpInterface>(callOp).getArgOperands());
319  for (int index : nonLiveArgs.set_bits())
320  nonLiveCallOperands.set(callOpOperands[index]->getOperandNumber());
321  cl.operands.push_back({callOp, nonLiveCallOperands});
322  }
323 
324  // Do (3).
325  // Get the list of unnecessary terminator operands (return values that are
326  // non-live across all callers) in `nonLiveRets`. There is a very important
327  // subtlety here. Unnecessary terminator operands are NOT the operands of the
328  // terminator that are non-live. Instead, these are the return values of the
329  // callers such that a given return value is non-live across all callers. Such
330  // corresponding operands in the terminator could be live. An example to
331  // demonstrate this:
332  // func.func private @f(%arg0: memref<i32>) -> (i32, i32) {
333  // %c0_i32 = arith.constant 0 : i32
334  // %0 = arith.addi %c0_i32, %c0_i32 : i32
335  // memref.store %0, %arg0[] : memref<i32>
336  // return %c0_i32, %0 : i32, i32
337  // }
338  // func.func @main(%arg0: i32, %arg1: memref<i32>) -> (i32) {
339  // %1:2 = call @f(%arg1) : (memref<i32>) -> i32
340  // return %1#0 : i32
341  // }
342  // Here, we can see that %1#1 is never used. It is non-live. Thus, @f doesn't
343  // need to return %0. But, %0 is live. And, still, we want to stop it from
344  // being returned, in order to optimize our IR. So, this demonstrates how we
345  // can make our optimization strong by even removing a live return value (%0),
346  // since it forwards only to non-live value(s) (%1#1).
347  size_t numReturns = funcOp.getNumResults();
348  BitVector nonLiveRets(numReturns, true);
349  for (SymbolTable::SymbolUse use : uses) {
350  Operation *callOp = use.getUser();
351  assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
352  BitVector liveCallRets = markLives(callOp->getResults(), nonLiveSet, la);
353  nonLiveRets &= liveCallRets.flip();
354  }
355 
356  // Note that in the absence of control flow ops forcing the control to go from
357  // the entry (first) block to the other blocks, the control never reaches any
358  // block other than the entry block, because every block has a terminator.
359  for (Block &block : funcOp.getBlocks()) {
360  Operation *returnOp = block.getTerminator();
361  if (returnOp && returnOp->getNumOperands() == numReturns)
362  cl.operands.push_back({returnOp, nonLiveRets});
363  }
364 
365  // Do (4).
366  cl.functions.push_back({funcOp, nonLiveArgs, nonLiveRets});
367 
368  // Do (5) and (6).
369  if (numReturns == 0)
370  return;
371  for (SymbolTable::SymbolUse use : uses) {
372  Operation *callOp = use.getUser();
373  assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
374  cl.results.push_back({callOp, nonLiveRets});
375  collectNonLiveValues(nonLiveSet, callOp->getResults(), nonLiveRets);
376  }
377 }
378 
379 /// Process a region branch operation `regionBranchOp` using the liveness
380 /// information in `la`. The processing involves two scenarios:
381 ///
382 /// Scenario 1: If the operation has no memory effects and none of its results
383 /// are live:
384 /// (1') Enqueue all its uses for deletion.
385 /// (2') Enqueue the branch itself for deletion.
386 ///
387 /// Scenario 2: Otherwise:
388 /// (1) Collect its unnecessary operands (operands forwarded to unnecessary
389 /// results or arguments).
390 /// (2) Process each of its regions.
391 /// (3) Collect the uses of its unnecessary results (results forwarded from
392 /// unnecessary operands
393 /// or terminator operands).
394 /// (4) Add these results to the deletion list.
395 ///
396 /// Processing a region includes:
397 /// (a) Collecting the uses of its unnecessary arguments (arguments forwarded
398 /// from unnecessary operands
399 /// or terminator operands).
400 /// (b) Collecting these unnecessary arguments.
401 /// (c) Collecting its unnecessary terminator operands (terminator operands
402 /// forwarded to unnecessary results
403 /// or arguments).
404 ///
405 /// Value Flow Note: In this operation, values flow as follows:
406 /// - From operands and terminator operands (successor operands)
407 /// - To arguments and results (successor inputs).
408 static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
410  DenseSet<Value> &nonLiveSet,
411  RDVFinalCleanupList &cl) {
412  LDBG() << "Processing region branch op: "
413  << OpWithFlags(regionBranchOp, OpPrintingFlags().skipRegions());
414  // Mark live results of `regionBranchOp` in `liveResults`.
415  auto markLiveResults = [&](BitVector &liveResults) {
416  liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la);
417  };
418 
419  // Mark live arguments in the regions of `regionBranchOp` in `liveArgs`.
420  auto markLiveArgs = [&](DenseMap<Region *, BitVector> &liveArgs) {
421  for (Region &region : regionBranchOp->getRegions()) {
422  if (region.empty())
423  continue;
424  SmallVector<Value> arguments(region.front().getArguments());
425  BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la);
426  liveArgs[&region] = regionLiveArgs;
427  }
428  };
429 
430  // Return the successors of `region` if the latter is not null. Else return
431  // the successors of `regionBranchOp`.
432  auto getSuccessors = [&](Region *region = nullptr) {
433  auto point = region ? region : RegionBranchPoint::parent();
434  SmallVector<RegionSuccessor> successors;
435  regionBranchOp.getSuccessorRegions(point, successors);
436  return successors;
437  };
438 
439  // Return the operands of `terminator` that are forwarded to `successor` if
440  // the former is not null. Else return the operands of `regionBranchOp`
441  // forwarded to `successor`.
442  auto getForwardedOpOperands = [&](const RegionSuccessor &successor,
443  Operation *terminator = nullptr) {
444  OperandRange operands =
445  terminator ? cast<RegionBranchTerminatorOpInterface>(terminator)
446  .getSuccessorOperands(successor)
447  : regionBranchOp.getEntrySuccessorOperands(successor);
448  SmallVector<OpOperand *> opOperands = operandsToOpOperands(operands);
449  return opOperands;
450  };
451 
452  // Mark the non-forwarded operands of `regionBranchOp` in
453  // `nonForwardedOperands`.
454  auto markNonForwardedOperands = [&](BitVector &nonForwardedOperands) {
455  nonForwardedOperands.resize(regionBranchOp->getNumOperands(), true);
456  for (const RegionSuccessor &successor : getSuccessors()) {
457  for (OpOperand *opOperand : getForwardedOpOperands(successor))
458  nonForwardedOperands.reset(opOperand->getOperandNumber());
459  }
460  };
461 
462  // Mark the non-forwarded terminator operands of the various regions of
463  // `regionBranchOp` in `nonForwardedRets`.
464  auto markNonForwardedReturnValues =
465  [&](DenseMap<Operation *, BitVector> &nonForwardedRets) {
466  for (Region &region : regionBranchOp->getRegions()) {
467  if (region.empty())
468  continue;
469  Operation *terminator = region.front().getTerminator();
470  nonForwardedRets[terminator] =
471  BitVector(terminator->getNumOperands(), true);
472  for (const RegionSuccessor &successor : getSuccessors(&region)) {
473  for (OpOperand *opOperand :
474  getForwardedOpOperands(successor, terminator))
475  nonForwardedRets[terminator].reset(opOperand->getOperandNumber());
476  }
477  }
478  };
479 
480  // Update `valuesToKeep` (which is expected to correspond to operands or
481  // terminator operands) based on `resultsToKeep` and `argsToKeep`, given
482  // `region`. When `valuesToKeep` correspond to operands, `region` is null.
483  // Else, `region` is the parent region of the terminator.
484  auto updateOperandsOrTerminatorOperandsToKeep =
485  [&](BitVector &valuesToKeep, BitVector &resultsToKeep,
486  DenseMap<Region *, BitVector> &argsToKeep, Region *region = nullptr) {
487  Operation *terminator =
488  region ? region->front().getTerminator() : nullptr;
489 
490  for (const RegionSuccessor &successor : getSuccessors(region)) {
491  Region *successorRegion = successor.getSuccessor();
492  for (auto [opOperand, input] :
493  llvm::zip(getForwardedOpOperands(successor, terminator),
494  successor.getSuccessorInputs())) {
495  size_t operandNum = opOperand->getOperandNumber();
496  bool updateBasedOn =
497  successorRegion
498  ? argsToKeep[successorRegion]
499  [cast<BlockArgument>(input).getArgNumber()]
500  : resultsToKeep[cast<OpResult>(input).getResultNumber()];
501  valuesToKeep[operandNum] = valuesToKeep[operandNum] | updateBasedOn;
502  }
503  }
504  };
505 
506  // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep` and
507  // `terminatorOperandsToKeep`. Store true in `resultsOrArgsToKeepChanged` if a
508  // value is modified, else, false.
509  auto recomputeResultsAndArgsToKeep =
510  [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep,
511  BitVector &operandsToKeep,
512  DenseMap<Operation *, BitVector> &terminatorOperandsToKeep,
513  bool &resultsOrArgsToKeepChanged) {
514  resultsOrArgsToKeepChanged = false;
515 
516  // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep`.
517  for (const RegionSuccessor &successor : getSuccessors()) {
518  Region *successorRegion = successor.getSuccessor();
519  for (auto [opOperand, input] :
520  llvm::zip(getForwardedOpOperands(successor),
521  successor.getSuccessorInputs())) {
522  bool recomputeBasedOn =
523  operandsToKeep[opOperand->getOperandNumber()];
524  bool toRecompute =
525  successorRegion
526  ? argsToKeep[successorRegion]
527  [cast<BlockArgument>(input).getArgNumber()]
528  : resultsToKeep[cast<OpResult>(input).getResultNumber()];
529  if (!toRecompute && recomputeBasedOn)
530  resultsOrArgsToKeepChanged = true;
531  if (successorRegion) {
532  argsToKeep[successorRegion][cast<BlockArgument>(input)
533  .getArgNumber()] =
534  argsToKeep[successorRegion]
535  [cast<BlockArgument>(input).getArgNumber()] |
536  recomputeBasedOn;
537  } else {
538  resultsToKeep[cast<OpResult>(input).getResultNumber()] =
539  resultsToKeep[cast<OpResult>(input).getResultNumber()] |
540  recomputeBasedOn;
541  }
542  }
543  }
544 
545  // Recompute `resultsToKeep` and `argsToKeep` based on
546  // `terminatorOperandsToKeep`.
547  for (Region &region : regionBranchOp->getRegions()) {
548  if (region.empty())
549  continue;
550  Operation *terminator = region.front().getTerminator();
551  for (const RegionSuccessor &successor : getSuccessors(&region)) {
552  Region *successorRegion = successor.getSuccessor();
553  for (auto [opOperand, input] :
554  llvm::zip(getForwardedOpOperands(successor, terminator),
555  successor.getSuccessorInputs())) {
556  bool recomputeBasedOn =
557  terminatorOperandsToKeep[region.back().getTerminator()]
558  [opOperand->getOperandNumber()];
559  bool toRecompute =
560  successorRegion
561  ? argsToKeep[successorRegion]
562  [cast<BlockArgument>(input).getArgNumber()]
563  : resultsToKeep[cast<OpResult>(input).getResultNumber()];
564  if (!toRecompute && recomputeBasedOn)
565  resultsOrArgsToKeepChanged = true;
566  if (successorRegion) {
567  argsToKeep[successorRegion][cast<BlockArgument>(input)
568  .getArgNumber()] =
569  argsToKeep[successorRegion]
570  [cast<BlockArgument>(input).getArgNumber()] |
571  recomputeBasedOn;
572  } else {
573  resultsToKeep[cast<OpResult>(input).getResultNumber()] =
574  resultsToKeep[cast<OpResult>(input).getResultNumber()] |
575  recomputeBasedOn;
576  }
577  }
578  }
579  }
580  };
581 
582  // Mark the values that we want to keep in `resultsToKeep`, `argsToKeep`,
583  // `operandsToKeep`, and `terminatorOperandsToKeep`.
584  auto markValuesToKeep =
585  [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep,
586  BitVector &operandsToKeep,
587  DenseMap<Operation *, BitVector> &terminatorOperandsToKeep) {
588  bool resultsOrArgsToKeepChanged = true;
589  // We keep updating and recomputing the values until we reach a point
590  // where they stop changing.
591  while (resultsOrArgsToKeepChanged) {
592  // Update the operands that need to be kept.
593  updateOperandsOrTerminatorOperandsToKeep(operandsToKeep,
594  resultsToKeep, argsToKeep);
595 
596  // Update the terminator operands that need to be kept.
597  for (Region &region : regionBranchOp->getRegions()) {
598  if (region.empty())
599  continue;
600  updateOperandsOrTerminatorOperandsToKeep(
601  terminatorOperandsToKeep[region.back().getTerminator()],
602  resultsToKeep, argsToKeep, &region);
603  }
604 
605  // Recompute the results and arguments that need to be kept.
606  recomputeResultsAndArgsToKeep(
607  resultsToKeep, argsToKeep, operandsToKeep,
608  terminatorOperandsToKeep, resultsOrArgsToKeepChanged);
609  }
610  };
611 
612  // Scenario 1. This is the only case where the entire `regionBranchOp`
613  // is removed. It will not happen in any other scenario. Note that in this
614  // case, a non-forwarded operand of `regionBranchOp` could be live/non-live.
615  // It could never be live because of this op but its liveness could have been
616  // attributed to something else.
617  // Do (1') and (2').
618  if (isMemoryEffectFree(regionBranchOp.getOperation()) &&
619  !hasLive(regionBranchOp->getResults(), nonLiveSet, la)) {
620  cl.operations.push_back(regionBranchOp.getOperation());
621  return;
622  }
623 
624  // Scenario 2.
625  // At this point, we know that every non-forwarded operand of `regionBranchOp`
626  // is live.
627 
628  // Stores the results of `regionBranchOp` that we want to keep.
629  BitVector resultsToKeep;
630  // Stores the mapping from regions of `regionBranchOp` to their arguments that
631  // we want to keep.
633  // Stores the operands of `regionBranchOp` that we want to keep.
634  BitVector operandsToKeep;
635  // Stores the mapping from region terminators in `regionBranchOp` to their
636  // operands that we want to keep.
637  DenseMap<Operation *, BitVector> terminatorOperandsToKeep;
638 
639  // Initializing the above variables...
640 
641  // The live results of `regionBranchOp` definitely need to be kept.
642  markLiveResults(resultsToKeep);
643  // Similarly, the live arguments of the regions in `regionBranchOp` definitely
644  // need to be kept.
645  markLiveArgs(argsToKeep);
646  // The non-forwarded operands of `regionBranchOp` definitely need to be kept.
647  // A live forwarded operand can be removed but no non-forwarded operand can be
648  // removed since it "controls" the flow of data in this control flow op.
649  markNonForwardedOperands(operandsToKeep);
650  // Similarly, the non-forwarded terminator operands of the regions in
651  // `regionBranchOp` definitely need to be kept.
652  markNonForwardedReturnValues(terminatorOperandsToKeep);
653 
654  // Mark the values (results, arguments, operands, and terminator operands)
655  // that we want to keep.
656  markValuesToKeep(resultsToKeep, argsToKeep, operandsToKeep,
657  terminatorOperandsToKeep);
658 
659  // Do (1).
660  cl.operands.push_back({regionBranchOp, operandsToKeep.flip()});
661 
662  // Do (2.a) and (2.b).
663  for (Region &region : regionBranchOp->getRegions()) {
664  if (region.empty())
665  continue;
666  BitVector argsToRemove = argsToKeep[&region].flip();
667  cl.blocks.push_back({&region.front(), argsToRemove});
668  collectNonLiveValues(nonLiveSet, region.front().getArguments(),
669  argsToRemove);
670  }
671 
672  // Do (2.c).
673  for (Region &region : regionBranchOp->getRegions()) {
674  if (region.empty())
675  continue;
676  Operation *terminator = region.front().getTerminator();
677  cl.operands.push_back(
678  {terminator, terminatorOperandsToKeep[terminator].flip()});
679  }
680 
681  // Do (3) and (4).
682  BitVector resultsToRemove = resultsToKeep.flip();
683  collectNonLiveValues(nonLiveSet, regionBranchOp.getOperation()->getResults(),
684  resultsToRemove);
685  cl.results.push_back({regionBranchOp.getOperation(), resultsToRemove});
686 }
687 
688 /// Steps to process a `BranchOpInterface` operation:
689 /// Iterate through each successor block of `branchOp`.
690 /// (1) For each successor block, gather all operands from all successors.
691 /// (2) Fetch their associated liveness analysis data and collect for future
692 /// removal.
693 /// (3) Identify and collect the dead operands from the successor block
694 /// as well as their corresponding arguments.
695 
696 static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
697  DenseSet<Value> &nonLiveSet,
698  RDVFinalCleanupList &cl) {
699  LDBG() << "Processing branch op: " << *branchOp;
700  unsigned numSuccessors = branchOp->getNumSuccessors();
701 
702  for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
703  Block *successorBlock = branchOp->getSuccessor(succIdx);
704 
705  // Do (1)
706  SuccessorOperands successorOperands =
707  branchOp.getSuccessorOperands(succIdx);
708  SmallVector<Value> operandValues;
709  for (unsigned operandIdx = 0; operandIdx < successorOperands.size();
710  ++operandIdx) {
711  operandValues.push_back(successorOperands[operandIdx]);
712  }
713 
714  // Do (2)
715  BitVector successorNonLive =
716  markLives(operandValues, nonLiveSet, la).flip();
717  collectNonLiveValues(nonLiveSet, successorBlock->getArguments(),
718  successorNonLive);
719 
720  // Do (3)
721  cl.blocks.push_back({successorBlock, successorNonLive});
722  cl.successorOperands.push_back({branchOp, succIdx, successorNonLive});
723  }
724 }
725 
726 /// Removes dead values collected in RDVFinalCleanupList.
727 /// To be run once when all dead values have been collected.
728 static void cleanUpDeadVals(RDVFinalCleanupList &list) {
729  LDBG() << "Starting cleanup of dead values...";
730 
731  // 1. Operations
732  LDBG() << "Cleaning up " << list.operations.size() << " operations";
733  for (auto &op : list.operations) {
734  LDBG() << "Erasing operation: "
735  << OpWithFlags(op, OpPrintingFlags().skipRegions());
736  op->dropAllUses();
737  op->erase();
738  }
739 
740  // 2. Values
741  LDBG() << "Cleaning up " << list.values.size() << " values";
742  for (auto &v : list.values) {
743  LDBG() << "Dropping all uses of value: " << v;
744  v.dropAllUses();
745  }
746 
747  // 3. Functions
748  LDBG() << "Cleaning up " << list.functions.size() << " functions";
749  for (auto &f : list.functions) {
750  LDBG() << "Cleaning up function: " << f.funcOp.getOperation()->getName();
751  LDBG() << " Erasing " << f.nonLiveArgs.count() << " non-live arguments";
752  LDBG() << " Erasing " << f.nonLiveRets.count()
753  << " non-live return values";
754  // Some functions may not allow erasing arguments or results. These calls
755  // return failure in such cases without modifying the function, so it's okay
756  // to proceed.
757  (void)f.funcOp.eraseArguments(f.nonLiveArgs);
758  (void)f.funcOp.eraseResults(f.nonLiveRets);
759  }
760 
761  // 4. Operands
762  LDBG() << "Cleaning up " << list.operands.size() << " operand lists";
763  for (OperationToCleanup &o : list.operands) {
764  if (o.op->getNumOperands() > 0) {
765  LDBG() << "Erasing " << o.nonLive.count()
766  << " non-live operands from operation: "
767  << OpWithFlags(o.op, OpPrintingFlags().skipRegions());
768  o.op->eraseOperands(o.nonLive);
769  }
770  }
771 
772  // 5. Results
773  LDBG() << "Cleaning up " << list.results.size() << " result lists";
774  for (auto &r : list.results) {
775  LDBG() << "Erasing " << r.nonLive.count()
776  << " non-live results from operation: "
777  << OpWithFlags(r.op, OpPrintingFlags().skipRegions());
778  dropUsesAndEraseResults(r.op, r.nonLive);
779  }
780 
781  // 6. Blocks
782  LDBG() << "Cleaning up " << list.blocks.size() << " block argument lists";
783  for (auto &b : list.blocks) {
784  // blocks that are accessed via multiple codepaths processed once
785  if (b.b->getNumArguments() != b.nonLiveArgs.size())
786  continue;
787  LDBG() << "Erasing " << b.nonLiveArgs.count()
788  << " non-live arguments from block: " << b.b;
789  // it iterates backwards because erase invalidates all successor indexes
790  for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) {
791  if (!b.nonLiveArgs[i])
792  continue;
793  LDBG() << " Erasing block argument " << i << ": " << b.b->getArgument(i);
794  b.b->getArgument(i).dropAllUses();
795  b.b->eraseArgument(i);
796  }
797  }
798 
799  // 7. Successor Operands
800  LDBG() << "Cleaning up " << list.successorOperands.size()
801  << " successor operand lists";
802  for (auto &op : list.successorOperands) {
803  SuccessorOperands successorOperands =
804  op.branch.getSuccessorOperands(op.successorIndex);
805  // blocks that are accessed via multiple codepaths processed once
806  if (successorOperands.size() != op.nonLiveOperands.size())
807  continue;
808  LDBG() << "Erasing " << op.nonLiveOperands.count()
809  << " non-live successor operands from successor "
810  << op.successorIndex << " of branch: "
811  << OpWithFlags(op.branch, OpPrintingFlags().skipRegions());
812  // it iterates backwards because erase invalidates all successor indexes
813  for (int i = successorOperands.size() - 1; i >= 0; --i) {
814  if (!op.nonLiveOperands[i])
815  continue;
816  LDBG() << " Erasing successor operand " << i << ": "
817  << successorOperands[i];
818  successorOperands.erase(i);
819  }
820  }
821 
822  LDBG() << "Finished cleanup of dead values";
823 }
824 
825 struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {
826  void runOnOperation() override;
827 };
828 } // namespace
829 
830 void RemoveDeadValues::runOnOperation() {
831  auto &la = getAnalysis<RunLivenessAnalysis>();
832  Operation *module = getOperation();
833 
834  // Tracks values eligible for erasure - complements liveness analysis to
835  // identify "droppable" values.
836  DenseSet<Value> deadVals;
837 
838  // Maintains a list of Ops, values, branches, etc., slated for cleanup at the
839  // end of this pass.
840  RDVFinalCleanupList finalCleanupList;
841 
842  module->walk([&](Operation *op) {
843  if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
844  processFuncOp(funcOp, module, la, deadVals, finalCleanupList);
845  } else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
846  processRegionBranchOp(regionBranchOp, la, deadVals, finalCleanupList);
847  } else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
848  processBranchOp(branchOp, la, deadVals, finalCleanupList);
849  } else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
850  // Nothing to do here because this is a terminator op and it should be
851  // honored with respect to its parent
852  } else if (isa<CallOpInterface>(op)) {
853  // Nothing to do because this op is associated with a function op and gets
854  // cleaned when the latter is cleaned.
855  } else {
856  processSimpleOp(op, la, deadVals, finalCleanupList);
857  }
858  });
859 
860  cleanUpDeadVals(finalCleanupList);
861 }
862 
863 std::unique_ptr<Pass> mlir::createRemoveDeadValuesPass() {
864  return std::make_unique<RemoveDeadValues>();
865 }
static MutableArrayRef< OpOperand > operandsToOpOperands(OperandRange &operands)
Block represents an ordered list of Operations.
Definition: Block.h:33
void erase()
Unlink this Block from its parent region and delete it.
Definition: Block.cpp:66
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
BlockArgListType getArguments()
Definition: Block.h:87
Block * getSuccessor(unsigned i)
Definition: Block.cpp:269
This class helps build Operations.
Definition: Builders.h:205
This class represents an operand of an operation.
Definition: Value.h:257
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
Definition: Value.h:447
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:773
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition: Operation.h:1111
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:834
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:66
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:538
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
static constexpr RegionBranchPoint parent()
Returns an instance of RegionBranchPoint representing the parent operation.
This class represents a successor of a region.
ValueRange getSuccessorInputs() const
Return the inputs to the successor that are remapped by the exit values of the current region.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
void push_back(Block *block)
Definition: Region.h:61
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
This class models how operands are forwarded to block arguments in control flow.
void erase(unsigned subStart, unsigned subLen=1)
Erase operands forwarded to the successor.
unsigned size() const
Returns the amount of operands passed to the successor.
This class represents a specific symbol use.
Definition: SymbolTable.h:183
This class implements a range of SymbolRef uses.
Definition: SymbolTable.h:203
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
std::unique_ptr< Pass > createRemoveDeadValuesPass()
Creates an optimization pass to remove dead values.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
This lattice represents, for a given value, whether or not it is "live".
Runs liveness analysis on the IR defined by op.
const Liveness * getLiveness(Value val)