MLIR 22.0.0git
RemoveDeadValues.cpp
Go to the documentation of this file.
1//===- RemoveDeadValues.cpp - Remove Dead Values --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// The goal of this pass is optimization (reducing runtime) by removing
10// unnecessary instructions. Unlike other passes that rely on local information
11// gathered from patterns to accomplish optimization, this pass uses a full
12// analysis of the IR, specifically, liveness analysis, and is thus more
13// powerful.
14//
15// Currently, this pass performs the following optimizations:
16// (A) Removes function arguments that are not live,
17// (B) Removes function return values that are not live across all callers of
18// the function,
19// (C) Removes unneccesary operands, results, region arguments, and region
20// terminator operands of region branch ops, and,
21// (D) Removes simple and region branch ops that have all non-live results and
22// don't affect memory in any way,
23//
24// iff
25//
26// the IR doesn't have any non-function symbol ops, non-call symbol user ops and
27// branch ops.
28//
29// Here, a "simple op" refers to an op that isn't a symbol op, symbol-user op,
30// region branch op, branch op, region branch terminator op, or return-like.
31//
32//===----------------------------------------------------------------------===//
33
37#include "mlir/IR/Builders.h"
39#include "mlir/IR/Dialect.h"
40#include "mlir/IR/Operation.h"
42#include "mlir/IR/SymbolTable.h"
43#include "mlir/IR/Value.h"
44#include "mlir/IR/ValueRange.h"
45#include "mlir/IR/Visitors.h"
50#include "mlir/Pass/Pass.h"
51#include "mlir/Support/LLVM.h"
54#include "llvm/ADT/STLExtras.h"
55#include "llvm/Support/Debug.h"
56#include "llvm/Support/DebugLog.h"
57#include <cassert>
58#include <cstddef>
59#include <memory>
60#include <optional>
61#include <vector>
62
63#define DEBUG_TYPE "remove-dead-values"
64
65namespace mlir {
66#define GEN_PASS_DEF_REMOVEDEADVALUES
67#include "mlir/Transforms/Passes.h.inc"
68} // namespace mlir
69
70using namespace mlir;
71using namespace mlir::dataflow;
72
73//===----------------------------------------------------------------------===//
74// RemoveDeadValues Pass
75//===----------------------------------------------------------------------===//
76
77namespace {
78
79// Set of structures below to be filled with operations and arguments to erase.
80// This is done to separate analysis and tree modification phases,
81// otherwise analysis is operating on half-deleted tree which is incorrect.
82
83struct FunctionToCleanUp {
84 FunctionOpInterface funcOp;
85 BitVector nonLiveArgs;
86 BitVector nonLiveRets;
87};
88
89struct ResultsToCleanup {
90 Operation *op;
91 BitVector nonLive;
92};
93
94struct OperandsToCleanup {
95 Operation *op;
96 BitVector nonLive;
97 Operation *callee =
98 nullptr; // Optional: For CallOpInterface ops, stores the callee function
99};
100
101struct BlockArgsToCleanup {
102 Block *b;
103 BitVector nonLiveArgs;
104};
105
106struct SuccessorOperandsToCleanup {
107 BranchOpInterface branch;
108 unsigned successorIndex;
109 BitVector nonLiveOperands;
110};
111
112struct RDVFinalCleanupList {
113 SmallVector<Operation *> operations;
114 SmallVector<FunctionToCleanUp> functions;
115 SmallVector<OperandsToCleanup> operands;
116 SmallVector<ResultsToCleanup> results;
117 SmallVector<BlockArgsToCleanup> blocks;
118 SmallVector<SuccessorOperandsToCleanup> successorOperands;
119};
120
121// Some helper functions...
122
123/// Return true iff at least one value in `values` is live, given the liveness
124/// information in `la`.
125static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
127 for (Value value : values) {
128 if (nonLiveSet.contains(value)) {
129 LDBG() << "Value " << value << " is already marked non-live (dead)";
130 continue;
131 }
132
133 const Liveness *liveness = la.getLiveness(value);
134 if (!liveness) {
135 LDBG() << "Value " << value
136 << " has no liveness info, conservatively considered live";
137 return true;
138 }
139 if (liveness->isLive) {
140 LDBG() << "Value " << value << " is live according to liveness analysis";
141 return true;
142 } else {
143 LDBG() << "Value " << value << " is dead according to liveness analysis";
144 }
145 }
146 return false;
147}
148
149/// Return a BitVector of size `values.size()` where its i-th bit is 1 iff the
150/// i-th value in `values` is live, given the liveness information in `la`.
151static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
153 BitVector lives(values.size(), true);
154
155 for (auto [index, value] : llvm::enumerate(values)) {
156 if (nonLiveSet.contains(value)) {
157 lives.reset(index);
158 LDBG() << "Value " << value
159 << " is already marked non-live (dead) at index " << index;
160 continue;
161 }
162
163 const Liveness *liveness = la.getLiveness(value);
164 // It is important to note that when `liveness` is null, we can't tell if
165 // `value` is live or not. So, the safe option is to consider it live. Also,
166 // the execution of this pass might create new SSA values when erasing some
167 // of the results of an op and we know that these new values are live
168 // (because they weren't erased) and also their liveness is null because
169 // liveness analysis ran before their creation.
170 if (!liveness) {
171 LDBG() << "Value " << value << " at index " << index
172 << " has no liveness info, conservatively considered live";
173 continue;
174 }
175 if (!liveness->isLive) {
176 lives.reset(index);
177 LDBG() << "Value " << value << " at index " << index
178 << " is dead according to liveness analysis";
179 } else {
180 LDBG() << "Value " << value << " at index " << index
181 << " is live according to liveness analysis";
182 }
183 }
184
185 return lives;
186}
187
188/// Collects values marked as "non-live" in the provided range and inserts them
189/// into the nonLiveSet. A value is considered "non-live" if the corresponding
190/// index in the `nonLive` bit vector is set.
191static void collectNonLiveValues(DenseSet<Value> &nonLiveSet, ValueRange range,
192 const BitVector &nonLive) {
193 for (auto [index, result] : llvm::enumerate(range)) {
194 if (!nonLive[index])
195 continue;
196 nonLiveSet.insert(result);
197 LDBG() << "Marking value " << result << " as non-live (dead) at index "
198 << index;
199 }
200}
201
202/// Drop the uses of the i-th result of `op` and then erase it iff toErase[i]
203/// is 1.
204static void dropUsesAndEraseResults(Operation *op, BitVector toErase) {
205 assert(op->getNumResults() == toErase.size() &&
206 "expected the number of results in `op` and the size of `toErase` to "
207 "be the same");
208 for (auto idx : toErase.set_bits())
209 op->getResult(idx).dropAllUses();
210 IRRewriter rewriter(op);
211 rewriter.eraseOpResults(op, toErase);
212}
213
214/// Convert a list of `Operand`s to a list of `OpOperand`s.
216 OpOperand *values = operands.getBase();
217 SmallVector<OpOperand *> opOperands;
218 for (unsigned i = 0, e = operands.size(); i < e; i++)
219 opOperands.push_back(&values[i]);
220 return opOperands;
221}
222
223/// Process a simple operation `op` using the liveness analysis `la`.
224/// If the operation has no memory effects and none of its results are live:
225/// 1. Add the operation to a list for future removal, and
226/// 2. Mark all its results as non-live values
227///
228/// The operation `op` is assumed to be simple. A simple operation is one that
229/// is NOT:
230/// - Function-like
231/// - Call-like
232/// - A region branch operation
233/// - A branch operation
234/// - A region branch terminator
235/// - Return-like
236static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
237 DenseSet<Value> &nonLiveSet,
238 RDVFinalCleanupList &cl) {
239 // Operations that have dead operands can be erased regardless of their
240 // side effects. The liveness analysis would not have marked an SSA value as
241 // "dead" if it had a side-effecting user that is reachable.
242 bool hasDeadOperand =
243 markLives(op->getOperands(), nonLiveSet, la).flip().any();
244 if (hasDeadOperand) {
245 LDBG() << "Simple op has dead operands, so the op must be dead: "
246 << OpWithFlags(op,
247 OpPrintingFlags().skipRegions().printGenericOpForm());
248 assert(!hasLive(op->getResults(), nonLiveSet, la) &&
249 "expected the op to have no live results");
250 cl.operations.push_back(op);
251 collectNonLiveValues(nonLiveSet, op->getResults(),
252 BitVector(op->getNumResults(), true));
253 return;
254 }
255
256 if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) {
257 LDBG() << "Simple op is not memory effect free or has live results, "
258 "preserving it: "
259 << OpWithFlags(op,
260 OpPrintingFlags().skipRegions().printGenericOpForm());
261 return;
262 }
263
264 LDBG()
265 << "Simple op has all dead results and is memory effect free, scheduling "
266 "for removal: "
267 << OpWithFlags(op, OpPrintingFlags().skipRegions().printGenericOpForm());
268 cl.operations.push_back(op);
269 collectNonLiveValues(nonLiveSet, op->getResults(),
270 BitVector(op->getNumResults(), true));
271}
272
273/// Process a function-like operation `funcOp` using the liveness analysis `la`
274/// and the IR in `module`. If it is not public or external:
275/// (1) Adding its non-live arguments to a list for future removal.
276/// (2) Marking their corresponding operands in its callers for removal.
277/// (3) Identifying and enqueueing unnecessary terminator operands
278/// (return values that are non-live across all callers) for removal.
279/// (4) Enqueueing the non-live arguments and return values for removal.
280/// (5) Collecting the uses of these return values in its callers for future
281/// removal.
282/// (6) Marking all its results as non-live values.
283static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
284 RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
285 RDVFinalCleanupList &cl) {
286 LDBG() << "Processing function op: "
287 << OpWithFlags(funcOp,
288 OpPrintingFlags().skipRegions().printGenericOpForm());
289 if (funcOp.isPublic() || funcOp.isExternal()) {
290 LDBG() << "Function is public or external, skipping: "
291 << funcOp.getOperation()->getName();
292 return;
293 }
294
295 // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
296 SmallVector<Value> arguments(funcOp.getArguments());
297 BitVector nonLiveArgs = markLives(arguments, nonLiveSet, la);
298 nonLiveArgs = nonLiveArgs.flip();
299
300 // Do (1).
301 for (auto [index, arg] : llvm::enumerate(arguments))
302 if (arg && nonLiveArgs[index])
303 nonLiveSet.insert(arg);
304
305 // Do (2). (Skip creating generic operand cleanup entries for call ops.
306 // Call arguments will be removed in the call-site specific segment-aware
307 // cleanup, avoiding generic eraseOperands bitvector mechanics.)
308 SymbolTable::UseRange uses = *funcOp.getSymbolUses(module);
309 for (SymbolTable::SymbolUse use : uses) {
310 Operation *callOp = use.getUser();
311 assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
312 // Push an empty operand cleanup entry so that call-site specific logic in
313 // cleanUpDeadVals runs (it keys off CallOpInterface). The BitVector is
314 // intentionally all false to avoid generic erasure.
315 // Store the funcOp as the callee to avoid expensive symbol lookup later.
316 cl.operands.push_back({callOp, BitVector(callOp->getNumOperands(), false),
317 funcOp.getOperation()});
318 }
319
320 // Do (3).
321 // Get the list of unnecessary terminator operands (return values that are
322 // non-live across all callers) in `nonLiveRets`. There is a very important
323 // subtlety here. Unnecessary terminator operands are NOT the operands of the
324 // terminator that are non-live. Instead, these are the return values of the
325 // callers such that a given return value is non-live across all callers. Such
326 // corresponding operands in the terminator could be live. An example to
327 // demonstrate this:
328 // func.func private @f(%arg0: memref<i32>) -> (i32, i32) {
329 // %c0_i32 = arith.constant 0 : i32
330 // %0 = arith.addi %c0_i32, %c0_i32 : i32
331 // memref.store %0, %arg0[] : memref<i32>
332 // return %c0_i32, %0 : i32, i32
333 // }
334 // func.func @main(%arg0: i32, %arg1: memref<i32>) -> (i32) {
335 // %1:2 = call @f(%arg1) : (memref<i32>) -> i32
336 // return %1#0 : i32
337 // }
338 // Here, we can see that %1#1 is never used. It is non-live. Thus, @f doesn't
339 // need to return %0. But, %0 is live. And, still, we want to stop it from
340 // being returned, in order to optimize our IR. So, this demonstrates how we
341 // can make our optimization strong by even removing a live return value (%0),
342 // since it forwards only to non-live value(s) (%1#1).
343 size_t numReturns = funcOp.getNumResults();
344 BitVector nonLiveRets(numReturns, true);
345 for (SymbolTable::SymbolUse use : uses) {
346 Operation *callOp = use.getUser();
347 assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
348 BitVector liveCallRets = markLives(callOp->getResults(), nonLiveSet, la);
349 nonLiveRets &= liveCallRets.flip();
350 }
351
352 // Note that in the absence of control flow ops forcing the control to go from
353 // the entry (first) block to the other blocks, the control never reaches any
354 // block other than the entry block, because every block has a terminator.
355 for (Block &block : funcOp.getBlocks()) {
356 Operation *returnOp = block.getTerminator();
357 if (!returnOp->hasTrait<OpTrait::ReturnLike>())
358 continue;
359 if (returnOp && returnOp->getNumOperands() == numReturns)
360 cl.operands.push_back({returnOp, nonLiveRets});
361 }
362
363 // Do (4).
364 cl.functions.push_back({funcOp, nonLiveArgs, nonLiveRets});
365
366 // Do (5) and (6).
367 if (numReturns == 0)
368 return;
369 for (SymbolTable::SymbolUse use : uses) {
370 Operation *callOp = use.getUser();
371 assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
372 cl.results.push_back({callOp, nonLiveRets});
373 collectNonLiveValues(nonLiveSet, callOp->getResults(), nonLiveRets);
374 }
375}
376
377/// Process a region branch operation `regionBranchOp` using the liveness
378/// information in `la`. The processing involves two scenarios:
379///
380/// Scenario 1: If the operation has no memory effects and none of its results
381/// are live:
382/// (1') Enqueue all its uses for deletion.
383/// (2') Enqueue the branch itself for deletion.
384///
385/// Scenario 2: Otherwise:
386/// (1) Collect its unnecessary operands (operands forwarded to unnecessary
387/// results or arguments).
388/// (2) Process each of its regions.
389/// (3) Collect the uses of its unnecessary results (results forwarded from
390/// unnecessary operands
391/// or terminator operands).
392/// (4) Add these results to the deletion list.
393///
394/// Processing a region includes:
395/// (a) Collecting the uses of its unnecessary arguments (arguments forwarded
396/// from unnecessary operands
397/// or terminator operands).
398/// (b) Collecting these unnecessary arguments.
399/// (c) Collecting its unnecessary terminator operands (terminator operands
400/// forwarded to unnecessary results
401/// or arguments).
402///
403/// Value Flow Note: In this operation, values flow as follows:
404/// - From operands and terminator operands (successor operands)
405/// - To arguments and results (successor inputs).
406static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
408 DenseSet<Value> &nonLiveSet,
409 RDVFinalCleanupList &cl) {
410 LDBG() << "Processing region branch op: "
411 << OpWithFlags(regionBranchOp,
412 OpPrintingFlags().skipRegions().printGenericOpForm());
413
414 // Scenario 1. This is the only case where the entire `regionBranchOp`
415 // is removed. It will not happen in any other scenario. Note that in this
416 // case, a non-forwarded operand of `regionBranchOp` could be live/non-live.
417 // It could never be live because of this op but its liveness could have been
418 // attributed to something else.
419 // Do (1') and (2').
420 if (isMemoryEffectFree(regionBranchOp.getOperation()) &&
421 !hasLive(regionBranchOp->getResults(), nonLiveSet, la)) {
422 cl.operations.push_back(regionBranchOp.getOperation());
423 return;
424 }
425
426 // Mark live results of `regionBranchOp` in `liveResults`.
427 auto markLiveResults = [&](BitVector &liveResults) {
428 liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la);
429 };
430
431 // Mark live arguments in the regions of `regionBranchOp` in `liveArgs`.
432 auto markLiveArgs = [&](DenseMap<Region *, BitVector> &liveArgs) {
433 for (Region &region : regionBranchOp->getRegions()) {
434 if (region.empty())
435 continue;
436 SmallVector<Value> arguments(region.front().getArguments());
437 BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la);
438 liveArgs[&region] = regionLiveArgs;
439 }
440 };
441
442 // Return the successors of `region` if the latter is not null. Else return
443 // the successors of `regionBranchOp`.
444 auto getSuccessors = [&](RegionBranchPoint point) {
446 regionBranchOp.getSuccessorRegions(point, successors);
447 return successors;
448 };
449
450 // Return the operands of `terminator` that are forwarded to `successor` if
451 // the former is not null. Else return the operands of `regionBranchOp`
452 // forwarded to `successor`.
453 auto getForwardedOpOperands = [&](RegionBranchPoint src,
454 const RegionSuccessor &successor) {
456 regionBranchOp.getSuccessorOperands(src, successor));
457 return opOperands;
458 };
459
460 // Mark the non-forwarded operands of `regionBranchOp` in
461 // `nonForwardedOperands`.
462 auto markNonForwardedOperands = [&](BitVector &nonForwardedOperands) {
463 nonForwardedOperands.resize(regionBranchOp->getNumOperands(), true);
464 for (const RegionSuccessor &successor :
465 getSuccessors(RegionBranchPoint::parent())) {
466 for (OpOperand *opOperand :
467 getForwardedOpOperands(RegionBranchPoint::parent(), successor))
468 nonForwardedOperands.reset(opOperand->getOperandNumber());
469 }
470 };
471
472 // Mark the non-forwarded terminator operands of the various regions of
473 // `regionBranchOp` in `nonForwardedRets`.
474 auto markNonForwardedReturnValues =
475 [&](DenseMap<Operation *, BitVector> &nonForwardedRets) {
476 for (Region &region : regionBranchOp->getRegions()) {
477 if (region.empty())
478 continue;
479 // TODO: this isn't correct in face of multiple terminators.
480 auto terminator = cast<RegionBranchTerminatorOpInterface>(
481 region.front().getTerminator());
482 nonForwardedRets[terminator] =
483 BitVector(terminator->getNumOperands(), true);
484 for (const RegionSuccessor &successor : getSuccessors(terminator)) {
485 for (OpOperand *opOperand : getForwardedOpOperands(
486 RegionBranchPoint(terminator), successor))
487 nonForwardedRets[terminator].reset(opOperand->getOperandNumber());
488 }
489 }
490 };
491
492 // Update `valuesToKeep` (which is expected to correspond to operands or
493 // terminator operands) based on `resultsToKeep` and `argsToKeep`, given
494 // `region`. When `valuesToKeep` correspond to operands, `region` is null.
495 // Else, `region` is the parent region of the terminator.
496 auto updateOperandsOrTerminatorOperandsToKeep =
497 [&](BitVector &valuesToKeep, BitVector &resultsToKeep,
498 DenseMap<Region *, BitVector> &argsToKeep, Region *region = nullptr) {
499 Operation *terminator =
500 region ? region->front().getTerminator() : nullptr;
501 RegionBranchPoint point =
502 terminator
504 cast<RegionBranchTerminatorOpInterface>(terminator))
505 : RegionBranchPoint::parent();
506
507 for (const RegionSuccessor &successor : getSuccessors(point)) {
508 Region *successorRegion = successor.getSuccessor();
509 for (auto [opOperand, input] :
510 llvm::zip(getForwardedOpOperands(point, successor),
511 successor.getSuccessorInputs())) {
512 size_t operandNum = opOperand->getOperandNumber();
513 bool updateBasedOn =
514 successorRegion
515 ? argsToKeep[successorRegion]
516 [cast<BlockArgument>(input).getArgNumber()]
517 : resultsToKeep[cast<OpResult>(input).getResultNumber()];
518 valuesToKeep[operandNum] = valuesToKeep[operandNum] | updateBasedOn;
519 }
520 }
521 };
522
523 // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep` and
524 // `terminatorOperandsToKeep`. Store true in `resultsOrArgsToKeepChanged` if a
525 // value is modified, else, false.
526 auto recomputeResultsAndArgsToKeep =
527 [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep,
528 BitVector &operandsToKeep,
529 DenseMap<Operation *, BitVector> &terminatorOperandsToKeep,
530 bool &resultsOrArgsToKeepChanged) {
531 resultsOrArgsToKeepChanged = false;
532
533 // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep`.
534 for (const RegionSuccessor &successor :
535 getSuccessors(RegionBranchPoint::parent())) {
536 Region *successorRegion = successor.getSuccessor();
537 for (auto [opOperand, input] :
538 llvm::zip(getForwardedOpOperands(RegionBranchPoint::parent(),
539 successor),
540 successor.getSuccessorInputs())) {
541 bool recomputeBasedOn =
542 operandsToKeep[opOperand->getOperandNumber()];
543 bool toRecompute =
544 successorRegion
545 ? argsToKeep[successorRegion]
546 [cast<BlockArgument>(input).getArgNumber()]
547 : resultsToKeep[cast<OpResult>(input).getResultNumber()];
548 if (!toRecompute && recomputeBasedOn)
549 resultsOrArgsToKeepChanged = true;
550 if (successorRegion) {
551 argsToKeep[successorRegion][cast<BlockArgument>(input)
552 .getArgNumber()] =
553 argsToKeep[successorRegion]
554 [cast<BlockArgument>(input).getArgNumber()] |
555 recomputeBasedOn;
556 } else {
557 resultsToKeep[cast<OpResult>(input).getResultNumber()] =
558 resultsToKeep[cast<OpResult>(input).getResultNumber()] |
559 recomputeBasedOn;
560 }
561 }
562 }
563
564 // Recompute `resultsToKeep` and `argsToKeep` based on
565 // `terminatorOperandsToKeep`.
566 for (Region &region : regionBranchOp->getRegions()) {
567 if (region.empty())
568 continue;
569 auto terminator = cast<RegionBranchTerminatorOpInterface>(
570 region.front().getTerminator());
571 for (const RegionSuccessor &successor : getSuccessors(terminator)) {
572 Region *successorRegion = successor.getSuccessor();
573 for (auto [opOperand, input] :
574 llvm::zip(getForwardedOpOperands(RegionBranchPoint(terminator),
575 successor),
576 successor.getSuccessorInputs())) {
577 bool recomputeBasedOn =
578 terminatorOperandsToKeep[region.back().getTerminator()]
579 [opOperand->getOperandNumber()];
580 bool toRecompute =
581 successorRegion
582 ? argsToKeep[successorRegion]
583 [cast<BlockArgument>(input).getArgNumber()]
584 : resultsToKeep[cast<OpResult>(input).getResultNumber()];
585 if (!toRecompute && recomputeBasedOn)
586 resultsOrArgsToKeepChanged = true;
587 if (successorRegion) {
588 argsToKeep[successorRegion][cast<BlockArgument>(input)
589 .getArgNumber()] =
590 argsToKeep[successorRegion]
591 [cast<BlockArgument>(input).getArgNumber()] |
592 recomputeBasedOn;
593 } else {
594 resultsToKeep[cast<OpResult>(input).getResultNumber()] =
595 resultsToKeep[cast<OpResult>(input).getResultNumber()] |
596 recomputeBasedOn;
597 }
598 }
599 }
600 }
601 };
602
603 // Mark the values that we want to keep in `resultsToKeep`, `argsToKeep`,
604 // `operandsToKeep`, and `terminatorOperandsToKeep`.
605 auto markValuesToKeep =
606 [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep,
607 BitVector &operandsToKeep,
608 DenseMap<Operation *, BitVector> &terminatorOperandsToKeep) {
609 bool resultsOrArgsToKeepChanged = true;
610 // We keep updating and recomputing the values until we reach a point
611 // where they stop changing.
612 while (resultsOrArgsToKeepChanged) {
613 // Update the operands that need to be kept.
614 updateOperandsOrTerminatorOperandsToKeep(operandsToKeep,
615 resultsToKeep, argsToKeep);
616
617 // Update the terminator operands that need to be kept.
618 for (Region &region : regionBranchOp->getRegions()) {
619 if (region.empty())
620 continue;
621 updateOperandsOrTerminatorOperandsToKeep(
622 terminatorOperandsToKeep[region.back().getTerminator()],
623 resultsToKeep, argsToKeep, &region);
624 }
625
626 // Recompute the results and arguments that need to be kept.
627 recomputeResultsAndArgsToKeep(
628 resultsToKeep, argsToKeep, operandsToKeep,
629 terminatorOperandsToKeep, resultsOrArgsToKeepChanged);
630 }
631 };
632
633 // Scenario 2.
634 // At this point, we know that every non-forwarded operand of `regionBranchOp`
635 // is live.
636
637 // Stores the results of `regionBranchOp` that we want to keep.
638 BitVector resultsToKeep;
639 // Stores the mapping from regions of `regionBranchOp` to their arguments that
640 // we want to keep.
642 // Stores the operands of `regionBranchOp` that we want to keep.
643 BitVector operandsToKeep;
644 // Stores the mapping from region terminators in `regionBranchOp` to their
645 // operands that we want to keep.
646 DenseMap<Operation *, BitVector> terminatorOperandsToKeep;
647
648 // Initializing the above variables...
649
650 // The live results of `regionBranchOp` definitely need to be kept.
651 markLiveResults(resultsToKeep);
652 // Similarly, the live arguments of the regions in `regionBranchOp` definitely
653 // need to be kept.
654 markLiveArgs(argsToKeep);
655 // The non-forwarded operands of `regionBranchOp` definitely need to be kept.
656 // A live forwarded operand can be removed but no non-forwarded operand can be
657 // removed since it "controls" the flow of data in this control flow op.
658 markNonForwardedOperands(operandsToKeep);
659 // Similarly, the non-forwarded terminator operands of the regions in
660 // `regionBranchOp` definitely need to be kept.
661 markNonForwardedReturnValues(terminatorOperandsToKeep);
662
663 // Mark the values (results, arguments, operands, and terminator operands)
664 // that we want to keep.
665 markValuesToKeep(resultsToKeep, argsToKeep, operandsToKeep,
666 terminatorOperandsToKeep);
667
668 // Do (1).
669 cl.operands.push_back({regionBranchOp, operandsToKeep.flip()});
670
671 // Do (2.a) and (2.b).
672 for (Region &region : regionBranchOp->getRegions()) {
673 if (region.empty())
674 continue;
675 BitVector argsToRemove = argsToKeep[&region].flip();
676 cl.blocks.push_back({&region.front(), argsToRemove});
677 collectNonLiveValues(nonLiveSet, region.front().getArguments(),
678 argsToRemove);
679 }
680
681 // Do (2.c).
682 for (Region &region : regionBranchOp->getRegions()) {
683 if (region.empty())
684 continue;
685 Operation *terminator = region.front().getTerminator();
686 cl.operands.push_back(
687 {terminator, terminatorOperandsToKeep[terminator].flip()});
688 }
689
690 // Do (3) and (4).
691 BitVector resultsToRemove = resultsToKeep.flip();
692 collectNonLiveValues(nonLiveSet, regionBranchOp.getOperation()->getResults(),
693 resultsToRemove);
694 cl.results.push_back({regionBranchOp.getOperation(), resultsToRemove});
695}
696
697/// Steps to process a `BranchOpInterface` operation:
698///
699/// When a non-forwarded operand is dead (e.g., the condition value of a
700/// conditional branch op), the entire operation is dead.
701///
702/// Otherwise, iterate through each successor block of `branchOp`.
703/// (1) For each successor block, gather all operands from all successors.
704/// (2) Fetch their associated liveness analysis data and collect for future
705/// removal.
706/// (3) Identify and collect the dead operands from the successor block
707/// as well as their corresponding arguments.
708
709static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
710 DenseSet<Value> &nonLiveSet,
711 RDVFinalCleanupList &cl) {
712 LDBG() << "Processing branch op: " << *branchOp;
713
714 // Check for dead non-forwarded operands.
715 BitVector deadNonForwardedOperands =
716 markLives(branchOp->getOperands(), nonLiveSet, la).flip();
717 unsigned numSuccessors = branchOp->getNumSuccessors();
718 for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
719 SuccessorOperands successorOperands =
720 branchOp.getSuccessorOperands(succIdx);
721 // Remove all non-forwarded operands from the bit vector.
722 for (OpOperand &opOperand : successorOperands.getMutableForwardedOperands())
723 deadNonForwardedOperands[opOperand.getOperandNumber()] = false;
724 }
725 if (deadNonForwardedOperands.any()) {
726 cl.operations.push_back(branchOp.getOperation());
727 return;
728 }
729
730 for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
731 Block *successorBlock = branchOp->getSuccessor(succIdx);
732
733 // Do (1)
734 SuccessorOperands successorOperands =
735 branchOp.getSuccessorOperands(succIdx);
736 SmallVector<Value> operandValues;
737 for (unsigned operandIdx = 0; operandIdx < successorOperands.size();
738 ++operandIdx) {
739 operandValues.push_back(successorOperands[operandIdx]);
740 }
741
742 // Do (2)
743 BitVector successorNonLive =
744 markLives(operandValues, nonLiveSet, la).flip();
745 collectNonLiveValues(nonLiveSet, successorBlock->getArguments(),
746 successorNonLive);
747
748 // Do (3)
749 cl.blocks.push_back({successorBlock, successorNonLive});
750 cl.successorOperands.push_back({branchOp, succIdx, successorNonLive});
751 }
752}
753
754/// Removes dead values collected in RDVFinalCleanupList.
755/// To be run once when all dead values have been collected.
756static void cleanUpDeadVals(RDVFinalCleanupList &list) {
757 LDBG() << "Starting cleanup of dead values...";
758
759 // 1. Blocks, We must remove the block arguments and successor operands before
760 // deleting the operation, as they may reside in the region operation.
761 LDBG() << "Cleaning up " << list.blocks.size() << " block argument lists";
762 for (auto &b : list.blocks) {
763 // blocks that are accessed via multiple codepaths processed once
764 if (b.b->getNumArguments() != b.nonLiveArgs.size())
765 continue;
766 LDBG_OS([&](raw_ostream &os) {
767 os << "Erasing non-live arguments [";
768 llvm::interleaveComma(b.nonLiveArgs.set_bits(), os);
769 os << "] from block #" << b.b->computeBlockNumber() << " in region #"
770 << b.b->getParent()->getRegionNumber() << " of operation "
771 << OpWithFlags(b.b->getParent()->getParentOp(),
772 OpPrintingFlags().skipRegions().printGenericOpForm());
773 });
774 // Note: Iterate from the end to make sure that that indices of not yet
775 // processes arguments do not change.
776 for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) {
777 if (!b.nonLiveArgs[i])
778 continue;
779 b.b->getArgument(i).dropAllUses();
780 b.b->eraseArgument(i);
781 }
782 }
783
784 // 2. Successor Operands
785 LDBG() << "Cleaning up " << list.successorOperands.size()
786 << " successor operand lists";
787 for (auto &op : list.successorOperands) {
788 SuccessorOperands successorOperands =
789 op.branch.getSuccessorOperands(op.successorIndex);
790 // blocks that are accessed via multiple codepaths processed once
791 if (successorOperands.size() != op.nonLiveOperands.size())
792 continue;
793 LDBG_OS([&](raw_ostream &os) {
794 os << "Erasing non-live successor operands [";
795 llvm::interleaveComma(op.nonLiveOperands.set_bits(), os);
796 os << "] from successor " << op.successorIndex << " of branch: "
797 << OpWithFlags(op.branch.getOperation(),
798 OpPrintingFlags().skipRegions().printGenericOpForm());
799 });
800 // it iterates backwards because erase invalidates all successor indexes
801 for (int i = successorOperands.size() - 1; i >= 0; --i) {
802 if (!op.nonLiveOperands[i])
803 continue;
804 successorOperands.erase(i);
805 }
806 }
807
808 // 3. Operations
809 LDBG() << "Cleaning up " << list.operations.size() << " operations";
810 for (Operation *op : list.operations) {
811 LDBG() << "Erasing operation: "
812 << OpWithFlags(op,
813 OpPrintingFlags().skipRegions().printGenericOpForm());
814 if (op->hasTrait<OpTrait::IsTerminator>()) {
815 // When erasing a terminator, insert an unreachable op in its place.
816 OpBuilder b(op);
817 ub::UnreachableOp::create(b, op->getLoc());
818 }
819 op->dropAllUses();
820 op->erase();
821 }
822
823 // 4. Functions
824 LDBG() << "Cleaning up " << list.functions.size() << " functions";
825 // Record which function arguments were erased so we can shrink call-site
826 // argument segments for CallOpInterface operations (e.g. ops using
827 // AttrSizedOperandSegments) in the next phase.
829 for (auto &f : list.functions) {
830 LDBG() << "Cleaning up function: " << f.funcOp.getOperation()->getName()
831 << " (" << f.funcOp.getOperation() << ")";
832 LDBG_OS([&](raw_ostream &os) {
833 os << " Erasing non-live arguments [";
834 llvm::interleaveComma(f.nonLiveArgs.set_bits(), os);
835 os << "]\n";
836 os << " Erasing non-live return values [";
837 llvm::interleaveComma(f.nonLiveRets.set_bits(), os);
838 os << "]";
839 });
840 // Drop all uses of the dead arguments.
841 for (auto deadIdx : f.nonLiveArgs.set_bits())
842 f.funcOp.getArgument(deadIdx).dropAllUses();
843 // Some functions may not allow erasing arguments or results. These calls
844 // return failure in such cases without modifying the function, so it's okay
845 // to proceed.
846 if (succeeded(f.funcOp.eraseArguments(f.nonLiveArgs))) {
847 // Record only if we actually erased something.
848 if (f.nonLiveArgs.any())
849 erasedFuncArgs.try_emplace(f.funcOp.getOperation(), f.nonLiveArgs);
850 }
851 (void)f.funcOp.eraseResults(f.nonLiveRets);
852 }
853
854 // 5. Operands
855 LDBG() << "Cleaning up " << list.operands.size() << " operand lists";
856 for (OperandsToCleanup &o : list.operands) {
857 // Handle call-specific cleanup only when we have a cached callee reference.
858 // This avoids expensive symbol lookup and is defensive against future
859 // changes.
860 bool handledAsCall = false;
861 if (o.callee && isa<CallOpInterface>(o.op)) {
862 auto call = cast<CallOpInterface>(o.op);
863 auto it = erasedFuncArgs.find(o.callee);
864 if (it != erasedFuncArgs.end()) {
865 const BitVector &deadArgIdxs = it->second;
866 MutableOperandRange args = call.getArgOperandsMutable();
867 // First, erase the call arguments corresponding to erased callee
868 // args. We iterate backwards to preserve indices.
869 for (unsigned argIdx : llvm::reverse(deadArgIdxs.set_bits()))
870 args.erase(argIdx);
871 // If this operand cleanup entry also has a generic nonLive bitvector,
872 // clear bits for call arguments we already erased above to avoid
873 // double-erasing (which could impact other segments of ops with
874 // AttrSizedOperandSegments).
875 if (o.nonLive.any()) {
876 // Map the argument logical index to the operand number(s) recorded.
877 int operandOffset = call.getArgOperands().getBeginOperandIndex();
878 for (int argIdx : deadArgIdxs.set_bits()) {
879 int operandNumber = operandOffset + argIdx;
880 if (operandNumber < static_cast<int>(o.nonLive.size()))
881 o.nonLive.reset(operandNumber);
882 }
883 }
884 handledAsCall = true;
885 }
886 }
887 // Perform generic operand erasure for:
888 // - Non-call operations
889 // - Call operations without cached callee (where handledAsCall is false)
890 // But skip call operations that were already handled via segment-aware path
891 if (!handledAsCall && o.nonLive.any()) {
892 LDBG_OS([&](raw_ostream &os) {
893 os << "Erasing non-live operands [";
894 llvm::interleaveComma(o.nonLive.set_bits(), os);
895 os << "] from operation: "
896 << OpWithFlags(o.op,
897 OpPrintingFlags().skipRegions().printGenericOpForm());
898 });
899 o.op->eraseOperands(o.nonLive);
900 }
901 }
902
903 // 6. Results
904 LDBG() << "Cleaning up " << list.results.size() << " result lists";
905 for (auto &r : list.results) {
906 LDBG_OS([&](raw_ostream &os) {
907 os << "Erasing non-live results [";
908 llvm::interleaveComma(r.nonLive.set_bits(), os);
909 os << "] from operation: "
910 << OpWithFlags(r.op,
911 OpPrintingFlags().skipRegions().printGenericOpForm());
912 });
913 dropUsesAndEraseResults(r.op, r.nonLive);
914 }
915 LDBG() << "Finished cleanup of dead values";
916}
917
918struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {
919 void runOnOperation() override;
920};
921} // namespace
922
923void RemoveDeadValues::runOnOperation() {
924 auto &la = getAnalysis<RunLivenessAnalysis>();
925 Operation *module = getOperation();
926
927 // Tracks values eligible for erasure - complements liveness analysis to
928 // identify "droppable" values.
929 DenseSet<Value> deadVals;
930
931 // Maintains a list of Ops, values, branches, etc., slated for cleanup at the
932 // end of this pass.
933 RDVFinalCleanupList finalCleanupList;
934
935 module->walk([&](Operation *op) {
936 if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
937 processFuncOp(funcOp, module, la, deadVals, finalCleanupList);
938 } else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
939 processRegionBranchOp(regionBranchOp, la, deadVals, finalCleanupList);
940 } else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
941 processBranchOp(branchOp, la, deadVals, finalCleanupList);
942 } else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
943 // Nothing to do here because this is a terminator op and it should be
944 // honored with respect to its parent
945 } else if (isa<CallOpInterface>(op)) {
946 // Nothing to do because this op is associated with a function op and gets
947 // cleaned when the latter is cleaned.
948 } else {
949 processSimpleOp(op, la, deadVals, finalCleanupList);
950 }
951 });
952
953 cleanUpDeadVals(finalCleanupList);
954}
955
956std::unique_ptr<Pass> mlir::createRemoveDeadValuesPass() {
957 return std::make_unique<RemoveDeadValues>();
958}
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static MutableArrayRef< OpOperand > operandsToOpOperands(OperandRange &operands)
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgListType getArguments()
Definition Block.h:97
Block * getSuccessor(unsigned i)
Definition Block.cpp:274
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:118
void erase(unsigned subStart, unsigned subLen=1)
Erase the operands within the given sub-range.
This class helps build Operations.
Definition Builders.h:207
This class represents an operand of an operation.
Definition Value.h:257
Set of flags used to control the behavior of the various IR print methods (e.g.
This class provides the API for ops that are known to be terminators.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition Operation.h:1111
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
void dropAllUses()
Drop all uses of results of this operation.
Definition Operation.h:834
void eraseOperands(unsigned idx, unsigned length=1)
Erase the operands starting at position idx and ending at position 'idx'+'length'.
Definition Operation.h:360
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
unsigned getNumOperands()
Definition Operation.h:346
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
static constexpr RegionBranchPoint parent()
Returns an instance of RegionBranchPoint representing the parent operation.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
This class models how operands are forwarded to block arguments in control flow.
void erase(unsigned subStart, unsigned subLen=1)
Erase operands forwarded to the successor.
MutableOperandRange getMutableForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
unsigned size() const
Returns the amount of operands passed to the successor.
This class represents a specific symbol use.
This class implements a range of SymbolRef uses.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void dropAllUses()
Drop all uses of this object from their respective owners.
Definition Value.h:144
Include the generated interface declarations.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
std::unique_ptr< Pass > createRemoveDeadValuesPass()
Creates an optimization pass to remove dead values.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
This trait indicates that a terminator operation is "return-like".
This lattice represents, for a given value, whether or not it is "live".
Runs liveness analysis on the IR defined by op.
const Liveness * getLiveness(Value val)