MLIR 22.0.0git
RemoveDeadValues.cpp
Go to the documentation of this file.
1//===- RemoveDeadValues.cpp - Remove Dead Values --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// The goal of this pass is optimization (reducing runtime) by removing
10// unnecessary instructions. Unlike other passes that rely on local information
11// gathered from patterns to accomplish optimization, this pass uses a full
12// analysis of the IR, specifically, liveness analysis, and is thus more
13// powerful.
14//
15// Currently, this pass performs the following optimizations:
16// (A) Removes function arguments that are not live,
17// (B) Removes function return values that are not live across all callers of
18// the function,
19// (C) Removes unneccesary operands, results, region arguments, and region
20// terminator operands of region branch ops, and,
21// (D) Removes simple and region branch ops that have all non-live results and
22// don't affect memory in any way,
23//
24// iff
25//
26// the IR doesn't have any non-function symbol ops, non-call symbol user ops and
27// branch ops.
28//
29// Here, a "simple op" refers to an op that isn't a symbol op, symbol-user op,
30// region branch op, branch op, region branch terminator op, or return-like.
31//
32//===----------------------------------------------------------------------===//
33
37#include "mlir/IR/Builders.h"
39#include "mlir/IR/Dialect.h"
40#include "mlir/IR/Operation.h"
42#include "mlir/IR/SymbolTable.h"
43#include "mlir/IR/Value.h"
44#include "mlir/IR/ValueRange.h"
45#include "mlir/IR/Visitors.h"
50#include "mlir/Pass/Pass.h"
51#include "mlir/Support/LLVM.h"
54#include "llvm/ADT/STLExtras.h"
55#include "llvm/Support/Debug.h"
56#include "llvm/Support/DebugLog.h"
57#include <cassert>
58#include <cstddef>
59#include <memory>
60#include <optional>
61#include <vector>
62
63#define DEBUG_TYPE "remove-dead-values"
64
65namespace mlir {
66#define GEN_PASS_DEF_REMOVEDEADVALUES
67#include "mlir/Transforms/Passes.h.inc"
68} // namespace mlir
69
70using namespace mlir;
71using namespace mlir::dataflow;
72
73//===----------------------------------------------------------------------===//
74// RemoveDeadValues Pass
75//===----------------------------------------------------------------------===//
76
77namespace {
78
79// Set of structures below to be filled with operations and arguments to erase.
80// This is done to separate analysis and tree modification phases,
81// otherwise analysis is operating on half-deleted tree which is incorrect.
82
83struct FunctionToCleanUp {
84 FunctionOpInterface funcOp;
85 BitVector nonLiveArgs;
86 BitVector nonLiveRets;
87};
88
89struct OperationToCleanup {
90 Operation *op;
91 BitVector nonLive;
92 Operation *callee =
93 nullptr; // Optional: For CallOpInterface ops, stores the callee function
94};
95
96struct BlockArgsToCleanup {
97 Block *b;
98 BitVector nonLiveArgs;
99};
100
101struct SuccessorOperandsToCleanup {
102 BranchOpInterface branch;
103 unsigned successorIndex;
104 BitVector nonLiveOperands;
105};
106
107struct RDVFinalCleanupList {
108 SmallVector<Operation *> operations;
109 SmallVector<Value> values;
110 SmallVector<FunctionToCleanUp> functions;
111 SmallVector<OperationToCleanup> operands;
112 SmallVector<OperationToCleanup> results;
113 SmallVector<BlockArgsToCleanup> blocks;
114 SmallVector<SuccessorOperandsToCleanup> successorOperands;
115};
116
117// Some helper functions...
118
119/// Return true iff at least one value in `values` is live, given the liveness
120/// information in `la`.
121static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
123 for (Value value : values) {
124 if (nonLiveSet.contains(value)) {
125 LDBG() << "Value " << value << " is already marked non-live (dead)";
126 continue;
127 }
128
129 const Liveness *liveness = la.getLiveness(value);
130 if (!liveness) {
131 LDBG() << "Value " << value
132 << " has no liveness info, conservatively considered live";
133 return true;
134 }
135 if (liveness->isLive) {
136 LDBG() << "Value " << value << " is live according to liveness analysis";
137 return true;
138 } else {
139 LDBG() << "Value " << value << " is dead according to liveness analysis";
140 }
141 }
142 return false;
143}
144
145/// Return a BitVector of size `values.size()` where its i-th bit is 1 iff the
146/// i-th value in `values` is live, given the liveness information in `la`.
147static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
149 BitVector lives(values.size(), true);
150
151 for (auto [index, value] : llvm::enumerate(values)) {
152 if (nonLiveSet.contains(value)) {
153 lives.reset(index);
154 LDBG() << "Value " << value
155 << " is already marked non-live (dead) at index " << index;
156 continue;
157 }
158
159 const Liveness *liveness = la.getLiveness(value);
160 // It is important to note that when `liveness` is null, we can't tell if
161 // `value` is live or not. So, the safe option is to consider it live. Also,
162 // the execution of this pass might create new SSA values when erasing some
163 // of the results of an op and we know that these new values are live
164 // (because they weren't erased) and also their liveness is null because
165 // liveness analysis ran before their creation.
166 if (!liveness) {
167 LDBG() << "Value " << value << " at index " << index
168 << " has no liveness info, conservatively considered live";
169 continue;
170 }
171 if (!liveness->isLive) {
172 lives.reset(index);
173 LDBG() << "Value " << value << " at index " << index
174 << " is dead according to liveness analysis";
175 } else {
176 LDBG() << "Value " << value << " at index " << index
177 << " is live according to liveness analysis";
178 }
179 }
180
181 return lives;
182}
183
184/// Collects values marked as "non-live" in the provided range and inserts them
185/// into the nonLiveSet. A value is considered "non-live" if the corresponding
186/// index in the `nonLive` bit vector is set.
187static void collectNonLiveValues(DenseSet<Value> &nonLiveSet, ValueRange range,
188 const BitVector &nonLive) {
189 for (auto [index, result] : llvm::enumerate(range)) {
190 if (!nonLive[index])
191 continue;
192 nonLiveSet.insert(result);
193 LDBG() << "Marking value " << result << " as non-live (dead) at index "
194 << index;
195 }
196}
197
198/// Drop the uses of the i-th result of `op` and then erase it iff toErase[i]
199/// is 1.
200static void dropUsesAndEraseResults(Operation *op, BitVector toErase) {
201 assert(op->getNumResults() == toErase.size() &&
202 "expected the number of results in `op` and the size of `toErase` to "
203 "be the same");
204
205 std::vector<Type> newResultTypes;
206 for (OpResult result : op->getResults())
207 if (!toErase[result.getResultNumber()])
208 newResultTypes.push_back(result.getType());
209 OpBuilder builder(op);
210 builder.setInsertionPointAfter(op);
211 OperationState state(op->getLoc(), op->getName().getStringRef(),
212 op->getOperands(), newResultTypes, op->getAttrs());
213 for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i)
214 state.addRegion();
215 Operation *newOp = builder.create(state);
216 for (const auto &[index, region] : llvm::enumerate(op->getRegions())) {
217 Region &newRegion = newOp->getRegion(index);
218 // Move all blocks of `region` into `newRegion`.
219 Block *temp = new Block();
220 newRegion.push_back(temp);
221 while (!region.empty())
222 region.front().moveBefore(temp);
223 temp->erase();
224 }
225
226 unsigned indexOfNextNewCallOpResultToReplace = 0;
227 for (auto [index, result] : llvm::enumerate(op->getResults())) {
228 assert(result && "expected result to be non-null");
229 if (toErase[index]) {
230 result.dropAllUses();
231 } else {
232 result.replaceAllUsesWith(
233 newOp->getResult(indexOfNextNewCallOpResultToReplace++));
234 }
235 }
236 op->erase();
237}
238
239/// Convert a list of `Operand`s to a list of `OpOperand`s.
241 OpOperand *values = operands.getBase();
242 SmallVector<OpOperand *> opOperands;
243 for (unsigned i = 0, e = operands.size(); i < e; i++)
244 opOperands.push_back(&values[i]);
245 return opOperands;
246}
247
248/// Process a simple operation `op` using the liveness analysis `la`.
249/// If the operation has no memory effects and none of its results are live:
250/// 1. Add the operation to a list for future removal, and
251/// 2. Mark all its results as non-live values
252///
253/// The operation `op` is assumed to be simple. A simple operation is one that
254/// is NOT:
255/// - Function-like
256/// - Call-like
257/// - A region branch operation
258/// - A branch operation
259/// - A region branch terminator
260/// - Return-like
261static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
262 DenseSet<Value> &nonLiveSet,
263 RDVFinalCleanupList &cl) {
264 // Operations that have dead operands can be erased regardless of their
265 // side effects. The liveness analysis would not have marked an SSA value as
266 // "dead" if it had a side-effecting user that is reachable.
267 bool hasDeadOperand =
268 markLives(op->getOperands(), nonLiveSet, la).flip().any();
269 if (hasDeadOperand) {
270 LDBG() << "Simple op has dead operands, so the op must be dead: "
271 << OpWithFlags(op, OpPrintingFlags().skipRegions());
272 assert(!hasLive(op->getResults(), nonLiveSet, la) &&
273 "expected the op to have no live results");
274 cl.operations.push_back(op);
275 collectNonLiveValues(nonLiveSet, op->getResults(),
276 BitVector(op->getNumResults(), true));
277 return;
278 }
279
280 if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) {
281 LDBG() << "Simple op is not memory effect free or has live results, "
282 "preserving it: "
283 << OpWithFlags(op, OpPrintingFlags().skipRegions());
284 return;
285 }
286
287 LDBG()
288 << "Simple op has all dead results and is memory effect free, scheduling "
289 "for removal: "
290 << OpWithFlags(op, OpPrintingFlags().skipRegions());
291 cl.operations.push_back(op);
292 collectNonLiveValues(nonLiveSet, op->getResults(),
293 BitVector(op->getNumResults(), true));
294}
295
296/// Process a function-like operation `funcOp` using the liveness analysis `la`
297/// and the IR in `module`. If it is not public or external:
298/// (1) Adding its non-live arguments to a list for future removal.
299/// (2) Marking their corresponding operands in its callers for removal.
300/// (3) Identifying and enqueueing unnecessary terminator operands
301/// (return values that are non-live across all callers) for removal.
302/// (4) Enqueueing the non-live arguments and return values for removal.
303/// (5) Collecting the uses of these return values in its callers for future
304/// removal.
305/// (6) Marking all its results as non-live values.
306static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
307 RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
308 RDVFinalCleanupList &cl) {
309 LDBG() << "Processing function op: "
310 << OpWithFlags(funcOp, OpPrintingFlags().skipRegions());
311 if (funcOp.isPublic() || funcOp.isExternal()) {
312 LDBG() << "Function is public or external, skipping: "
313 << funcOp.getOperation()->getName();
314 return;
315 }
316
317 // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
318 SmallVector<Value> arguments(funcOp.getArguments());
319 BitVector nonLiveArgs = markLives(arguments, nonLiveSet, la);
320 nonLiveArgs = nonLiveArgs.flip();
321
322 // Do (1).
323 for (auto [index, arg] : llvm::enumerate(arguments))
324 if (arg && nonLiveArgs[index]) {
325 cl.values.push_back(arg);
326 nonLiveSet.insert(arg);
327 }
328
329 // Do (2). (Skip creating generic operand cleanup entries for call ops.
330 // Call arguments will be removed in the call-site specific segment-aware
331 // cleanup, avoiding generic eraseOperands bitvector mechanics.)
332 SymbolTable::UseRange uses = *funcOp.getSymbolUses(module);
333 for (SymbolTable::SymbolUse use : uses) {
334 Operation *callOp = use.getUser();
335 assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
336 // Push an empty operand cleanup entry so that call-site specific logic in
337 // cleanUpDeadVals runs (it keys off CallOpInterface). The BitVector is
338 // intentionally all false to avoid generic erasure.
339 // Store the funcOp as the callee to avoid expensive symbol lookup later.
340 cl.operands.push_back({callOp, BitVector(callOp->getNumOperands(), false),
341 funcOp.getOperation()});
342 }
343
344 // Do (3).
345 // Get the list of unnecessary terminator operands (return values that are
346 // non-live across all callers) in `nonLiveRets`. There is a very important
347 // subtlety here. Unnecessary terminator operands are NOT the operands of the
348 // terminator that are non-live. Instead, these are the return values of the
349 // callers such that a given return value is non-live across all callers. Such
350 // corresponding operands in the terminator could be live. An example to
351 // demonstrate this:
352 // func.func private @f(%arg0: memref<i32>) -> (i32, i32) {
353 // %c0_i32 = arith.constant 0 : i32
354 // %0 = arith.addi %c0_i32, %c0_i32 : i32
355 // memref.store %0, %arg0[] : memref<i32>
356 // return %c0_i32, %0 : i32, i32
357 // }
358 // func.func @main(%arg0: i32, %arg1: memref<i32>) -> (i32) {
359 // %1:2 = call @f(%arg1) : (memref<i32>) -> i32
360 // return %1#0 : i32
361 // }
362 // Here, we can see that %1#1 is never used. It is non-live. Thus, @f doesn't
363 // need to return %0. But, %0 is live. And, still, we want to stop it from
364 // being returned, in order to optimize our IR. So, this demonstrates how we
365 // can make our optimization strong by even removing a live return value (%0),
366 // since it forwards only to non-live value(s) (%1#1).
367 size_t numReturns = funcOp.getNumResults();
368 BitVector nonLiveRets(numReturns, true);
369 for (SymbolTable::SymbolUse use : uses) {
370 Operation *callOp = use.getUser();
371 assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
372 BitVector liveCallRets = markLives(callOp->getResults(), nonLiveSet, la);
373 nonLiveRets &= liveCallRets.flip();
374 }
375
376 // Note that in the absence of control flow ops forcing the control to go from
377 // the entry (first) block to the other blocks, the control never reaches any
378 // block other than the entry block, because every block has a terminator.
379 for (Block &block : funcOp.getBlocks()) {
380 Operation *returnOp = block.getTerminator();
381 if (!returnOp->hasTrait<OpTrait::ReturnLike>())
382 continue;
383 if (returnOp && returnOp->getNumOperands() == numReturns)
384 cl.operands.push_back({returnOp, nonLiveRets});
385 }
386
387 // Do (4).
388 cl.functions.push_back({funcOp, nonLiveArgs, nonLiveRets});
389
390 // Do (5) and (6).
391 if (numReturns == 0)
392 return;
393 for (SymbolTable::SymbolUse use : uses) {
394 Operation *callOp = use.getUser();
395 assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
396 cl.results.push_back({callOp, nonLiveRets});
397 collectNonLiveValues(nonLiveSet, callOp->getResults(), nonLiveRets);
398 }
399}
400
401/// Process a region branch operation `regionBranchOp` using the liveness
402/// information in `la`. The processing involves two scenarios:
403///
404/// Scenario 1: If the operation has no memory effects and none of its results
405/// are live:
406/// (1') Enqueue all its uses for deletion.
407/// (2') Enqueue the branch itself for deletion.
408///
409/// Scenario 2: Otherwise:
410/// (1) Collect its unnecessary operands (operands forwarded to unnecessary
411/// results or arguments).
412/// (2) Process each of its regions.
413/// (3) Collect the uses of its unnecessary results (results forwarded from
414/// unnecessary operands
415/// or terminator operands).
416/// (4) Add these results to the deletion list.
417///
418/// Processing a region includes:
419/// (a) Collecting the uses of its unnecessary arguments (arguments forwarded
420/// from unnecessary operands
421/// or terminator operands).
422/// (b) Collecting these unnecessary arguments.
423/// (c) Collecting its unnecessary terminator operands (terminator operands
424/// forwarded to unnecessary results
425/// or arguments).
426///
427/// Value Flow Note: In this operation, values flow as follows:
428/// - From operands and terminator operands (successor operands)
429/// - To arguments and results (successor inputs).
430static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
432 DenseSet<Value> &nonLiveSet,
433 RDVFinalCleanupList &cl) {
434 LDBG() << "Processing region branch op: "
435 << OpWithFlags(regionBranchOp, OpPrintingFlags().skipRegions());
436 // Mark live results of `regionBranchOp` in `liveResults`.
437 auto markLiveResults = [&](BitVector &liveResults) {
438 liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la);
439 };
440
441 // Mark live arguments in the regions of `regionBranchOp` in `liveArgs`.
442 auto markLiveArgs = [&](DenseMap<Region *, BitVector> &liveArgs) {
443 for (Region &region : regionBranchOp->getRegions()) {
444 if (region.empty())
445 continue;
446 SmallVector<Value> arguments(region.front().getArguments());
447 BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la);
448 liveArgs[&region] = regionLiveArgs;
449 }
450 };
451
452 // Return the successors of `region` if the latter is not null. Else return
453 // the successors of `regionBranchOp`.
454 auto getSuccessors = [&](RegionBranchPoint point) {
456 regionBranchOp.getSuccessorRegions(point, successors);
457 return successors;
458 };
459
460 // Return the operands of `terminator` that are forwarded to `successor` if
461 // the former is not null. Else return the operands of `regionBranchOp`
462 // forwarded to `successor`.
463 auto getForwardedOpOperands = [&](const RegionSuccessor &successor,
464 Operation *terminator = nullptr) {
465 OperandRange operands =
466 terminator ? cast<RegionBranchTerminatorOpInterface>(terminator)
467 .getSuccessorOperands(successor)
468 : regionBranchOp.getEntrySuccessorOperands(successor);
469 SmallVector<OpOperand *> opOperands = operandsToOpOperands(operands);
470 return opOperands;
471 };
472
473 // Mark the non-forwarded operands of `regionBranchOp` in
474 // `nonForwardedOperands`.
475 auto markNonForwardedOperands = [&](BitVector &nonForwardedOperands) {
476 nonForwardedOperands.resize(regionBranchOp->getNumOperands(), true);
477 for (const RegionSuccessor &successor :
478 getSuccessors(RegionBranchPoint::parent())) {
479 for (OpOperand *opOperand : getForwardedOpOperands(successor))
480 nonForwardedOperands.reset(opOperand->getOperandNumber());
481 }
482 };
483
484 // Mark the non-forwarded terminator operands of the various regions of
485 // `regionBranchOp` in `nonForwardedRets`.
486 auto markNonForwardedReturnValues =
487 [&](DenseMap<Operation *, BitVector> &nonForwardedRets) {
488 for (Region &region : regionBranchOp->getRegions()) {
489 if (region.empty())
490 continue;
491 // TODO: this isn't correct in face of multiple terminators.
492 Operation *terminator = region.front().getTerminator();
493 nonForwardedRets[terminator] =
494 BitVector(terminator->getNumOperands(), true);
495 for (const RegionSuccessor &successor :
496 getSuccessors(RegionBranchPoint(
497 cast<RegionBranchTerminatorOpInterface>(terminator)))) {
498 for (OpOperand *opOperand :
499 getForwardedOpOperands(successor, terminator))
500 nonForwardedRets[terminator].reset(opOperand->getOperandNumber());
501 }
502 }
503 };
504
505 // Update `valuesToKeep` (which is expected to correspond to operands or
506 // terminator operands) based on `resultsToKeep` and `argsToKeep`, given
507 // `region`. When `valuesToKeep` correspond to operands, `region` is null.
508 // Else, `region` is the parent region of the terminator.
509 auto updateOperandsOrTerminatorOperandsToKeep =
510 [&](BitVector &valuesToKeep, BitVector &resultsToKeep,
511 DenseMap<Region *, BitVector> &argsToKeep, Region *region = nullptr) {
512 Operation *terminator =
513 region ? region->front().getTerminator() : nullptr;
514 RegionBranchPoint point =
515 terminator
517 cast<RegionBranchTerminatorOpInterface>(terminator))
518 : RegionBranchPoint::parent();
519
520 for (const RegionSuccessor &successor : getSuccessors(point)) {
521 Region *successorRegion = successor.getSuccessor();
522 for (auto [opOperand, input] :
523 llvm::zip(getForwardedOpOperands(successor, terminator),
524 successor.getSuccessorInputs())) {
525 size_t operandNum = opOperand->getOperandNumber();
526 bool updateBasedOn =
527 successorRegion
528 ? argsToKeep[successorRegion]
529 [cast<BlockArgument>(input).getArgNumber()]
530 : resultsToKeep[cast<OpResult>(input).getResultNumber()];
531 valuesToKeep[operandNum] = valuesToKeep[operandNum] | updateBasedOn;
532 }
533 }
534 };
535
536 // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep` and
537 // `terminatorOperandsToKeep`. Store true in `resultsOrArgsToKeepChanged` if a
538 // value is modified, else, false.
539 auto recomputeResultsAndArgsToKeep =
540 [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep,
541 BitVector &operandsToKeep,
542 DenseMap<Operation *, BitVector> &terminatorOperandsToKeep,
543 bool &resultsOrArgsToKeepChanged) {
544 resultsOrArgsToKeepChanged = false;
545
546 // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep`.
547 for (const RegionSuccessor &successor :
548 getSuccessors(RegionBranchPoint::parent())) {
549 Region *successorRegion = successor.getSuccessor();
550 for (auto [opOperand, input] :
551 llvm::zip(getForwardedOpOperands(successor),
552 successor.getSuccessorInputs())) {
553 bool recomputeBasedOn =
554 operandsToKeep[opOperand->getOperandNumber()];
555 bool toRecompute =
556 successorRegion
557 ? argsToKeep[successorRegion]
558 [cast<BlockArgument>(input).getArgNumber()]
559 : resultsToKeep[cast<OpResult>(input).getResultNumber()];
560 if (!toRecompute && recomputeBasedOn)
561 resultsOrArgsToKeepChanged = true;
562 if (successorRegion) {
563 argsToKeep[successorRegion][cast<BlockArgument>(input)
564 .getArgNumber()] =
565 argsToKeep[successorRegion]
566 [cast<BlockArgument>(input).getArgNumber()] |
567 recomputeBasedOn;
568 } else {
569 resultsToKeep[cast<OpResult>(input).getResultNumber()] =
570 resultsToKeep[cast<OpResult>(input).getResultNumber()] |
571 recomputeBasedOn;
572 }
573 }
574 }
575
576 // Recompute `resultsToKeep` and `argsToKeep` based on
577 // `terminatorOperandsToKeep`.
578 for (Region &region : regionBranchOp->getRegions()) {
579 if (region.empty())
580 continue;
581 Operation *terminator = region.front().getTerminator();
582 for (const RegionSuccessor &successor :
583 getSuccessors(RegionBranchPoint(
584 cast<RegionBranchTerminatorOpInterface>(terminator)))) {
585 Region *successorRegion = successor.getSuccessor();
586 for (auto [opOperand, input] :
587 llvm::zip(getForwardedOpOperands(successor, terminator),
588 successor.getSuccessorInputs())) {
589 bool recomputeBasedOn =
590 terminatorOperandsToKeep[region.back().getTerminator()]
591 [opOperand->getOperandNumber()];
592 bool toRecompute =
593 successorRegion
594 ? argsToKeep[successorRegion]
595 [cast<BlockArgument>(input).getArgNumber()]
596 : resultsToKeep[cast<OpResult>(input).getResultNumber()];
597 if (!toRecompute && recomputeBasedOn)
598 resultsOrArgsToKeepChanged = true;
599 if (successorRegion) {
600 argsToKeep[successorRegion][cast<BlockArgument>(input)
601 .getArgNumber()] =
602 argsToKeep[successorRegion]
603 [cast<BlockArgument>(input).getArgNumber()] |
604 recomputeBasedOn;
605 } else {
606 resultsToKeep[cast<OpResult>(input).getResultNumber()] =
607 resultsToKeep[cast<OpResult>(input).getResultNumber()] |
608 recomputeBasedOn;
609 }
610 }
611 }
612 }
613 };
614
615 // Mark the values that we want to keep in `resultsToKeep`, `argsToKeep`,
616 // `operandsToKeep`, and `terminatorOperandsToKeep`.
617 auto markValuesToKeep =
618 [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep,
619 BitVector &operandsToKeep,
620 DenseMap<Operation *, BitVector> &terminatorOperandsToKeep) {
621 bool resultsOrArgsToKeepChanged = true;
622 // We keep updating and recomputing the values until we reach a point
623 // where they stop changing.
624 while (resultsOrArgsToKeepChanged) {
625 // Update the operands that need to be kept.
626 updateOperandsOrTerminatorOperandsToKeep(operandsToKeep,
627 resultsToKeep, argsToKeep);
628
629 // Update the terminator operands that need to be kept.
630 for (Region &region : regionBranchOp->getRegions()) {
631 if (region.empty())
632 continue;
633 updateOperandsOrTerminatorOperandsToKeep(
634 terminatorOperandsToKeep[region.back().getTerminator()],
635 resultsToKeep, argsToKeep, &region);
636 }
637
638 // Recompute the results and arguments that need to be kept.
639 recomputeResultsAndArgsToKeep(
640 resultsToKeep, argsToKeep, operandsToKeep,
641 terminatorOperandsToKeep, resultsOrArgsToKeepChanged);
642 }
643 };
644
645 // Scenario 1. This is the only case where the entire `regionBranchOp`
646 // is removed. It will not happen in any other scenario. Note that in this
647 // case, a non-forwarded operand of `regionBranchOp` could be live/non-live.
648 // It could never be live because of this op but its liveness could have been
649 // attributed to something else.
650 // Do (1') and (2').
651 if (isMemoryEffectFree(regionBranchOp.getOperation()) &&
652 !hasLive(regionBranchOp->getResults(), nonLiveSet, la)) {
653 cl.operations.push_back(regionBranchOp.getOperation());
654 return;
655 }
656
657 // Scenario 2.
658 // At this point, we know that every non-forwarded operand of `regionBranchOp`
659 // is live.
660
661 // Stores the results of `regionBranchOp` that we want to keep.
662 BitVector resultsToKeep;
663 // Stores the mapping from regions of `regionBranchOp` to their arguments that
664 // we want to keep.
666 // Stores the operands of `regionBranchOp` that we want to keep.
667 BitVector operandsToKeep;
668 // Stores the mapping from region terminators in `regionBranchOp` to their
669 // operands that we want to keep.
670 DenseMap<Operation *, BitVector> terminatorOperandsToKeep;
671
672 // Initializing the above variables...
673
674 // The live results of `regionBranchOp` definitely need to be kept.
675 markLiveResults(resultsToKeep);
676 // Similarly, the live arguments of the regions in `regionBranchOp` definitely
677 // need to be kept.
678 markLiveArgs(argsToKeep);
679 // The non-forwarded operands of `regionBranchOp` definitely need to be kept.
680 // A live forwarded operand can be removed but no non-forwarded operand can be
681 // removed since it "controls" the flow of data in this control flow op.
682 markNonForwardedOperands(operandsToKeep);
683 // Similarly, the non-forwarded terminator operands of the regions in
684 // `regionBranchOp` definitely need to be kept.
685 markNonForwardedReturnValues(terminatorOperandsToKeep);
686
687 // Mark the values (results, arguments, operands, and terminator operands)
688 // that we want to keep.
689 markValuesToKeep(resultsToKeep, argsToKeep, operandsToKeep,
690 terminatorOperandsToKeep);
691
692 // Do (1).
693 cl.operands.push_back({regionBranchOp, operandsToKeep.flip()});
694
695 // Do (2.a) and (2.b).
696 for (Region &region : regionBranchOp->getRegions()) {
697 if (region.empty())
698 continue;
699 BitVector argsToRemove = argsToKeep[&region].flip();
700 cl.blocks.push_back({&region.front(), argsToRemove});
701 collectNonLiveValues(nonLiveSet, region.front().getArguments(),
702 argsToRemove);
703 }
704
705 // Do (2.c).
706 for (Region &region : regionBranchOp->getRegions()) {
707 if (region.empty())
708 continue;
709 Operation *terminator = region.front().getTerminator();
710 cl.operands.push_back(
711 {terminator, terminatorOperandsToKeep[terminator].flip()});
712 }
713
714 // Do (3) and (4).
715 BitVector resultsToRemove = resultsToKeep.flip();
716 collectNonLiveValues(nonLiveSet, regionBranchOp.getOperation()->getResults(),
717 resultsToRemove);
718 cl.results.push_back({regionBranchOp.getOperation(), resultsToRemove});
719}
720
721/// Steps to process a `BranchOpInterface` operation:
722///
723/// When a non-forwarded operand is dead (e.g., the condition value of a
724/// conditional branch op), the entire operation is dead.
725///
726/// Otherwise, iterate through each successor block of `branchOp`.
727/// (1) For each successor block, gather all operands from all successors.
728/// (2) Fetch their associated liveness analysis data and collect for future
729/// removal.
730/// (3) Identify and collect the dead operands from the successor block
731/// as well as their corresponding arguments.
732
733static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
734 DenseSet<Value> &nonLiveSet,
735 RDVFinalCleanupList &cl) {
736 LDBG() << "Processing branch op: " << *branchOp;
737
738 // Check for dead non-forwarded operands.
739 BitVector deadNonForwardedOperands =
740 markLives(branchOp->getOperands(), nonLiveSet, la).flip();
741 unsigned numSuccessors = branchOp->getNumSuccessors();
742 for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
743 SuccessorOperands successorOperands =
744 branchOp.getSuccessorOperands(succIdx);
745 // Remove all non-forwarded operands from the bit vector.
746 for (OpOperand &opOperand : successorOperands.getMutableForwardedOperands())
747 deadNonForwardedOperands[opOperand.getOperandNumber()] = false;
748 }
749 if (deadNonForwardedOperands.any()) {
750 cl.operations.push_back(branchOp.getOperation());
751 return;
752 }
753
754 for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
755 Block *successorBlock = branchOp->getSuccessor(succIdx);
756
757 // Do (1)
758 SuccessorOperands successorOperands =
759 branchOp.getSuccessorOperands(succIdx);
760 SmallVector<Value> operandValues;
761 for (unsigned operandIdx = 0; operandIdx < successorOperands.size();
762 ++operandIdx) {
763 operandValues.push_back(successorOperands[operandIdx]);
764 }
765
766 // Do (2)
767 BitVector successorNonLive =
768 markLives(operandValues, nonLiveSet, la).flip();
769 collectNonLiveValues(nonLiveSet, successorBlock->getArguments(),
770 successorNonLive);
771
772 // Do (3)
773 cl.blocks.push_back({successorBlock, successorNonLive});
774 cl.successorOperands.push_back({branchOp, succIdx, successorNonLive});
775 }
776}
777
778/// Removes dead values collected in RDVFinalCleanupList.
779/// To be run once when all dead values have been collected.
780static void cleanUpDeadVals(RDVFinalCleanupList &list) {
781 LDBG() << "Starting cleanup of dead values...";
782
783 // 1. Blocks, We must remove the block arguments and successor operands before
784 // deleting the operation, as they may reside in the region operation.
785 LDBG() << "Cleaning up " << list.blocks.size() << " block argument lists";
786 for (auto &b : list.blocks) {
787 // blocks that are accessed via multiple codepaths processed once
788 if (b.b->getNumArguments() != b.nonLiveArgs.size())
789 continue;
790 LDBG() << "Erasing " << b.nonLiveArgs.count()
791 << " non-live arguments from block: " << b.b;
792 // it iterates backwards because erase invalidates all successor indexes
793 for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) {
794 if (!b.nonLiveArgs[i])
795 continue;
796 LDBG() << " Erasing block argument " << i << ": " << b.b->getArgument(i);
797 b.b->getArgument(i).dropAllUses();
798 b.b->eraseArgument(i);
799 }
800 }
801
802 // 2. Successor Operands
803 LDBG() << "Cleaning up " << list.successorOperands.size()
804 << " successor operand lists";
805 for (auto &op : list.successorOperands) {
806 SuccessorOperands successorOperands =
807 op.branch.getSuccessorOperands(op.successorIndex);
808 // blocks that are accessed via multiple codepaths processed once
809 if (successorOperands.size() != op.nonLiveOperands.size())
810 continue;
811 LDBG() << "Erasing " << op.nonLiveOperands.count()
812 << " non-live successor operands from successor "
813 << op.successorIndex << " of branch: "
814 << OpWithFlags(op.branch, OpPrintingFlags().skipRegions());
815 // it iterates backwards because erase invalidates all successor indexes
816 for (int i = successorOperands.size() - 1; i >= 0; --i) {
817 if (!op.nonLiveOperands[i])
818 continue;
819 LDBG() << " Erasing successor operand " << i << ": "
820 << successorOperands[i];
821 successorOperands.erase(i);
822 }
823 }
824
825 // 3. Operations
826 LDBG() << "Cleaning up " << list.operations.size() << " operations";
827 for (Operation *op : list.operations) {
828 LDBG() << "Erasing operation: "
829 << OpWithFlags(op, OpPrintingFlags().skipRegions());
830 if (op->hasTrait<OpTrait::IsTerminator>()) {
831 // When erasing a terminator, insert an unreachable op in its place.
832 OpBuilder b(op);
833 ub::UnreachableOp::create(b, op->getLoc());
834 }
835 op->dropAllUses();
836 op->erase();
837 }
838
839 // 4. Values
840 LDBG() << "Cleaning up " << list.values.size() << " values";
841 for (auto &v : list.values) {
842 LDBG() << "Dropping all uses of value: " << v;
843 v.dropAllUses();
844 }
845
846 // 5. Functions
847 LDBG() << "Cleaning up " << list.functions.size() << " functions";
848 // Record which function arguments were erased so we can shrink call-site
849 // argument segments for CallOpInterface operations (e.g. ops using
850 // AttrSizedOperandSegments) in the next phase.
852 for (auto &f : list.functions) {
853 LDBG() << "Cleaning up function: " << f.funcOp.getOperation()->getName();
854 LDBG() << " Erasing " << f.nonLiveArgs.count() << " non-live arguments";
855 LDBG() << " Erasing " << f.nonLiveRets.count()
856 << " non-live return values";
857 // Some functions may not allow erasing arguments or results. These calls
858 // return failure in such cases without modifying the function, so it's okay
859 // to proceed.
860 if (succeeded(f.funcOp.eraseArguments(f.nonLiveArgs))) {
861 // Record only if we actually erased something.
862 if (f.nonLiveArgs.any())
863 erasedFuncArgs.try_emplace(f.funcOp.getOperation(), f.nonLiveArgs);
864 }
865 (void)f.funcOp.eraseResults(f.nonLiveRets);
866 }
867
868 // 6. Operands
869 LDBG() << "Cleaning up " << list.operands.size() << " operand lists";
870 for (OperationToCleanup &o : list.operands) {
871 // Handle call-specific cleanup only when we have a cached callee reference.
872 // This avoids expensive symbol lookup and is defensive against future
873 // changes.
874 bool handledAsCall = false;
875 if (o.callee && isa<CallOpInterface>(o.op)) {
876 auto call = cast<CallOpInterface>(o.op);
877 auto it = erasedFuncArgs.find(o.callee);
878 if (it != erasedFuncArgs.end()) {
879 const BitVector &deadArgIdxs = it->second;
880 MutableOperandRange args = call.getArgOperandsMutable();
881 // First, erase the call arguments corresponding to erased callee
882 // args. We iterate backwards to preserve indices.
883 for (unsigned argIdx : llvm::reverse(deadArgIdxs.set_bits()))
884 args.erase(argIdx);
885 // If this operand cleanup entry also has a generic nonLive bitvector,
886 // clear bits for call arguments we already erased above to avoid
887 // double-erasing (which could impact other segments of ops with
888 // AttrSizedOperandSegments).
889 if (o.nonLive.any()) {
890 // Map the argument logical index to the operand number(s) recorded.
891 int operandOffset = call.getArgOperands().getBeginOperandIndex();
892 for (int argIdx : deadArgIdxs.set_bits()) {
893 int operandNumber = operandOffset + argIdx;
894 if (operandNumber < static_cast<int>(o.nonLive.size()))
895 o.nonLive.reset(operandNumber);
896 }
897 }
898 handledAsCall = true;
899 }
900 }
901 // Perform generic operand erasure for:
902 // - Non-call operations
903 // - Call operations without cached callee (where handledAsCall is false)
904 // But skip call operations that were already handled via segment-aware path
905 if (!handledAsCall && o.nonLive.any()) {
906 o.op->eraseOperands(o.nonLive);
907 }
908 }
909
910 // 7. Results
911 LDBG() << "Cleaning up " << list.results.size() << " result lists";
912 for (auto &r : list.results) {
913 LDBG() << "Erasing " << r.nonLive.count()
914 << " non-live results from operation: "
915 << OpWithFlags(r.op, OpPrintingFlags().skipRegions());
916 dropUsesAndEraseResults(r.op, r.nonLive);
917 }
918 LDBG() << "Finished cleanup of dead values";
919}
920
921struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {
922 void runOnOperation() override;
923};
924} // namespace
925
926void RemoveDeadValues::runOnOperation() {
927 auto &la = getAnalysis<RunLivenessAnalysis>();
928 Operation *module = getOperation();
929
930 // Tracks values eligible for erasure - complements liveness analysis to
931 // identify "droppable" values.
932 DenseSet<Value> deadVals;
933
934 // Maintains a list of Ops, values, branches, etc., slated for cleanup at the
935 // end of this pass.
936 RDVFinalCleanupList finalCleanupList;
937
938 module->walk([&](Operation *op) {
939 if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
940 processFuncOp(funcOp, module, la, deadVals, finalCleanupList);
941 } else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
942 processRegionBranchOp(regionBranchOp, la, deadVals, finalCleanupList);
943 } else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
944 processBranchOp(branchOp, la, deadVals, finalCleanupList);
945 } else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
946 // Nothing to do here because this is a terminator op and it should be
947 // honored with respect to its parent
948 } else if (isa<CallOpInterface>(op)) {
949 // Nothing to do because this op is associated with a function op and gets
950 // cleaned when the latter is cleaned.
951 } else {
952 processSimpleOp(op, la, deadVals, finalCleanupList);
953 }
954 });
955
956 cleanUpDeadVals(finalCleanupList);
957}
958
959std::unique_ptr<Pass> mlir::createRemoveDeadValuesPass() {
960 return std::make_unique<RemoveDeadValues>();
961}
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static MutableArrayRef< OpOperand > operandsToOpOperands(OperandRange &operands)
Block represents an ordered list of Operations.
Definition Block.h:33
void erase()
Unlink this Block from its parent region and delete it.
Definition Block.cpp:66
BlockArgListType getArguments()
Definition Block.h:87
Block * getSuccessor(unsigned i)
Definition Block.cpp:269
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:118
void erase(unsigned subStart, unsigned subLen=1)
Erase the operands within the given sub-range.
This class helps build Operations.
Definition Builders.h:207
This class represents an operand of an operation.
Definition Value.h:257
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
Definition Value.h:457
This class provides the API for ops that are known to be terminators.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition Operation.h:1111
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
void dropAllUses()
Drop all uses of results of this operation.
Definition Operation.h:834
void eraseOperands(unsigned idx, unsigned length=1)
Erase the operands starting at position idx and ending at position 'idx'+'length'.
Definition Operation.h:360
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
unsigned getNumOperands()
Definition Operation.h:346
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition Operation.cpp:67
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
static constexpr RegionBranchPoint parent()
Returns an instance of RegionBranchPoint representing the parent operation.
This class represents a successor of a region.
ValueRange getSuccessorInputs() const
Return the inputs to the successor that are remapped by the exit values of the current region.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
void push_back(Block *block)
Definition Region.h:61
This class models how operands are forwarded to block arguments in control flow.
void erase(unsigned subStart, unsigned subLen=1)
Erase operands forwarded to the successor.
MutableOperandRange getMutableForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
unsigned size() const
Returns the amount of operands passed to the successor.
This class represents a specific symbol use.
This class implements a range of SymbolRef uses.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Include the generated interface declarations.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
std::unique_ptr< Pass > createRemoveDeadValuesPass()
Creates an optimization pass to remove dead values.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
This trait indicates that a terminator operation is "return-like".
This represents an operation in an abstracted form, suitable for use with the builder APIs.
This lattice represents, for a given value, whether or not it is "live".
Runs liveness analysis on the IR defined by op.
const Liveness * getLiveness(Value val)