54#include "llvm/ADT/STLExtras.h"
55#include "llvm/Support/Debug.h"
56#include "llvm/Support/DebugLog.h"
63#define DEBUG_TYPE "remove-dead-values"
66#define GEN_PASS_DEF_REMOVEDEADVALUES
67#include "mlir/Transforms/Passes.h.inc"
83struct FunctionToCleanUp {
84 FunctionOpInterface funcOp;
85 BitVector nonLiveArgs;
86 BitVector nonLiveRets;
89struct ResultsToCleanup {
94struct OperandsToCleanup {
101struct BlockArgsToCleanup {
103 BitVector nonLiveArgs;
106struct SuccessorOperandsToCleanup {
107 BranchOpInterface branch;
108 unsigned successorIndex;
109 BitVector nonLiveOperands;
112struct RDVFinalCleanupList {
113 SmallVector<Operation *> operations;
114 SmallVector<FunctionToCleanUp> functions;
115 SmallVector<OperandsToCleanup> operands;
116 SmallVector<ResultsToCleanup> results;
117 SmallVector<BlockArgsToCleanup> blocks;
118 SmallVector<SuccessorOperandsToCleanup> successorOperands;
127 for (
Value value : values) {
128 if (nonLiveSet.contains(value)) {
129 LDBG() <<
"Value " << value <<
" is already marked non-live (dead)";
135 LDBG() <<
"Value " << value
136 <<
" has no liveness info, conservatively considered live";
140 LDBG() <<
"Value " << value <<
" is live according to liveness analysis";
143 LDBG() <<
"Value " << value <<
" is dead according to liveness analysis";
153 BitVector lives(values.size(),
true);
155 for (
auto [
index, value] : llvm::enumerate(values)) {
156 if (nonLiveSet.contains(value)) {
158 LDBG() <<
"Value " << value
159 <<
" is already marked non-live (dead) at index " <<
index;
171 LDBG() <<
"Value " << value <<
" at index " <<
index
172 <<
" has no liveness info, conservatively considered live";
177 LDBG() <<
"Value " << value <<
" at index " <<
index
178 <<
" is dead according to liveness analysis";
180 LDBG() <<
"Value " << value <<
" at index " <<
index
181 <<
" is live according to liveness analysis";
192 const BitVector &nonLive) {
193 for (
auto [
index,
result] : llvm::enumerate(range)) {
196 nonLiveSet.insert(
result);
197 LDBG() <<
"Marking value " <<
result <<
" as non-live (dead) at index "
204static void dropUsesAndEraseResults(
Operation *op, BitVector toErase) {
206 "expected the number of results in `op` and the size of `toErase` to "
208 for (
auto idx : toErase.set_bits())
211 rewriter.eraseOpResults(op, toErase);
218 for (
unsigned i = 0, e = operands.size(); i < e; i++)
219 opOperands.push_back(&values[i]);
238 RDVFinalCleanupList &cl) {
242 bool hasDeadOperand =
243 markLives(op->
getOperands(), nonLiveSet, la).flip().any();
244 if (hasDeadOperand) {
245 LDBG() <<
"Simple op has dead operands, so the op must be dead: "
248 assert(!hasLive(op->
getResults(), nonLiveSet, la) &&
249 "expected the op to have no live results");
250 cl.operations.push_back(op);
251 collectNonLiveValues(nonLiveSet, op->
getResults(),
257 LDBG() <<
"Simple op is not memory effect free or has live results, "
265 <<
"Simple op has all dead results and is memory effect free, scheduling "
268 cl.operations.push_back(op);
269 collectNonLiveValues(nonLiveSet, op->
getResults(),
283static void processFuncOp(FunctionOpInterface funcOp,
Operation *module,
285 RDVFinalCleanupList &cl) {
286 LDBG() <<
"Processing function op: "
289 if (funcOp.isPublic() || funcOp.isExternal()) {
290 LDBG() <<
"Function is public or external, skipping: "
291 << funcOp.getOperation()->getName();
297 BitVector nonLiveArgs = markLives(arguments, nonLiveSet, la);
298 nonLiveArgs = nonLiveArgs.flip();
301 for (
auto [
index, arg] : llvm::enumerate(arguments))
302 if (arg && nonLiveArgs[
index])
303 nonLiveSet.insert(arg);
311 assert(isa<CallOpInterface>(callOp) &&
"expected a call-like user");
316 cl.operands.push_back({callOp, BitVector(callOp->
getNumOperands(),
false),
317 funcOp.getOperation()});
343 size_t numReturns = funcOp.getNumResults();
344 BitVector nonLiveRets(numReturns,
true);
347 assert(isa<CallOpInterface>(callOp) &&
"expected a call-like user");
348 BitVector liveCallRets = markLives(callOp->
getResults(), nonLiveSet, la);
349 nonLiveRets &= liveCallRets.flip();
355 for (
Block &block : funcOp.getBlocks()) {
356 Operation *returnOp = block.getTerminator();
360 cl.operands.push_back({returnOp, nonLiveRets});
364 cl.functions.push_back({funcOp, nonLiveArgs, nonLiveRets});
371 assert(isa<CallOpInterface>(callOp) &&
"expected a call-like user");
372 cl.results.push_back({callOp, nonLiveRets});
373 collectNonLiveValues(nonLiveSet, callOp->
getResults(), nonLiveRets);
406static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
409 RDVFinalCleanupList &cl) {
410 LDBG() <<
"Processing region branch op: "
421 !hasLive(regionBranchOp->getResults(), nonLiveSet, la)) {
422 cl.operations.push_back(regionBranchOp.getOperation());
427 auto markLiveResults = [&](BitVector &liveResults) {
428 liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la);
433 for (
Region ®ion : regionBranchOp->getRegions()) {
437 BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la);
438 liveArgs[®ion] = regionLiveArgs;
446 regionBranchOp.getSuccessorRegions(point, successors);
456 regionBranchOp.getSuccessorOperands(src, successor));
462 auto markNonForwardedOperands = [&](BitVector &nonForwardedOperands) {
463 nonForwardedOperands.resize(regionBranchOp->getNumOperands(),
true);
468 nonForwardedOperands.reset(opOperand->getOperandNumber());
474 auto markNonForwardedReturnValues =
476 for (
Region ®ion : regionBranchOp->getRegions()) {
480 auto terminator = cast<RegionBranchTerminatorOpInterface>(
481 region.front().getTerminator());
482 nonForwardedRets[terminator] =
483 BitVector(terminator->getNumOperands(),
true);
485 for (
OpOperand *opOperand : getForwardedOpOperands(
487 nonForwardedRets[terminator].reset(opOperand->getOperandNumber());
496 auto updateOperandsOrTerminatorOperandsToKeep =
497 [&](BitVector &valuesToKeep, BitVector &resultsToKeep,
500 region ? region->front().getTerminator() :
nullptr;
504 cast<RegionBranchTerminatorOpInterface>(terminator))
508 Region *successorRegion = successor.getSuccessor();
509 for (
auto [opOperand, input] :
510 llvm::zip(getForwardedOpOperands(point, successor),
511 successor.getSuccessorInputs())) {
512 size_t operandNum = opOperand->getOperandNumber();
515 ? argsToKeep[successorRegion]
516 [cast<BlockArgument>(input).getArgNumber()]
517 : resultsToKeep[cast<OpResult>(input).getResultNumber()];
518 valuesToKeep[operandNum] = valuesToKeep[operandNum] | updateBasedOn;
526 auto recomputeResultsAndArgsToKeep =
528 BitVector &operandsToKeep,
530 bool &resultsOrArgsToKeepChanged) {
531 resultsOrArgsToKeepChanged =
false;
536 Region *successorRegion = successor.getSuccessor();
537 for (
auto [opOperand, input] :
540 successor.getSuccessorInputs())) {
541 bool recomputeBasedOn =
542 operandsToKeep[opOperand->getOperandNumber()];
545 ? argsToKeep[successorRegion]
546 [cast<BlockArgument>(input).getArgNumber()]
547 : resultsToKeep[cast<OpResult>(input).getResultNumber()];
548 if (!toRecompute && recomputeBasedOn)
549 resultsOrArgsToKeepChanged =
true;
550 if (successorRegion) {
551 argsToKeep[successorRegion][cast<BlockArgument>(input)
553 argsToKeep[successorRegion]
554 [cast<BlockArgument>(input).getArgNumber()] |
557 resultsToKeep[cast<OpResult>(input).getResultNumber()] =
558 resultsToKeep[cast<OpResult>(input).getResultNumber()] |
566 for (
Region ®ion : regionBranchOp->getRegions()) {
569 auto terminator = cast<RegionBranchTerminatorOpInterface>(
570 region.front().getTerminator());
572 Region *successorRegion = successor.getSuccessor();
573 for (
auto [opOperand, input] :
576 successor.getSuccessorInputs())) {
577 bool recomputeBasedOn =
578 terminatorOperandsToKeep[region.back().getTerminator()]
579 [opOperand->getOperandNumber()];
582 ? argsToKeep[successorRegion]
583 [cast<BlockArgument>(input).getArgNumber()]
584 : resultsToKeep[cast<OpResult>(input).getResultNumber()];
585 if (!toRecompute && recomputeBasedOn)
586 resultsOrArgsToKeepChanged =
true;
587 if (successorRegion) {
588 argsToKeep[successorRegion][cast<BlockArgument>(input)
590 argsToKeep[successorRegion]
591 [cast<BlockArgument>(input).getArgNumber()] |
594 resultsToKeep[cast<OpResult>(input).getResultNumber()] =
595 resultsToKeep[cast<OpResult>(input).getResultNumber()] |
605 auto markValuesToKeep =
607 BitVector &operandsToKeep,
609 bool resultsOrArgsToKeepChanged =
true;
612 while (resultsOrArgsToKeepChanged) {
614 updateOperandsOrTerminatorOperandsToKeep(operandsToKeep,
615 resultsToKeep, argsToKeep);
618 for (
Region ®ion : regionBranchOp->getRegions()) {
621 updateOperandsOrTerminatorOperandsToKeep(
622 terminatorOperandsToKeep[region.back().getTerminator()],
623 resultsToKeep, argsToKeep, ®ion);
627 recomputeResultsAndArgsToKeep(
628 resultsToKeep, argsToKeep, operandsToKeep,
629 terminatorOperandsToKeep, resultsOrArgsToKeepChanged);
638 BitVector resultsToKeep;
643 BitVector operandsToKeep;
651 markLiveResults(resultsToKeep);
654 markLiveArgs(argsToKeep);
658 markNonForwardedOperands(operandsToKeep);
661 markNonForwardedReturnValues(terminatorOperandsToKeep);
665 markValuesToKeep(resultsToKeep, argsToKeep, operandsToKeep,
666 terminatorOperandsToKeep);
669 cl.operands.push_back({regionBranchOp, operandsToKeep.flip()});
672 for (
Region ®ion : regionBranchOp->getRegions()) {
675 BitVector argsToRemove = argsToKeep[®ion].flip();
676 cl.blocks.push_back({®ion.front(), argsToRemove});
677 collectNonLiveValues(nonLiveSet, region.front().getArguments(),
682 for (
Region ®ion : regionBranchOp->getRegions()) {
685 Operation *terminator = region.front().getTerminator();
686 cl.operands.push_back(
687 {terminator, terminatorOperandsToKeep[terminator].flip()});
691 BitVector resultsToRemove = resultsToKeep.flip();
692 collectNonLiveValues(nonLiveSet, regionBranchOp.getOperation()->getResults(),
694 cl.results.push_back({regionBranchOp.getOperation(), resultsToRemove});
711 RDVFinalCleanupList &cl) {
712 LDBG() <<
"Processing branch op: " << *branchOp;
715 BitVector deadNonForwardedOperands =
716 markLives(branchOp->getOperands(), nonLiveSet, la).flip();
717 unsigned numSuccessors = branchOp->getNumSuccessors();
718 for (
unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
720 branchOp.getSuccessorOperands(succIdx);
723 deadNonForwardedOperands[opOperand.getOperandNumber()] =
false;
725 if (deadNonForwardedOperands.any()) {
726 cl.operations.push_back(branchOp.getOperation());
730 for (
unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
735 branchOp.getSuccessorOperands(succIdx);
737 for (
unsigned operandIdx = 0; operandIdx < successorOperands.
size();
739 operandValues.push_back(successorOperands[operandIdx]);
743 BitVector successorNonLive =
744 markLives(operandValues, nonLiveSet, la).flip();
745 collectNonLiveValues(nonLiveSet, successorBlock->
getArguments(),
749 cl.blocks.push_back({successorBlock, successorNonLive});
750 cl.successorOperands.push_back({branchOp, succIdx, successorNonLive});
756static void cleanUpDeadVals(RDVFinalCleanupList &list) {
757 LDBG() <<
"Starting cleanup of dead values...";
761 LDBG() <<
"Cleaning up " << list.blocks.size() <<
" block argument lists";
762 for (
auto &
b : list.blocks) {
764 if (
b.b->getNumArguments() !=
b.nonLiveArgs.size())
767 os <<
"Erasing non-live arguments [";
768 llvm::interleaveComma(
b.nonLiveArgs.set_bits(), os);
769 os <<
"] from block #" <<
b.b->computeBlockNumber() <<
" in region #"
770 <<
b.b->getParent()->getRegionNumber() <<
" of operation "
776 for (
int i =
b.nonLiveArgs.size() - 1; i >= 0; --i) {
777 if (!
b.nonLiveArgs[i])
779 b.b->getArgument(i).dropAllUses();
780 b.b->eraseArgument(i);
785 LDBG() <<
"Cleaning up " << list.successorOperands.size()
786 <<
" successor operand lists";
787 for (
auto &op : list.successorOperands) {
789 op.branch.getSuccessorOperands(op.successorIndex);
791 if (successorOperands.
size() != op.nonLiveOperands.size())
794 os <<
"Erasing non-live successor operands [";
795 llvm::interleaveComma(op.nonLiveOperands.set_bits(), os);
796 os <<
"] from successor " << op.successorIndex <<
" of branch: "
801 for (
int i = successorOperands.
size() - 1; i >= 0; --i) {
802 if (!op.nonLiveOperands[i])
804 successorOperands.
erase(i);
809 LDBG() <<
"Cleaning up " << list.operations.size() <<
" operations";
811 LDBG() <<
"Erasing operation: "
817 ub::UnreachableOp::create(
b, op->
getLoc());
824 LDBG() <<
"Cleaning up " << list.functions.size() <<
" functions";
829 for (
auto &f : list.functions) {
830 LDBG() <<
"Cleaning up function: " << f.funcOp.getOperation()->getName()
831 <<
" (" << f.funcOp.getOperation() <<
")";
833 os <<
" Erasing non-live arguments [";
834 llvm::interleaveComma(f.nonLiveArgs.set_bits(), os);
836 os <<
" Erasing non-live return values [";
837 llvm::interleaveComma(f.nonLiveRets.set_bits(), os);
841 for (
auto deadIdx : f.nonLiveArgs.set_bits())
842 f.funcOp.getArgument(deadIdx).dropAllUses();
846 if (succeeded(f.funcOp.eraseArguments(f.nonLiveArgs))) {
848 if (f.nonLiveArgs.any())
849 erasedFuncArgs.try_emplace(f.funcOp.getOperation(), f.nonLiveArgs);
851 (
void)f.funcOp.eraseResults(f.nonLiveRets);
855 LDBG() <<
"Cleaning up " << list.operands.size() <<
" operand lists";
856 for (OperandsToCleanup &o : list.operands) {
860 bool handledAsCall =
false;
861 if (o.callee && isa<CallOpInterface>(o.op)) {
862 auto call = cast<CallOpInterface>(o.op);
863 auto it = erasedFuncArgs.find(o.callee);
864 if (it != erasedFuncArgs.end()) {
865 const BitVector &deadArgIdxs = it->second;
869 for (
unsigned argIdx : llvm::reverse(deadArgIdxs.set_bits()))
875 if (o.nonLive.any()) {
877 int operandOffset = call.getArgOperands().getBeginOperandIndex();
878 for (
int argIdx : deadArgIdxs.set_bits()) {
879 int operandNumber = operandOffset + argIdx;
880 if (operandNumber <
static_cast<int>(o.nonLive.size()))
881 o.nonLive.reset(operandNumber);
884 handledAsCall =
true;
891 if (!handledAsCall && o.nonLive.any()) {
893 os <<
"Erasing non-live operands [";
894 llvm::interleaveComma(o.nonLive.set_bits(), os);
895 os <<
"] from operation: "
904 LDBG() <<
"Cleaning up " << list.results.size() <<
" result lists";
905 for (
auto &r : list.results) {
907 os <<
"Erasing non-live results [";
908 llvm::interleaveComma(r.nonLive.set_bits(), os);
909 os <<
"] from operation: "
913 dropUsesAndEraseResults(r.op, r.nonLive);
915 LDBG() <<
"Finished cleanup of dead values";
919 void runOnOperation()
override;
923void RemoveDeadValues::runOnOperation() {
924 auto &la = getAnalysis<RunLivenessAnalysis>();
925 Operation *module = getOperation();
933 RDVFinalCleanupList finalCleanupList;
935 module->walk([&](Operation *op) {
936 if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
937 processFuncOp(funcOp, module, la, deadVals, finalCleanupList);
938 }
else if (
auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
939 processRegionBranchOp(regionBranchOp, la, deadVals, finalCleanupList);
940 }
else if (
auto branchOp = dyn_cast<BranchOpInterface>(op)) {
941 processBranchOp(branchOp, la, deadVals, finalCleanupList);
942 }
else if (op->
hasTrait<::mlir::OpTrait::IsTerminator>()) {
945 }
else if (isa<CallOpInterface>(op)) {
949 processSimpleOp(op, la, deadVals, finalCleanupList);
953 cleanUpDeadVals(finalCleanupList);
957 return std::make_unique<RemoveDeadValues>();
static MutableArrayRef< OpOperand > operandsToOpOperands(OperandRange &operands)
Block represents an ordered list of Operations.
BlockArgListType getArguments()
Block * getSuccessor(unsigned i)
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class provides a mutable adaptor for a range of operands.
void erase(unsigned subStart, unsigned subLen=1)
Erase the operands within the given sub-range.
This class helps build Operations.
This class represents an operand of an operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
This class provides the API for ops that are known to be terminators.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
void dropAllUses()
Drop all uses of results of this operation.
void eraseOperands(unsigned idx, unsigned length=1)
Erase the operands starting at position idx and ending at position 'idx'+'length'.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
static constexpr RegionBranchPoint parent()
Returns an instance of RegionBranchPoint representing the parent operation.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class models how operands are forwarded to block arguments in control flow.
void erase(unsigned subStart, unsigned subLen=1)
Erase operands forwarded to the successor.
MutableOperandRange getMutableForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
unsigned size() const
Returns the amount of operands passed to the successor.
This class represents a specific symbol use.
This class implements a range of SymbolRef uses.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void dropAllUses()
Drop all uses of this object from their respective owners.
Include the generated interface declarations.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
std::unique_ptr< Pass > createRemoveDeadValuesPass()
Creates an optimization pass to remove dead values.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
This trait indicates that a terminator operation is "return-like".
This lattice represents, for a given value, whether or not it is "live".
Runs liveness analysis on the IR defined by op.
const Liveness * getLiveness(Value val)