MLIR  22.0.0git
RemoveDeadValues.cpp
Go to the documentation of this file.
1 //===- RemoveDeadValues.cpp - Remove Dead Values --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The goal of this pass is optimization (reducing runtime) by removing
10 // unnecessary instructions. Unlike other passes that rely on local information
11 // gathered from patterns to accomplish optimization, this pass uses a full
12 // analysis of the IR, specifically, liveness analysis, and is thus more
13 // powerful.
14 //
15 // Currently, this pass performs the following optimizations:
16 // (A) Removes function arguments that are not live,
17 // (B) Removes function return values that are not live across all callers of
18 // the function,
19 // (C) Removes unneccesary operands, results, region arguments, and region
20 // terminator operands of region branch ops, and,
21 // (D) Removes simple and region branch ops that have all non-live results and
22 // don't affect memory in any way,
23 //
24 // iff
25 //
26 // the IR doesn't have any non-function symbol ops, non-call symbol user ops and
27 // branch ops.
28 //
29 // Here, a "simple op" refers to an op that isn't a symbol op, symbol-user op,
30 // region branch op, branch op, region branch terminator op, or return-like.
31 //
32 //===----------------------------------------------------------------------===//
33 
36 #include "mlir/IR/Builders.h"
38 #include "mlir/IR/Dialect.h"
39 #include "mlir/IR/Operation.h"
41 #include "mlir/IR/SymbolTable.h"
42 #include "mlir/IR/Value.h"
43 #include "mlir/IR/ValueRange.h"
44 #include "mlir/IR/Visitors.h"
49 #include "mlir/Pass/Pass.h"
50 #include "mlir/Support/LLVM.h"
52 #include "mlir/Transforms/Passes.h"
53 #include "llvm/ADT/STLExtras.h"
54 #include "llvm/Support/Debug.h"
55 #include "llvm/Support/DebugLog.h"
56 #include <cassert>
57 #include <cstddef>
58 #include <memory>
59 #include <optional>
60 #include <vector>
61 
62 #define DEBUG_TYPE "remove-dead-values"
63 
64 namespace mlir {
65 #define GEN_PASS_DEF_REMOVEDEADVALUES
66 #include "mlir/Transforms/Passes.h.inc"
67 } // namespace mlir
68 
69 using namespace mlir;
70 using namespace mlir::dataflow;
71 
72 //===----------------------------------------------------------------------===//
73 // RemoveDeadValues Pass
74 //===----------------------------------------------------------------------===//
75 
76 namespace {
77 
78 // Set of structures below to be filled with operations and arguments to erase.
79 // This is done to separate analysis and tree modification phases,
80 // otherwise analysis is operating on half-deleted tree which is incorrect.
81 
82 struct FunctionToCleanUp {
83  FunctionOpInterface funcOp;
84  BitVector nonLiveArgs;
85  BitVector nonLiveRets;
86 };
87 
88 struct OperationToCleanup {
89  Operation *op;
90  BitVector nonLive;
91  Operation *callee =
92  nullptr; // Optional: For CallOpInterface ops, stores the callee function
93 };
94 
95 struct BlockArgsToCleanup {
96  Block *b;
97  BitVector nonLiveArgs;
98 };
99 
100 struct SuccessorOperandsToCleanup {
101  BranchOpInterface branch;
102  unsigned successorIndex;
103  BitVector nonLiveOperands;
104 };
105 
106 struct RDVFinalCleanupList {
107  SmallVector<Operation *> operations;
108  SmallVector<Value> values;
113  SmallVector<SuccessorOperandsToCleanup> successorOperands;
114 };
115 
116 // Some helper functions...
117 
118 /// Return true iff at least one value in `values` is live, given the liveness
119 /// information in `la`.
120 static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
121  RunLivenessAnalysis &la) {
122  for (Value value : values) {
123  if (nonLiveSet.contains(value)) {
124  LDBG() << "Value " << value << " is already marked non-live (dead)";
125  continue;
126  }
127 
128  const Liveness *liveness = la.getLiveness(value);
129  if (!liveness) {
130  LDBG() << "Value " << value
131  << " has no liveness info, conservatively considered live";
132  return true;
133  }
134  if (liveness->isLive) {
135  LDBG() << "Value " << value << " is live according to liveness analysis";
136  return true;
137  } else {
138  LDBG() << "Value " << value << " is dead according to liveness analysis";
139  }
140  }
141  return false;
142 }
143 
144 /// Return a BitVector of size `values.size()` where its i-th bit is 1 iff the
145 /// i-th value in `values` is live, given the liveness information in `la`.
146 static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
147  RunLivenessAnalysis &la) {
148  BitVector lives(values.size(), true);
149 
150  for (auto [index, value] : llvm::enumerate(values)) {
151  if (nonLiveSet.contains(value)) {
152  lives.reset(index);
153  LDBG() << "Value " << value
154  << " is already marked non-live (dead) at index " << index;
155  continue;
156  }
157 
158  const Liveness *liveness = la.getLiveness(value);
159  // It is important to note that when `liveness` is null, we can't tell if
160  // `value` is live or not. So, the safe option is to consider it live. Also,
161  // the execution of this pass might create new SSA values when erasing some
162  // of the results of an op and we know that these new values are live
163  // (because they weren't erased) and also their liveness is null because
164  // liveness analysis ran before their creation.
165  if (!liveness) {
166  LDBG() << "Value " << value << " at index " << index
167  << " has no liveness info, conservatively considered live";
168  continue;
169  }
170  if (!liveness->isLive) {
171  lives.reset(index);
172  LDBG() << "Value " << value << " at index " << index
173  << " is dead according to liveness analysis";
174  } else {
175  LDBG() << "Value " << value << " at index " << index
176  << " is live according to liveness analysis";
177  }
178  }
179 
180  return lives;
181 }
182 
183 /// Collects values marked as "non-live" in the provided range and inserts them
184 /// into the nonLiveSet. A value is considered "non-live" if the corresponding
185 /// index in the `nonLive` bit vector is set.
186 static void collectNonLiveValues(DenseSet<Value> &nonLiveSet, ValueRange range,
187  const BitVector &nonLive) {
188  for (auto [index, result] : llvm::enumerate(range)) {
189  if (!nonLive[index])
190  continue;
191  nonLiveSet.insert(result);
192  LDBG() << "Marking value " << result << " as non-live (dead) at index "
193  << index;
194  }
195 }
196 
197 /// Drop the uses of the i-th result of `op` and then erase it iff toErase[i]
198 /// is 1.
199 static void dropUsesAndEraseResults(Operation *op, BitVector toErase) {
200  assert(op->getNumResults() == toErase.size() &&
201  "expected the number of results in `op` and the size of `toErase` to "
202  "be the same");
203 
204  std::vector<Type> newResultTypes;
205  for (OpResult result : op->getResults())
206  if (!toErase[result.getResultNumber()])
207  newResultTypes.push_back(result.getType());
208  OpBuilder builder(op);
209  builder.setInsertionPointAfter(op);
210  OperationState state(op->getLoc(), op->getName().getStringRef(),
211  op->getOperands(), newResultTypes, op->getAttrs());
212  for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i)
213  state.addRegion();
214  Operation *newOp = builder.create(state);
215  for (const auto &[index, region] : llvm::enumerate(op->getRegions())) {
216  Region &newRegion = newOp->getRegion(index);
217  // Move all blocks of `region` into `newRegion`.
218  Block *temp = new Block();
219  newRegion.push_back(temp);
220  while (!region.empty())
221  region.front().moveBefore(temp);
222  temp->erase();
223  }
224 
225  unsigned indexOfNextNewCallOpResultToReplace = 0;
226  for (auto [index, result] : llvm::enumerate(op->getResults())) {
227  assert(result && "expected result to be non-null");
228  if (toErase[index]) {
229  result.dropAllUses();
230  } else {
231  result.replaceAllUsesWith(
232  newOp->getResult(indexOfNextNewCallOpResultToReplace++));
233  }
234  }
235  op->erase();
236 }
237 
238 /// Convert a list of `Operand`s to a list of `OpOperand`s.
240  OpOperand *values = operands.getBase();
241  SmallVector<OpOperand *> opOperands;
242  for (unsigned i = 0, e = operands.size(); i < e; i++)
243  opOperands.push_back(&values[i]);
244  return opOperands;
245 }
246 
247 /// Process a simple operation `op` using the liveness analysis `la`.
248 /// If the operation has no memory effects and none of its results are live:
249 /// 1. Add the operation to a list for future removal, and
250 /// 2. Mark all its results as non-live values
251 ///
252 /// The operation `op` is assumed to be simple. A simple operation is one that
253 /// is NOT:
254 /// - Function-like
255 /// - Call-like
256 /// - A region branch operation
257 /// - A branch operation
258 /// - A region branch terminator
259 /// - Return-like
260 static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
261  DenseSet<Value> &nonLiveSet,
262  RDVFinalCleanupList &cl) {
263  if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) {
264  LDBG() << "Simple op is not memory effect free or has live results, "
265  "preserving it: "
266  << OpWithFlags(op, OpPrintingFlags().skipRegions());
267  return;
268  }
269 
270  LDBG()
271  << "Simple op has all dead results and is memory effect free, scheduling "
272  "for removal: "
273  << OpWithFlags(op, OpPrintingFlags().skipRegions());
274  cl.operations.push_back(op);
275  collectNonLiveValues(nonLiveSet, op->getResults(),
276  BitVector(op->getNumResults(), true));
277 }
278 
279 /// Process a function-like operation `funcOp` using the liveness analysis `la`
280 /// and the IR in `module`. If it is not public or external:
281 /// (1) Adding its non-live arguments to a list for future removal.
282 /// (2) Marking their corresponding operands in its callers for removal.
283 /// (3) Identifying and enqueueing unnecessary terminator operands
284 /// (return values that are non-live across all callers) for removal.
285 /// (4) Enqueueing the non-live arguments and return values for removal.
286 /// (5) Collecting the uses of these return values in its callers for future
287 /// removal.
288 /// (6) Marking all its results as non-live values.
289 static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
290  RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
291  RDVFinalCleanupList &cl) {
292  LDBG() << "Processing function op: "
293  << OpWithFlags(funcOp, OpPrintingFlags().skipRegions());
294  if (funcOp.isPublic() || funcOp.isExternal()) {
295  LDBG() << "Function is public or external, skipping: "
296  << funcOp.getOperation()->getName();
297  return;
298  }
299 
300  // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
301  SmallVector<Value> arguments(funcOp.getArguments());
302  BitVector nonLiveArgs = markLives(arguments, nonLiveSet, la);
303  nonLiveArgs = nonLiveArgs.flip();
304 
305  // Do (1).
306  for (auto [index, arg] : llvm::enumerate(arguments))
307  if (arg && nonLiveArgs[index]) {
308  cl.values.push_back(arg);
309  nonLiveSet.insert(arg);
310  }
311 
312  // Do (2). (Skip creating generic operand cleanup entries for call ops.
313  // Call arguments will be removed in the call-site specific segment-aware
314  // cleanup, avoiding generic eraseOperands bitvector mechanics.)
315  SymbolTable::UseRange uses = *funcOp.getSymbolUses(module);
316  for (SymbolTable::SymbolUse use : uses) {
317  Operation *callOp = use.getUser();
318  assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
319  // Push an empty operand cleanup entry so that call-site specific logic in
320  // cleanUpDeadVals runs (it keys off CallOpInterface). The BitVector is
321  // intentionally all false to avoid generic erasure.
322  // Store the funcOp as the callee to avoid expensive symbol lookup later.
323  cl.operands.push_back({callOp, BitVector(callOp->getNumOperands(), false),
324  funcOp.getOperation()});
325  }
326 
327  // Do (3).
328  // Get the list of unnecessary terminator operands (return values that are
329  // non-live across all callers) in `nonLiveRets`. There is a very important
330  // subtlety here. Unnecessary terminator operands are NOT the operands of the
331  // terminator that are non-live. Instead, these are the return values of the
332  // callers such that a given return value is non-live across all callers. Such
333  // corresponding operands in the terminator could be live. An example to
334  // demonstrate this:
335  // func.func private @f(%arg0: memref<i32>) -> (i32, i32) {
336  // %c0_i32 = arith.constant 0 : i32
337  // %0 = arith.addi %c0_i32, %c0_i32 : i32
338  // memref.store %0, %arg0[] : memref<i32>
339  // return %c0_i32, %0 : i32, i32
340  // }
341  // func.func @main(%arg0: i32, %arg1: memref<i32>) -> (i32) {
342  // %1:2 = call @f(%arg1) : (memref<i32>) -> i32
343  // return %1#0 : i32
344  // }
345  // Here, we can see that %1#1 is never used. It is non-live. Thus, @f doesn't
346  // need to return %0. But, %0 is live. And, still, we want to stop it from
347  // being returned, in order to optimize our IR. So, this demonstrates how we
348  // can make our optimization strong by even removing a live return value (%0),
349  // since it forwards only to non-live value(s) (%1#1).
350  size_t numReturns = funcOp.getNumResults();
351  BitVector nonLiveRets(numReturns, true);
352  for (SymbolTable::SymbolUse use : uses) {
353  Operation *callOp = use.getUser();
354  assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
355  BitVector liveCallRets = markLives(callOp->getResults(), nonLiveSet, la);
356  nonLiveRets &= liveCallRets.flip();
357  }
358 
359  // Note that in the absence of control flow ops forcing the control to go from
360  // the entry (first) block to the other blocks, the control never reaches any
361  // block other than the entry block, because every block has a terminator.
362  for (Block &block : funcOp.getBlocks()) {
363  Operation *returnOp = block.getTerminator();
364  if (returnOp && returnOp->getNumOperands() == numReturns)
365  cl.operands.push_back({returnOp, nonLiveRets});
366  }
367 
368  // Do (4).
369  cl.functions.push_back({funcOp, nonLiveArgs, nonLiveRets});
370 
371  // Do (5) and (6).
372  if (numReturns == 0)
373  return;
374  for (SymbolTable::SymbolUse use : uses) {
375  Operation *callOp = use.getUser();
376  assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
377  cl.results.push_back({callOp, nonLiveRets});
378  collectNonLiveValues(nonLiveSet, callOp->getResults(), nonLiveRets);
379  }
380 }
381 
382 /// Process a region branch operation `regionBranchOp` using the liveness
383 /// information in `la`. The processing involves two scenarios:
384 ///
385 /// Scenario 1: If the operation has no memory effects and none of its results
386 /// are live:
387 /// (1') Enqueue all its uses for deletion.
388 /// (2') Enqueue the branch itself for deletion.
389 ///
390 /// Scenario 2: Otherwise:
391 /// (1) Collect its unnecessary operands (operands forwarded to unnecessary
392 /// results or arguments).
393 /// (2) Process each of its regions.
394 /// (3) Collect the uses of its unnecessary results (results forwarded from
395 /// unnecessary operands
396 /// or terminator operands).
397 /// (4) Add these results to the deletion list.
398 ///
399 /// Processing a region includes:
400 /// (a) Collecting the uses of its unnecessary arguments (arguments forwarded
401 /// from unnecessary operands
402 /// or terminator operands).
403 /// (b) Collecting these unnecessary arguments.
404 /// (c) Collecting its unnecessary terminator operands (terminator operands
405 /// forwarded to unnecessary results
406 /// or arguments).
407 ///
408 /// Value Flow Note: In this operation, values flow as follows:
409 /// - From operands and terminator operands (successor operands)
410 /// - To arguments and results (successor inputs).
411 static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
413  DenseSet<Value> &nonLiveSet,
414  RDVFinalCleanupList &cl) {
415  LDBG() << "Processing region branch op: "
416  << OpWithFlags(regionBranchOp, OpPrintingFlags().skipRegions());
417  // Mark live results of `regionBranchOp` in `liveResults`.
418  auto markLiveResults = [&](BitVector &liveResults) {
419  liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la);
420  };
421 
422  // Mark live arguments in the regions of `regionBranchOp` in `liveArgs`.
423  auto markLiveArgs = [&](DenseMap<Region *, BitVector> &liveArgs) {
424  for (Region &region : regionBranchOp->getRegions()) {
425  if (region.empty())
426  continue;
427  SmallVector<Value> arguments(region.front().getArguments());
428  BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la);
429  liveArgs[&region] = regionLiveArgs;
430  }
431  };
432 
433  // Return the successors of `region` if the latter is not null. Else return
434  // the successors of `regionBranchOp`.
435  auto getSuccessors = [&](RegionBranchPoint point) {
436  SmallVector<RegionSuccessor> successors;
437  regionBranchOp.getSuccessorRegions(point, successors);
438  return successors;
439  };
440 
441  // Return the operands of `terminator` that are forwarded to `successor` if
442  // the former is not null. Else return the operands of `regionBranchOp`
443  // forwarded to `successor`.
444  auto getForwardedOpOperands = [&](const RegionSuccessor &successor,
445  Operation *terminator = nullptr) {
446  OperandRange operands =
447  terminator ? cast<RegionBranchTerminatorOpInterface>(terminator)
448  .getSuccessorOperands(successor)
449  : regionBranchOp.getEntrySuccessorOperands(successor);
450  SmallVector<OpOperand *> opOperands = operandsToOpOperands(operands);
451  return opOperands;
452  };
453 
454  // Mark the non-forwarded operands of `regionBranchOp` in
455  // `nonForwardedOperands`.
456  auto markNonForwardedOperands = [&](BitVector &nonForwardedOperands) {
457  nonForwardedOperands.resize(regionBranchOp->getNumOperands(), true);
458  for (const RegionSuccessor &successor :
459  getSuccessors(RegionBranchPoint::parent())) {
460  for (OpOperand *opOperand : getForwardedOpOperands(successor))
461  nonForwardedOperands.reset(opOperand->getOperandNumber());
462  }
463  };
464 
465  // Mark the non-forwarded terminator operands of the various regions of
466  // `regionBranchOp` in `nonForwardedRets`.
467  auto markNonForwardedReturnValues =
468  [&](DenseMap<Operation *, BitVector> &nonForwardedRets) {
469  for (Region &region : regionBranchOp->getRegions()) {
470  if (region.empty())
471  continue;
472  // TODO: this isn't correct in face of multiple terminators.
473  Operation *terminator = region.front().getTerminator();
474  nonForwardedRets[terminator] =
475  BitVector(terminator->getNumOperands(), true);
476  for (const RegionSuccessor &successor :
477  getSuccessors(RegionBranchPoint(
478  cast<RegionBranchTerminatorOpInterface>(terminator)))) {
479  for (OpOperand *opOperand :
480  getForwardedOpOperands(successor, terminator))
481  nonForwardedRets[terminator].reset(opOperand->getOperandNumber());
482  }
483  }
484  };
485 
486  // Update `valuesToKeep` (which is expected to correspond to operands or
487  // terminator operands) based on `resultsToKeep` and `argsToKeep`, given
488  // `region`. When `valuesToKeep` correspond to operands, `region` is null.
489  // Else, `region` is the parent region of the terminator.
490  auto updateOperandsOrTerminatorOperandsToKeep =
491  [&](BitVector &valuesToKeep, BitVector &resultsToKeep,
492  DenseMap<Region *, BitVector> &argsToKeep, Region *region = nullptr) {
493  Operation *terminator =
494  region ? region->front().getTerminator() : nullptr;
495  RegionBranchPoint point =
496  terminator
498  cast<RegionBranchTerminatorOpInterface>(terminator))
499  : RegionBranchPoint::parent();
500 
501  for (const RegionSuccessor &successor : getSuccessors(point)) {
502  Region *successorRegion = successor.getSuccessor();
503  for (auto [opOperand, input] :
504  llvm::zip(getForwardedOpOperands(successor, terminator),
505  successor.getSuccessorInputs())) {
506  size_t operandNum = opOperand->getOperandNumber();
507  bool updateBasedOn =
508  successorRegion
509  ? argsToKeep[successorRegion]
510  [cast<BlockArgument>(input).getArgNumber()]
511  : resultsToKeep[cast<OpResult>(input).getResultNumber()];
512  valuesToKeep[operandNum] = valuesToKeep[operandNum] | updateBasedOn;
513  }
514  }
515  };
516 
517  // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep` and
518  // `terminatorOperandsToKeep`. Store true in `resultsOrArgsToKeepChanged` if a
519  // value is modified, else, false.
520  auto recomputeResultsAndArgsToKeep =
521  [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep,
522  BitVector &operandsToKeep,
523  DenseMap<Operation *, BitVector> &terminatorOperandsToKeep,
524  bool &resultsOrArgsToKeepChanged) {
525  resultsOrArgsToKeepChanged = false;
526 
527  // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep`.
528  for (const RegionSuccessor &successor :
529  getSuccessors(RegionBranchPoint::parent())) {
530  Region *successorRegion = successor.getSuccessor();
531  for (auto [opOperand, input] :
532  llvm::zip(getForwardedOpOperands(successor),
533  successor.getSuccessorInputs())) {
534  bool recomputeBasedOn =
535  operandsToKeep[opOperand->getOperandNumber()];
536  bool toRecompute =
537  successorRegion
538  ? argsToKeep[successorRegion]
539  [cast<BlockArgument>(input).getArgNumber()]
540  : resultsToKeep[cast<OpResult>(input).getResultNumber()];
541  if (!toRecompute && recomputeBasedOn)
542  resultsOrArgsToKeepChanged = true;
543  if (successorRegion) {
544  argsToKeep[successorRegion][cast<BlockArgument>(input)
545  .getArgNumber()] =
546  argsToKeep[successorRegion]
547  [cast<BlockArgument>(input).getArgNumber()] |
548  recomputeBasedOn;
549  } else {
550  resultsToKeep[cast<OpResult>(input).getResultNumber()] =
551  resultsToKeep[cast<OpResult>(input).getResultNumber()] |
552  recomputeBasedOn;
553  }
554  }
555  }
556 
557  // Recompute `resultsToKeep` and `argsToKeep` based on
558  // `terminatorOperandsToKeep`.
559  for (Region &region : regionBranchOp->getRegions()) {
560  if (region.empty())
561  continue;
562  Operation *terminator = region.front().getTerminator();
563  for (const RegionSuccessor &successor :
564  getSuccessors(RegionBranchPoint(
565  cast<RegionBranchTerminatorOpInterface>(terminator)))) {
566  Region *successorRegion = successor.getSuccessor();
567  for (auto [opOperand, input] :
568  llvm::zip(getForwardedOpOperands(successor, terminator),
569  successor.getSuccessorInputs())) {
570  bool recomputeBasedOn =
571  terminatorOperandsToKeep[region.back().getTerminator()]
572  [opOperand->getOperandNumber()];
573  bool toRecompute =
574  successorRegion
575  ? argsToKeep[successorRegion]
576  [cast<BlockArgument>(input).getArgNumber()]
577  : resultsToKeep[cast<OpResult>(input).getResultNumber()];
578  if (!toRecompute && recomputeBasedOn)
579  resultsOrArgsToKeepChanged = true;
580  if (successorRegion) {
581  argsToKeep[successorRegion][cast<BlockArgument>(input)
582  .getArgNumber()] =
583  argsToKeep[successorRegion]
584  [cast<BlockArgument>(input).getArgNumber()] |
585  recomputeBasedOn;
586  } else {
587  resultsToKeep[cast<OpResult>(input).getResultNumber()] =
588  resultsToKeep[cast<OpResult>(input).getResultNumber()] |
589  recomputeBasedOn;
590  }
591  }
592  }
593  }
594  };
595 
596  // Mark the values that we want to keep in `resultsToKeep`, `argsToKeep`,
597  // `operandsToKeep`, and `terminatorOperandsToKeep`.
598  auto markValuesToKeep =
599  [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep,
600  BitVector &operandsToKeep,
601  DenseMap<Operation *, BitVector> &terminatorOperandsToKeep) {
602  bool resultsOrArgsToKeepChanged = true;
603  // We keep updating and recomputing the values until we reach a point
604  // where they stop changing.
605  while (resultsOrArgsToKeepChanged) {
606  // Update the operands that need to be kept.
607  updateOperandsOrTerminatorOperandsToKeep(operandsToKeep,
608  resultsToKeep, argsToKeep);
609 
610  // Update the terminator operands that need to be kept.
611  for (Region &region : regionBranchOp->getRegions()) {
612  if (region.empty())
613  continue;
614  updateOperandsOrTerminatorOperandsToKeep(
615  terminatorOperandsToKeep[region.back().getTerminator()],
616  resultsToKeep, argsToKeep, &region);
617  }
618 
619  // Recompute the results and arguments that need to be kept.
620  recomputeResultsAndArgsToKeep(
621  resultsToKeep, argsToKeep, operandsToKeep,
622  terminatorOperandsToKeep, resultsOrArgsToKeepChanged);
623  }
624  };
625 
626  // Scenario 1. This is the only case where the entire `regionBranchOp`
627  // is removed. It will not happen in any other scenario. Note that in this
628  // case, a non-forwarded operand of `regionBranchOp` could be live/non-live.
629  // It could never be live because of this op but its liveness could have been
630  // attributed to something else.
631  // Do (1') and (2').
632  if (isMemoryEffectFree(regionBranchOp.getOperation()) &&
633  !hasLive(regionBranchOp->getResults(), nonLiveSet, la)) {
634  cl.operations.push_back(regionBranchOp.getOperation());
635  return;
636  }
637 
638  // Scenario 2.
639  // At this point, we know that every non-forwarded operand of `regionBranchOp`
640  // is live.
641 
642  // Stores the results of `regionBranchOp` that we want to keep.
643  BitVector resultsToKeep;
644  // Stores the mapping from regions of `regionBranchOp` to their arguments that
645  // we want to keep.
647  // Stores the operands of `regionBranchOp` that we want to keep.
648  BitVector operandsToKeep;
649  // Stores the mapping from region terminators in `regionBranchOp` to their
650  // operands that we want to keep.
651  DenseMap<Operation *, BitVector> terminatorOperandsToKeep;
652 
653  // Initializing the above variables...
654 
655  // The live results of `regionBranchOp` definitely need to be kept.
656  markLiveResults(resultsToKeep);
657  // Similarly, the live arguments of the regions in `regionBranchOp` definitely
658  // need to be kept.
659  markLiveArgs(argsToKeep);
660  // The non-forwarded operands of `regionBranchOp` definitely need to be kept.
661  // A live forwarded operand can be removed but no non-forwarded operand can be
662  // removed since it "controls" the flow of data in this control flow op.
663  markNonForwardedOperands(operandsToKeep);
664  // Similarly, the non-forwarded terminator operands of the regions in
665  // `regionBranchOp` definitely need to be kept.
666  markNonForwardedReturnValues(terminatorOperandsToKeep);
667 
668  // Mark the values (results, arguments, operands, and terminator operands)
669  // that we want to keep.
670  markValuesToKeep(resultsToKeep, argsToKeep, operandsToKeep,
671  terminatorOperandsToKeep);
672 
673  // Do (1).
674  cl.operands.push_back({regionBranchOp, operandsToKeep.flip()});
675 
676  // Do (2.a) and (2.b).
677  for (Region &region : regionBranchOp->getRegions()) {
678  if (region.empty())
679  continue;
680  BitVector argsToRemove = argsToKeep[&region].flip();
681  cl.blocks.push_back({&region.front(), argsToRemove});
682  collectNonLiveValues(nonLiveSet, region.front().getArguments(),
683  argsToRemove);
684  }
685 
686  // Do (2.c).
687  for (Region &region : regionBranchOp->getRegions()) {
688  if (region.empty())
689  continue;
690  Operation *terminator = region.front().getTerminator();
691  cl.operands.push_back(
692  {terminator, terminatorOperandsToKeep[terminator].flip()});
693  }
694 
695  // Do (3) and (4).
696  BitVector resultsToRemove = resultsToKeep.flip();
697  collectNonLiveValues(nonLiveSet, regionBranchOp.getOperation()->getResults(),
698  resultsToRemove);
699  cl.results.push_back({regionBranchOp.getOperation(), resultsToRemove});
700 }
701 
702 /// Steps to process a `BranchOpInterface` operation:
703 /// Iterate through each successor block of `branchOp`.
704 /// (1) For each successor block, gather all operands from all successors.
705 /// (2) Fetch their associated liveness analysis data and collect for future
706 /// removal.
707 /// (3) Identify and collect the dead operands from the successor block
708 /// as well as their corresponding arguments.
709 
710 static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
711  DenseSet<Value> &nonLiveSet,
712  RDVFinalCleanupList &cl) {
713  LDBG() << "Processing branch op: " << *branchOp;
714  unsigned numSuccessors = branchOp->getNumSuccessors();
715 
716  for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
717  Block *successorBlock = branchOp->getSuccessor(succIdx);
718 
719  // Do (1)
720  SuccessorOperands successorOperands =
721  branchOp.getSuccessorOperands(succIdx);
722  SmallVector<Value> operandValues;
723  for (unsigned operandIdx = 0; operandIdx < successorOperands.size();
724  ++operandIdx) {
725  operandValues.push_back(successorOperands[operandIdx]);
726  }
727 
728  // Do (2)
729  BitVector successorNonLive =
730  markLives(operandValues, nonLiveSet, la).flip();
731  collectNonLiveValues(nonLiveSet, successorBlock->getArguments(),
732  successorNonLive);
733 
734  // Do (3)
735  cl.blocks.push_back({successorBlock, successorNonLive});
736  cl.successorOperands.push_back({branchOp, succIdx, successorNonLive});
737  }
738 }
739 
740 /// Removes dead values collected in RDVFinalCleanupList.
741 /// To be run once when all dead values have been collected.
742 static void cleanUpDeadVals(RDVFinalCleanupList &list) {
743  LDBG() << "Starting cleanup of dead values...";
744 
745  // 1. Operations
746  LDBG() << "Cleaning up " << list.operations.size() << " operations";
747  for (auto &op : list.operations) {
748  LDBG() << "Erasing operation: "
749  << OpWithFlags(op, OpPrintingFlags().skipRegions());
750  op->dropAllUses();
751  op->erase();
752  }
753 
754  // 2. Values
755  LDBG() << "Cleaning up " << list.values.size() << " values";
756  for (auto &v : list.values) {
757  LDBG() << "Dropping all uses of value: " << v;
758  v.dropAllUses();
759  }
760 
761  // 3. Functions
762  LDBG() << "Cleaning up " << list.functions.size() << " functions";
763  // Record which function arguments were erased so we can shrink call-site
764  // argument segments for CallOpInterface operations (e.g. ops using
765  // AttrSizedOperandSegments) in the next phase.
766  DenseMap<Operation *, BitVector> erasedFuncArgs;
767  for (auto &f : list.functions) {
768  LDBG() << "Cleaning up function: " << f.funcOp.getOperation()->getName();
769  LDBG() << " Erasing " << f.nonLiveArgs.count() << " non-live arguments";
770  LDBG() << " Erasing " << f.nonLiveRets.count()
771  << " non-live return values";
772  // Some functions may not allow erasing arguments or results. These calls
773  // return failure in such cases without modifying the function, so it's okay
774  // to proceed.
775  if (succeeded(f.funcOp.eraseArguments(f.nonLiveArgs))) {
776  // Record only if we actually erased something.
777  if (f.nonLiveArgs.any())
778  erasedFuncArgs.try_emplace(f.funcOp.getOperation(), f.nonLiveArgs);
779  }
780  (void)f.funcOp.eraseResults(f.nonLiveRets);
781  }
782 
783  // 4. Operands
784  LDBG() << "Cleaning up " << list.operands.size() << " operand lists";
785  for (OperationToCleanup &o : list.operands) {
786  // Handle call-specific cleanup only when we have a cached callee reference.
787  // This avoids expensive symbol lookup and is defensive against future
788  // changes.
789  bool handledAsCall = false;
790  if (o.callee && isa<CallOpInterface>(o.op)) {
791  auto call = cast<CallOpInterface>(o.op);
792  auto it = erasedFuncArgs.find(o.callee);
793  if (it != erasedFuncArgs.end()) {
794  const BitVector &deadArgIdxs = it->second;
795  MutableOperandRange args = call.getArgOperandsMutable();
796  // First, erase the call arguments corresponding to erased callee
797  // args. We iterate backwards to preserve indices.
798  for (unsigned argIdx : llvm::reverse(deadArgIdxs.set_bits()))
799  args.erase(argIdx);
800  // If this operand cleanup entry also has a generic nonLive bitvector,
801  // clear bits for call arguments we already erased above to avoid
802  // double-erasing (which could impact other segments of ops with
803  // AttrSizedOperandSegments).
804  if (o.nonLive.any()) {
805  // Map the argument logical index to the operand number(s) recorded.
806  int operandOffset = call.getArgOperands().getBeginOperandIndex();
807  for (int argIdx : deadArgIdxs.set_bits()) {
808  int operandNumber = operandOffset + argIdx;
809  if (operandNumber < static_cast<int>(o.nonLive.size()))
810  o.nonLive.reset(operandNumber);
811  }
812  }
813  handledAsCall = true;
814  }
815  }
816  // Perform generic operand erasure for:
817  // - Non-call operations
818  // - Call operations without cached callee (where handledAsCall is false)
819  // But skip call operations that were already handled via segment-aware path
820  if (!handledAsCall && o.nonLive.any()) {
821  o.op->eraseOperands(o.nonLive);
822  }
823  }
824 
825  // 5. Results
826  LDBG() << "Cleaning up " << list.results.size() << " result lists";
827  for (auto &r : list.results) {
828  LDBG() << "Erasing " << r.nonLive.count()
829  << " non-live results from operation: "
830  << OpWithFlags(r.op, OpPrintingFlags().skipRegions());
831  dropUsesAndEraseResults(r.op, r.nonLive);
832  }
833 
834  // 6. Blocks
835  LDBG() << "Cleaning up " << list.blocks.size() << " block argument lists";
836  for (auto &b : list.blocks) {
837  // blocks that are accessed via multiple codepaths processed once
838  if (b.b->getNumArguments() != b.nonLiveArgs.size())
839  continue;
840  LDBG() << "Erasing " << b.nonLiveArgs.count()
841  << " non-live arguments from block: " << b.b;
842  // it iterates backwards because erase invalidates all successor indexes
843  for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) {
844  if (!b.nonLiveArgs[i])
845  continue;
846  LDBG() << " Erasing block argument " << i << ": " << b.b->getArgument(i);
847  b.b->getArgument(i).dropAllUses();
848  b.b->eraseArgument(i);
849  }
850  }
851 
852  // 7. Successor Operands
853  LDBG() << "Cleaning up " << list.successorOperands.size()
854  << " successor operand lists";
855  for (auto &op : list.successorOperands) {
856  SuccessorOperands successorOperands =
857  op.branch.getSuccessorOperands(op.successorIndex);
858  // blocks that are accessed via multiple codepaths processed once
859  if (successorOperands.size() != op.nonLiveOperands.size())
860  continue;
861  LDBG() << "Erasing " << op.nonLiveOperands.count()
862  << " non-live successor operands from successor "
863  << op.successorIndex << " of branch: "
864  << OpWithFlags(op.branch, OpPrintingFlags().skipRegions());
865  // it iterates backwards because erase invalidates all successor indexes
866  for (int i = successorOperands.size() - 1; i >= 0; --i) {
867  if (!op.nonLiveOperands[i])
868  continue;
869  LDBG() << " Erasing successor operand " << i << ": "
870  << successorOperands[i];
871  successorOperands.erase(i);
872  }
873  }
874 
875  LDBG() << "Finished cleanup of dead values";
876 }
877 
878 struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {
879  void runOnOperation() override;
880 };
881 } // namespace
882 
883 void RemoveDeadValues::runOnOperation() {
884  auto &la = getAnalysis<RunLivenessAnalysis>();
885  Operation *module = getOperation();
886 
887  // Tracks values eligible for erasure - complements liveness analysis to
888  // identify "droppable" values.
889  DenseSet<Value> deadVals;
890 
891  // Maintains a list of Ops, values, branches, etc., slated for cleanup at the
892  // end of this pass.
893  RDVFinalCleanupList finalCleanupList;
894 
895  module->walk([&](Operation *op) {
896  if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
897  processFuncOp(funcOp, module, la, deadVals, finalCleanupList);
898  } else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
899  processRegionBranchOp(regionBranchOp, la, deadVals, finalCleanupList);
900  } else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
901  processBranchOp(branchOp, la, deadVals, finalCleanupList);
902  } else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
903  // Nothing to do here because this is a terminator op and it should be
904  // honored with respect to its parent
905  } else if (isa<CallOpInterface>(op)) {
906  // Nothing to do because this op is associated with a function op and gets
907  // cleaned when the latter is cleaned.
908  } else {
909  processSimpleOp(op, la, deadVals, finalCleanupList);
910  }
911  });
912 
913  cleanUpDeadVals(finalCleanupList);
914 }
915 
916 std::unique_ptr<Pass> mlir::createRemoveDeadValuesPass() {
917  return std::make_unique<RemoveDeadValues>();
918 }
static MutableArrayRef< OpOperand > operandsToOpOperands(OperandRange &operands)
Block represents an ordered list of Operations.
Definition: Block.h:33
void erase()
Unlink this Block from its parent region and delete it.
Definition: Block.cpp:66
BlockArgListType getArguments()
Definition: Block.h:87
Block * getSuccessor(unsigned i)
Definition: Block.cpp:269
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:118
void erase(unsigned subStart, unsigned subLen=1)
Erase the operands within the given sub-range.
This class helps build Operations.
Definition: Builders.h:207
This class represents an operand of an operation.
Definition: Value.h:257
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
Definition: Value.h:457
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:773
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition: Operation.h:1111
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:834
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:67
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
static constexpr RegionBranchPoint parent()
Returns an instance of RegionBranchPoint representing the parent operation.
This class represents a successor of a region.
ValueRange getSuccessorInputs() const
Return the inputs to the successor that are remapped by the exit values of the current region.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
void push_back(Block *block)
Definition: Region.h:61
This class models how operands are forwarded to block arguments in control flow.
void erase(unsigned subStart, unsigned subLen=1)
Erase operands forwarded to the successor.
unsigned size() const
Returns the amount of operands passed to the successor.
This class represents a specific symbol use.
Definition: SymbolTable.h:183
This class implements a range of SymbolRef uses.
Definition: SymbolTable.h:203
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
std::unique_ptr< Pass > createRemoveDeadValuesPass()
Creates an optimization pass to remove dead values.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
This lattice represents, for a given value, whether or not it is "live".
Runs liveness analysis on the IR defined by op.
const Liveness * getLiveness(Value val)