MLIR  21.0.0git
RemoveDeadValues.cpp
Go to the documentation of this file.
1 //===- RemoveDeadValues.cpp - Remove Dead Values --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The goal of this pass is optimization (reducing runtime) by removing
10 // unnecessary instructions. Unlike other passes that rely on local information
11 // gathered from patterns to accomplish optimization, this pass uses a full
12 // analysis of the IR, specifically, liveness analysis, and is thus more
13 // powerful.
14 //
15 // Currently, this pass performs the following optimizations:
16 // (A) Removes function arguments that are not live,
17 // (B) Removes function return values that are not live across all callers of
18 // the function,
19 // (C) Removes unneccesary operands, results, region arguments, and region
20 // terminator operands of region branch ops, and,
21 // (D) Removes simple and region branch ops that have all non-live results and
22 // don't affect memory in any way,
23 //
24 // iff
25 //
26 // the IR doesn't have any non-function symbol ops, non-call symbol user ops and
27 // branch ops.
28 //
29 // Here, a "simple op" refers to an op that isn't a symbol op, symbol-user op,
30 // region branch op, branch op, region branch terminator op, or return-like.
31 //
32 //===----------------------------------------------------------------------===//
33 
36 #include "mlir/IR/Attributes.h"
37 #include "mlir/IR/Builders.h"
39 #include "mlir/IR/Dialect.h"
40 #include "mlir/IR/IRMapping.h"
42 #include "mlir/IR/SymbolTable.h"
43 #include "mlir/IR/Value.h"
44 #include "mlir/IR/ValueRange.h"
45 #include "mlir/IR/Visitors.h"
50 #include "mlir/Pass/Pass.h"
51 #include "mlir/Support/LLVM.h"
53 #include "mlir/Transforms/Passes.h"
54 #include "llvm/ADT/STLExtras.h"
55 #include <cassert>
56 #include <cstddef>
57 #include <memory>
58 #include <optional>
59 #include <vector>
60 
61 namespace mlir {
62 #define GEN_PASS_DEF_REMOVEDEADVALUES
63 #include "mlir/Transforms/Passes.h.inc"
64 } // namespace mlir
65 
66 using namespace mlir;
67 using namespace mlir::dataflow;
68 
69 //===----------------------------------------------------------------------===//
70 // RemoveDeadValues Pass
71 //===----------------------------------------------------------------------===//
72 
73 namespace {
74 
75 // Set of structures below to be filled with operations and arguments to erase.
76 // This is done to separate analysis and tree modification phases,
77 // otherwise analysis is operating on half-deleted tree which is incorrect.
78 
79 struct FunctionToCleanUp {
80  FunctionOpInterface funcOp;
81  BitVector nonLiveArgs;
82  BitVector nonLiveRets;
83 };
84 
85 struct OperationToCleanup {
86  Operation *op;
87  BitVector nonLive;
88 };
89 
90 struct BlockArgsToCleanup {
91  Block *b;
92  BitVector nonLiveArgs;
93 };
94 
95 struct SuccessorOperandsToCleanup {
96  BranchOpInterface branch;
97  unsigned successorIndex;
98  BitVector nonLiveOperands;
99 };
100 
101 struct RDVFinalCleanupList {
102  SmallVector<Operation *> operations;
103  SmallVector<Value> values;
108  SmallVector<SuccessorOperandsToCleanup> successorOperands;
109 };
110 
111 // Some helper functions...
112 
113 /// Return true iff at least one value in `values` is live, given the liveness
114 /// information in `la`.
115 static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
116  RunLivenessAnalysis &la) {
117  for (Value value : values) {
118  if (nonLiveSet.contains(value))
119  continue;
120 
121  const Liveness *liveness = la.getLiveness(value);
122  if (!liveness || liveness->isLive)
123  return true;
124  }
125  return false;
126 }
127 
128 /// Return a BitVector of size `values.size()` where its i-th bit is 1 iff the
129 /// i-th value in `values` is live, given the liveness information in `la`.
130 static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
131  RunLivenessAnalysis &la) {
132  BitVector lives(values.size(), true);
133 
134  for (auto [index, value] : llvm::enumerate(values)) {
135  if (nonLiveSet.contains(value)) {
136  lives.reset(index);
137  continue;
138  }
139 
140  const Liveness *liveness = la.getLiveness(value);
141  // It is important to note that when `liveness` is null, we can't tell if
142  // `value` is live or not. So, the safe option is to consider it live. Also,
143  // the execution of this pass might create new SSA values when erasing some
144  // of the results of an op and we know that these new values are live
145  // (because they weren't erased) and also their liveness is null because
146  // liveness analysis ran before their creation.
147  if (liveness && !liveness->isLive)
148  lives.reset(index);
149  }
150 
151  return lives;
152 }
153 
154 /// Collects values marked as "non-live" in the provided range and inserts them
155 /// into the nonLiveSet. A value is considered "non-live" if the corresponding
156 /// index in the `nonLive` bit vector is set.
157 static void collectNonLiveValues(DenseSet<Value> &nonLiveSet, ValueRange range,
158  const BitVector &nonLive) {
159  for (auto [index, result] : llvm::enumerate(range)) {
160  if (!nonLive[index])
161  continue;
162  nonLiveSet.insert(result);
163  }
164 }
165 
166 /// Drop the uses of the i-th result of `op` and then erase it iff toErase[i]
167 /// is 1.
168 static void dropUsesAndEraseResults(Operation *op, BitVector toErase) {
169  assert(op->getNumResults() == toErase.size() &&
170  "expected the number of results in `op` and the size of `toErase` to "
171  "be the same");
172 
173  std::vector<Type> newResultTypes;
174  for (OpResult result : op->getResults())
175  if (!toErase[result.getResultNumber()])
176  newResultTypes.push_back(result.getType());
177  OpBuilder builder(op);
178  builder.setInsertionPointAfter(op);
179  OperationState state(op->getLoc(), op->getName().getStringRef(),
180  op->getOperands(), newResultTypes, op->getAttrs());
181  for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i)
182  state.addRegion();
183  Operation *newOp = builder.create(state);
184  for (const auto &[index, region] : llvm::enumerate(op->getRegions())) {
185  Region &newRegion = newOp->getRegion(index);
186  // Move all blocks of `region` into `newRegion`.
187  Block *temp = new Block();
188  newRegion.push_back(temp);
189  while (!region.empty())
190  region.front().moveBefore(temp);
191  temp->erase();
192  }
193 
194  unsigned indexOfNextNewCallOpResultToReplace = 0;
195  for (auto [index, result] : llvm::enumerate(op->getResults())) {
196  assert(result && "expected result to be non-null");
197  if (toErase[index]) {
198  result.dropAllUses();
199  } else {
200  result.replaceAllUsesWith(
201  newOp->getResult(indexOfNextNewCallOpResultToReplace++));
202  }
203  }
204  op->erase();
205 }
206 
207 /// Convert a list of `Operand`s to a list of `OpOperand`s.
209  OpOperand *values = operands.getBase();
210  SmallVector<OpOperand *> opOperands;
211  for (unsigned i = 0, e = operands.size(); i < e; i++)
212  opOperands.push_back(&values[i]);
213  return opOperands;
214 }
215 
216 /// Process a simple operation `op` using the liveness analysis `la`.
217 /// If the operation has no memory effects and none of its results are live:
218 /// 1. Add the operation to a list for future removal, and
219 /// 2. Mark all its results as non-live values
220 ///
221 /// The operation `op` is assumed to be simple. A simple operation is one that
222 /// is NOT:
223 /// - Function-like
224 /// - Call-like
225 /// - A region branch operation
226 /// - A branch operation
227 /// - A region branch terminator
228 /// - Return-like
229 static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
230  DenseSet<Value> &nonLiveSet,
231  RDVFinalCleanupList &cl) {
232  if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la))
233  return;
234 
235  cl.operations.push_back(op);
236  collectNonLiveValues(nonLiveSet, op->getResults(),
237  BitVector(op->getNumResults(), true));
238 }
239 
240 /// Process a function-like operation `funcOp` using the liveness analysis `la`
241 /// and the IR in `module`. If it is not public or external:
242 /// (1) Adding its non-live arguments to a list for future removal.
243 /// (2) Marking their corresponding operands in its callers for removal.
244 /// (3) Identifying and enqueueing unnecessary terminator operands
245 /// (return values that are non-live across all callers) for removal.
246 /// (4) Enqueueing the non-live arguments and return values for removal.
247 /// (5) Collecting the uses of these return values in its callers for future
248 /// removal.
249 /// (6) Marking all its results as non-live values.
250 static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
251  RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
252  RDVFinalCleanupList &cl) {
253  if (funcOp.isPublic() || funcOp.isExternal())
254  return;
255 
256  // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
257  SmallVector<Value> arguments(funcOp.getArguments());
258  BitVector nonLiveArgs = markLives(arguments, nonLiveSet, la);
259  nonLiveArgs = nonLiveArgs.flip();
260 
261  // Do (1).
262  for (auto [index, arg] : llvm::enumerate(arguments))
263  if (arg && nonLiveArgs[index]) {
264  cl.values.push_back(arg);
265  nonLiveSet.insert(arg);
266  }
267 
268  // Do (2).
269  SymbolTable::UseRange uses = *funcOp.getSymbolUses(module);
270  for (SymbolTable::SymbolUse use : uses) {
271  Operation *callOp = use.getUser();
272  assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
273  // The number of operands in the call op may not match the number of
274  // arguments in the func op.
275  BitVector nonLiveCallOperands(callOp->getNumOperands(), false);
276  SmallVector<OpOperand *> callOpOperands =
277  operandsToOpOperands(cast<CallOpInterface>(callOp).getArgOperands());
278  for (int index : nonLiveArgs.set_bits())
279  nonLiveCallOperands.set(callOpOperands[index]->getOperandNumber());
280  cl.operands.push_back({callOp, nonLiveCallOperands});
281  }
282 
283  // Do (3).
284  // Get the list of unnecessary terminator operands (return values that are
285  // non-live across all callers) in `nonLiveRets`. There is a very important
286  // subtlety here. Unnecessary terminator operands are NOT the operands of the
287  // terminator that are non-live. Instead, these are the return values of the
288  // callers such that a given return value is non-live across all callers. Such
289  // corresponding operands in the terminator could be live. An example to
290  // demonstrate this:
291  // func.func private @f(%arg0: memref<i32>) -> (i32, i32) {
292  // %c0_i32 = arith.constant 0 : i32
293  // %0 = arith.addi %c0_i32, %c0_i32 : i32
294  // memref.store %0, %arg0[] : memref<i32>
295  // return %c0_i32, %0 : i32, i32
296  // }
297  // func.func @main(%arg0: i32, %arg1: memref<i32>) -> (i32) {
298  // %1:2 = call @f(%arg1) : (memref<i32>) -> i32
299  // return %1#0 : i32
300  // }
301  // Here, we can see that %1#1 is never used. It is non-live. Thus, @f doesn't
302  // need to return %0. But, %0 is live. And, still, we want to stop it from
303  // being returned, in order to optimize our IR. So, this demonstrates how we
304  // can make our optimization strong by even removing a live return value (%0),
305  // since it forwards only to non-live value(s) (%1#1).
306  Operation *lastReturnOp = funcOp.back().getTerminator();
307  size_t numReturns = lastReturnOp->getNumOperands();
308  BitVector nonLiveRets(numReturns, true);
309  for (SymbolTable::SymbolUse use : uses) {
310  Operation *callOp = use.getUser();
311  assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
312  BitVector liveCallRets = markLives(callOp->getResults(), nonLiveSet, la);
313  nonLiveRets &= liveCallRets.flip();
314  }
315 
316  // Note that in the absence of control flow ops forcing the control to go from
317  // the entry (first) block to the other blocks, the control never reaches any
318  // block other than the entry block, because every block has a terminator.
319  for (Block &block : funcOp.getBlocks()) {
320  Operation *returnOp = block.getTerminator();
321  if (returnOp && returnOp->getNumOperands() == numReturns)
322  cl.operands.push_back({returnOp, nonLiveRets});
323  }
324 
325  // Do (4).
326  cl.functions.push_back({funcOp, nonLiveArgs, nonLiveRets});
327 
328  // Do (5) and (6).
329  for (SymbolTable::SymbolUse use : uses) {
330  Operation *callOp = use.getUser();
331  assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
332  cl.results.push_back({callOp, nonLiveRets});
333  collectNonLiveValues(nonLiveSet, callOp->getResults(), nonLiveRets);
334  }
335 }
336 
337 /// Process a region branch operation `regionBranchOp` using the liveness
338 /// information in `la`. The processing involves two scenarios:
339 ///
340 /// Scenario 1: If the operation has no memory effects and none of its results
341 /// are live:
342 /// (1') Enqueue all its uses for deletion.
343 /// (2') Enqueue the branch itself for deletion.
344 ///
345 /// Scenario 2: Otherwise:
346 /// (1) Collect its unnecessary operands (operands forwarded to unnecessary
347 /// results or arguments).
348 /// (2) Process each of its regions.
349 /// (3) Collect the uses of its unnecessary results (results forwarded from
350 /// unnecessary operands
351 /// or terminator operands).
352 /// (4) Add these results to the deletion list.
353 ///
354 /// Processing a region includes:
355 /// (a) Collecting the uses of its unnecessary arguments (arguments forwarded
356 /// from unnecessary operands
357 /// or terminator operands).
358 /// (b) Collecting these unnecessary arguments.
359 /// (c) Collecting its unnecessary terminator operands (terminator operands
360 /// forwarded to unnecessary results
361 /// or arguments).
362 ///
363 /// Value Flow Note: In this operation, values flow as follows:
364 /// - From operands and terminator operands (successor operands)
365 /// - To arguments and results (successor inputs).
366 static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
368  DenseSet<Value> &nonLiveSet,
369  RDVFinalCleanupList &cl) {
370  // Mark live results of `regionBranchOp` in `liveResults`.
371  auto markLiveResults = [&](BitVector &liveResults) {
372  liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la);
373  };
374 
375  // Mark live arguments in the regions of `regionBranchOp` in `liveArgs`.
376  auto markLiveArgs = [&](DenseMap<Region *, BitVector> &liveArgs) {
377  for (Region &region : regionBranchOp->getRegions()) {
378  if (region.empty())
379  continue;
380  SmallVector<Value> arguments(region.front().getArguments());
381  BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la);
382  liveArgs[&region] = regionLiveArgs;
383  }
384  };
385 
386  // Return the successors of `region` if the latter is not null. Else return
387  // the successors of `regionBranchOp`.
388  auto getSuccessors = [&](Region *region = nullptr) {
389  auto point = region ? region : RegionBranchPoint::parent();
390  SmallVector<Attribute> operandAttributes(regionBranchOp->getNumOperands(),
391  nullptr);
392  SmallVector<RegionSuccessor> successors;
393  regionBranchOp.getSuccessorRegions(point, successors);
394  return successors;
395  };
396 
397  // Return the operands of `terminator` that are forwarded to `successor` if
398  // the former is not null. Else return the operands of `regionBranchOp`
399  // forwarded to `successor`.
400  auto getForwardedOpOperands = [&](const RegionSuccessor &successor,
401  Operation *terminator = nullptr) {
402  OperandRange operands =
403  terminator ? cast<RegionBranchTerminatorOpInterface>(terminator)
404  .getSuccessorOperands(successor)
405  : regionBranchOp.getEntrySuccessorOperands(successor);
406  SmallVector<OpOperand *> opOperands = operandsToOpOperands(operands);
407  return opOperands;
408  };
409 
410  // Mark the non-forwarded operands of `regionBranchOp` in
411  // `nonForwardedOperands`.
412  auto markNonForwardedOperands = [&](BitVector &nonForwardedOperands) {
413  nonForwardedOperands.resize(regionBranchOp->getNumOperands(), true);
414  for (const RegionSuccessor &successor : getSuccessors()) {
415  for (OpOperand *opOperand : getForwardedOpOperands(successor))
416  nonForwardedOperands.reset(opOperand->getOperandNumber());
417  }
418  };
419 
420  // Mark the non-forwarded terminator operands of the various regions of
421  // `regionBranchOp` in `nonForwardedRets`.
422  auto markNonForwardedReturnValues =
423  [&](DenseMap<Operation *, BitVector> &nonForwardedRets) {
424  for (Region &region : regionBranchOp->getRegions()) {
425  if (region.empty())
426  continue;
427  Operation *terminator = region.front().getTerminator();
428  nonForwardedRets[terminator] =
429  BitVector(terminator->getNumOperands(), true);
430  for (const RegionSuccessor &successor : getSuccessors(&region)) {
431  for (OpOperand *opOperand :
432  getForwardedOpOperands(successor, terminator))
433  nonForwardedRets[terminator].reset(opOperand->getOperandNumber());
434  }
435  }
436  };
437 
438  // Update `valuesToKeep` (which is expected to correspond to operands or
439  // terminator operands) based on `resultsToKeep` and `argsToKeep`, given
440  // `region`. When `valuesToKeep` correspond to operands, `region` is null.
441  // Else, `region` is the parent region of the terminator.
442  auto updateOperandsOrTerminatorOperandsToKeep =
443  [&](BitVector &valuesToKeep, BitVector &resultsToKeep,
444  DenseMap<Region *, BitVector> &argsToKeep, Region *region = nullptr) {
445  Operation *terminator =
446  region ? region->front().getTerminator() : nullptr;
447 
448  for (const RegionSuccessor &successor : getSuccessors(region)) {
449  Region *successorRegion = successor.getSuccessor();
450  for (auto [opOperand, input] :
451  llvm::zip(getForwardedOpOperands(successor, terminator),
452  successor.getSuccessorInputs())) {
453  size_t operandNum = opOperand->getOperandNumber();
454  bool updateBasedOn =
455  successorRegion
456  ? argsToKeep[successorRegion]
457  [cast<BlockArgument>(input).getArgNumber()]
458  : resultsToKeep[cast<OpResult>(input).getResultNumber()];
459  valuesToKeep[operandNum] = valuesToKeep[operandNum] | updateBasedOn;
460  }
461  }
462  };
463 
464  // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep` and
465  // `terminatorOperandsToKeep`. Store true in `resultsOrArgsToKeepChanged` if a
466  // value is modified, else, false.
467  auto recomputeResultsAndArgsToKeep =
468  [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep,
469  BitVector &operandsToKeep,
470  DenseMap<Operation *, BitVector> &terminatorOperandsToKeep,
471  bool &resultsOrArgsToKeepChanged) {
472  resultsOrArgsToKeepChanged = false;
473 
474  // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep`.
475  for (const RegionSuccessor &successor : getSuccessors()) {
476  Region *successorRegion = successor.getSuccessor();
477  for (auto [opOperand, input] :
478  llvm::zip(getForwardedOpOperands(successor),
479  successor.getSuccessorInputs())) {
480  bool recomputeBasedOn =
481  operandsToKeep[opOperand->getOperandNumber()];
482  bool toRecompute =
483  successorRegion
484  ? argsToKeep[successorRegion]
485  [cast<BlockArgument>(input).getArgNumber()]
486  : resultsToKeep[cast<OpResult>(input).getResultNumber()];
487  if (!toRecompute && recomputeBasedOn)
488  resultsOrArgsToKeepChanged = true;
489  if (successorRegion) {
490  argsToKeep[successorRegion][cast<BlockArgument>(input)
491  .getArgNumber()] =
492  argsToKeep[successorRegion]
493  [cast<BlockArgument>(input).getArgNumber()] |
494  recomputeBasedOn;
495  } else {
496  resultsToKeep[cast<OpResult>(input).getResultNumber()] =
497  resultsToKeep[cast<OpResult>(input).getResultNumber()] |
498  recomputeBasedOn;
499  }
500  }
501  }
502 
503  // Recompute `resultsToKeep` and `argsToKeep` based on
504  // `terminatorOperandsToKeep`.
505  for (Region &region : regionBranchOp->getRegions()) {
506  if (region.empty())
507  continue;
508  Operation *terminator = region.front().getTerminator();
509  for (const RegionSuccessor &successor : getSuccessors(&region)) {
510  Region *successorRegion = successor.getSuccessor();
511  for (auto [opOperand, input] :
512  llvm::zip(getForwardedOpOperands(successor, terminator),
513  successor.getSuccessorInputs())) {
514  bool recomputeBasedOn =
515  terminatorOperandsToKeep[region.back().getTerminator()]
516  [opOperand->getOperandNumber()];
517  bool toRecompute =
518  successorRegion
519  ? argsToKeep[successorRegion]
520  [cast<BlockArgument>(input).getArgNumber()]
521  : resultsToKeep[cast<OpResult>(input).getResultNumber()];
522  if (!toRecompute && recomputeBasedOn)
523  resultsOrArgsToKeepChanged = true;
524  if (successorRegion) {
525  argsToKeep[successorRegion][cast<BlockArgument>(input)
526  .getArgNumber()] =
527  argsToKeep[successorRegion]
528  [cast<BlockArgument>(input).getArgNumber()] |
529  recomputeBasedOn;
530  } else {
531  resultsToKeep[cast<OpResult>(input).getResultNumber()] =
532  resultsToKeep[cast<OpResult>(input).getResultNumber()] |
533  recomputeBasedOn;
534  }
535  }
536  }
537  }
538  };
539 
540  // Mark the values that we want to keep in `resultsToKeep`, `argsToKeep`,
541  // `operandsToKeep`, and `terminatorOperandsToKeep`.
542  auto markValuesToKeep =
543  [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep,
544  BitVector &operandsToKeep,
545  DenseMap<Operation *, BitVector> &terminatorOperandsToKeep) {
546  bool resultsOrArgsToKeepChanged = true;
547  // We keep updating and recomputing the values until we reach a point
548  // where they stop changing.
549  while (resultsOrArgsToKeepChanged) {
550  // Update the operands that need to be kept.
551  updateOperandsOrTerminatorOperandsToKeep(operandsToKeep,
552  resultsToKeep, argsToKeep);
553 
554  // Update the terminator operands that need to be kept.
555  for (Region &region : regionBranchOp->getRegions()) {
556  if (region.empty())
557  continue;
558  updateOperandsOrTerminatorOperandsToKeep(
559  terminatorOperandsToKeep[region.back().getTerminator()],
560  resultsToKeep, argsToKeep, &region);
561  }
562 
563  // Recompute the results and arguments that need to be kept.
564  recomputeResultsAndArgsToKeep(
565  resultsToKeep, argsToKeep, operandsToKeep,
566  terminatorOperandsToKeep, resultsOrArgsToKeepChanged);
567  }
568  };
569 
570  // Scenario 1. This is the only case where the entire `regionBranchOp`
571  // is removed. It will not happen in any other scenario. Note that in this
572  // case, a non-forwarded operand of `regionBranchOp` could be live/non-live.
573  // It could never be live because of this op but its liveness could have been
574  // attributed to something else.
575  // Do (1') and (2').
576  if (isMemoryEffectFree(regionBranchOp.getOperation()) &&
577  !hasLive(regionBranchOp->getResults(), nonLiveSet, la)) {
578  cl.operations.push_back(regionBranchOp.getOperation());
579  return;
580  }
581 
582  // Scenario 2.
583  // At this point, we know that every non-forwarded operand of `regionBranchOp`
584  // is live.
585 
586  // Stores the results of `regionBranchOp` that we want to keep.
587  BitVector resultsToKeep;
588  // Stores the mapping from regions of `regionBranchOp` to their arguments that
589  // we want to keep.
591  // Stores the operands of `regionBranchOp` that we want to keep.
592  BitVector operandsToKeep;
593  // Stores the mapping from region terminators in `regionBranchOp` to their
594  // operands that we want to keep.
595  DenseMap<Operation *, BitVector> terminatorOperandsToKeep;
596 
597  // Initializing the above variables...
598 
599  // The live results of `regionBranchOp` definitely need to be kept.
600  markLiveResults(resultsToKeep);
601  // Similarly, the live arguments of the regions in `regionBranchOp` definitely
602  // need to be kept.
603  markLiveArgs(argsToKeep);
604  // The non-forwarded operands of `regionBranchOp` definitely need to be kept.
605  // A live forwarded operand can be removed but no non-forwarded operand can be
606  // removed since it "controls" the flow of data in this control flow op.
607  markNonForwardedOperands(operandsToKeep);
608  // Similarly, the non-forwarded terminator operands of the regions in
609  // `regionBranchOp` definitely need to be kept.
610  markNonForwardedReturnValues(terminatorOperandsToKeep);
611 
612  // Mark the values (results, arguments, operands, and terminator operands)
613  // that we want to keep.
614  markValuesToKeep(resultsToKeep, argsToKeep, operandsToKeep,
615  terminatorOperandsToKeep);
616 
617  // Do (1).
618  cl.operands.push_back({regionBranchOp, operandsToKeep.flip()});
619 
620  // Do (2.a) and (2.b).
621  for (Region &region : regionBranchOp->getRegions()) {
622  if (region.empty())
623  continue;
624  BitVector argsToRemove = argsToKeep[&region].flip();
625  cl.blocks.push_back({&region.front(), argsToRemove});
626  collectNonLiveValues(nonLiveSet, region.front().getArguments(),
627  argsToRemove);
628  }
629 
630  // Do (2.c).
631  for (Region &region : regionBranchOp->getRegions()) {
632  if (region.empty())
633  continue;
634  Operation *terminator = region.front().getTerminator();
635  cl.operands.push_back(
636  {terminator, terminatorOperandsToKeep[terminator].flip()});
637  }
638 
639  // Do (3) and (4).
640  BitVector resultsToRemove = resultsToKeep.flip();
641  collectNonLiveValues(nonLiveSet, regionBranchOp.getOperation()->getResults(),
642  resultsToRemove);
643  cl.results.push_back({regionBranchOp.getOperation(), resultsToRemove});
644 }
645 
646 /// Steps to process a `BranchOpInterface` operation:
647 /// Iterate through each successor block of `branchOp`.
648 /// (1) For each successor block, gather all operands from all successors.
649 /// (2) Fetch their associated liveness analysis data and collect for future
650 /// removal.
651 /// (3) Identify and collect the dead operands from the successor block
652 /// as well as their corresponding arguments.
653 
654 static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
655  DenseSet<Value> &nonLiveSet,
656  RDVFinalCleanupList &cl) {
657  unsigned numSuccessors = branchOp->getNumSuccessors();
658 
659  for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
660  Block *successorBlock = branchOp->getSuccessor(succIdx);
661 
662  // Do (1)
663  SuccessorOperands successorOperands =
664  branchOp.getSuccessorOperands(succIdx);
665  SmallVector<Value> operandValues;
666  for (unsigned operandIdx = 0; operandIdx < successorOperands.size();
667  ++operandIdx) {
668  operandValues.push_back(successorOperands[operandIdx]);
669  }
670 
671  // Do (2)
672  BitVector successorNonLive =
673  markLives(operandValues, nonLiveSet, la).flip();
674  collectNonLiveValues(nonLiveSet, successorBlock->getArguments(),
675  successorNonLive);
676 
677  // Do (3)
678  cl.blocks.push_back({successorBlock, successorNonLive});
679  cl.successorOperands.push_back({branchOp, succIdx, successorNonLive});
680  }
681 }
682 
683 /// Removes dead values collected in RDVFinalCleanupList.
684 /// To be run once when all dead values have been collected.
685 static void cleanUpDeadVals(RDVFinalCleanupList &list) {
686  // 1. Operations
687  for (auto &op : list.operations) {
688  op->dropAllUses();
689  op->erase();
690  }
691 
692  // 2. Values
693  for (auto &v : list.values) {
694  v.dropAllUses();
695  }
696 
697  // 3. Functions
698  for (auto &f : list.functions) {
699  f.funcOp.eraseArguments(f.nonLiveArgs);
700  f.funcOp.eraseResults(f.nonLiveRets);
701  }
702 
703  // 4. Operands
704  for (auto &o : list.operands) {
705  o.op->eraseOperands(o.nonLive);
706  }
707 
708  // 5. Results
709  for (auto &r : list.results) {
710  dropUsesAndEraseResults(r.op, r.nonLive);
711  }
712 
713  // 6. Blocks
714  for (auto &b : list.blocks) {
715  // blocks that are accessed via multiple codepaths processed once
716  if (b.b->getNumArguments() != b.nonLiveArgs.size())
717  continue;
718  // it iterates backwards because erase invalidates all successor indexes
719  for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) {
720  if (!b.nonLiveArgs[i])
721  continue;
722  b.b->getArgument(i).dropAllUses();
723  b.b->eraseArgument(i);
724  }
725  }
726 
727  // 7. Successor Operands
728  for (auto &op : list.successorOperands) {
729  SuccessorOperands successorOperands =
730  op.branch.getSuccessorOperands(op.successorIndex);
731  // blocks that are accessed via multiple codepaths processed once
732  if (successorOperands.size() != op.nonLiveOperands.size())
733  continue;
734  // it iterates backwards because erase invalidates all successor indexes
735  for (int i = successorOperands.size() - 1; i >= 0; --i) {
736  if (!op.nonLiveOperands[i])
737  continue;
738  successorOperands.erase(i);
739  }
740  }
741 }
742 
743 struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {
744  void runOnOperation() override;
745 };
746 } // namespace
747 
748 void RemoveDeadValues::runOnOperation() {
749  auto &la = getAnalysis<RunLivenessAnalysis>();
750  Operation *module = getOperation();
751 
752  // Tracks values eligible for erasure - complements liveness analysis to
753  // identify "droppable" values.
754  DenseSet<Value> deadVals;
755 
756  // Maintains a list of Ops, values, branches, etc., slated for cleanup at the
757  // end of this pass.
758  RDVFinalCleanupList finalCleanupList;
759 
760  module->walk([&](Operation *op) {
761  if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
762  processFuncOp(funcOp, module, la, deadVals, finalCleanupList);
763  } else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
764  processRegionBranchOp(regionBranchOp, la, deadVals, finalCleanupList);
765  } else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
766  processBranchOp(branchOp, la, deadVals, finalCleanupList);
767  } else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
768  // Nothing to do here because this is a terminator op and it should be
769  // honored with respect to its parent
770  } else if (isa<CallOpInterface>(op)) {
771  // Nothing to do because this op is associated with a function op and gets
772  // cleaned when the latter is cleaned.
773  } else {
774  processSimpleOp(op, la, deadVals, finalCleanupList);
775  }
776  });
777 
778  cleanUpDeadVals(finalCleanupList);
779 }
780 
781 std::unique_ptr<Pass> mlir::createRemoveDeadValuesPass() {
782  return std::make_unique<RemoveDeadValues>();
783 }
static MutableArrayRef< OpOperand > operandsToOpOperands(OperandRange &operands)
Block represents an ordered list of Operations.
Definition: Block.h:33
void erase()
Unlink this Block from its parent region and delete it.
Definition: Block.cpp:68
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgListType getArguments()
Definition: Block.h:87
Block * getSuccessor(unsigned i)
Definition: Block.cpp:261
This class helps build Operations.
Definition: Builders.h:205
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:765
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
void dropAllUses()
Drop all uses of results of this operation.
Definition: Operation.h:835
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:798
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:67
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:687
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
static constexpr RegionBranchPoint parent()
Returns an instance of RegionBranchPoint representing the parent operation.
This class represents a successor of a region.
ValueRange getSuccessorInputs() const
Return the inputs to the successor that are remapped by the exit values of the current region.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
void push_back(Block *block)
Definition: Region.h:61
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
This class models how operands are forwarded to block arguments in control flow.
void erase(unsigned subStart, unsigned subLen=1)
Erase operands forwarded to the successor.
unsigned size() const
Returns the amount of operands passed to the successor.
This class represents a specific symbol use.
Definition: SymbolTable.h:183
This class implements a range of SymbolRef uses.
Definition: SymbolTable.h:203
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
std::unique_ptr< Pass > createRemoveDeadValuesPass()
Creates an optimization pass to remove dead values.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
This lattice represents, for a given value, whether or not it is "live".
Runs liveness analysis on the IR defined by op.
const Liveness * getLiveness(Value val)