53 #include "llvm/ADT/STLExtras.h"
54 #include "llvm/Support/Debug.h"
55 #include "llvm/Support/DebugLog.h"
62 #define DEBUG_TYPE "remove-dead-values"
65 #define GEN_PASS_DEF_REMOVEDEADVALUES
66 #include "mlir/Transforms/Passes.h.inc"
82 struct FunctionToCleanUp {
83 FunctionOpInterface funcOp;
84 BitVector nonLiveArgs;
85 BitVector nonLiveRets;
88 struct OperationToCleanup {
95 struct BlockArgsToCleanup {
97 BitVector nonLiveArgs;
100 struct SuccessorOperandsToCleanup {
101 BranchOpInterface branch;
102 unsigned successorIndex;
103 BitVector nonLiveOperands;
106 struct RDVFinalCleanupList {
122 for (
Value value : values) {
123 if (nonLiveSet.contains(value)) {
124 LDBG() <<
"Value " << value <<
" is already marked non-live (dead)";
130 LDBG() <<
"Value " << value
131 <<
" has no liveness info, conservatively considered live";
135 LDBG() <<
"Value " << value <<
" is live according to liveness analysis";
138 LDBG() <<
"Value " << value <<
" is dead according to liveness analysis";
148 BitVector lives(values.size(),
true);
151 if (nonLiveSet.contains(value)) {
153 LDBG() <<
"Value " << value
154 <<
" is already marked non-live (dead) at index " << index;
166 LDBG() <<
"Value " << value <<
" at index " << index
167 <<
" has no liveness info, conservatively considered live";
172 LDBG() <<
"Value " << value <<
" at index " << index
173 <<
" is dead according to liveness analysis";
175 LDBG() <<
"Value " << value <<
" at index " << index
176 <<
" is live according to liveness analysis";
187 const BitVector &nonLive) {
191 nonLiveSet.insert(result);
192 LDBG() <<
"Marking value " << result <<
" as non-live (dead) at index "
199 static void dropUsesAndEraseResults(
Operation *op, BitVector toErase) {
201 "expected the number of results in `op` and the size of `toErase` to "
204 std::vector<Type> newResultTypes;
206 if (!toErase[result.getResultNumber()])
207 newResultTypes.push_back(result.getType());
209 builder.setInsertionPointAfter(op);
220 while (!region.empty())
221 region.front().moveBefore(temp);
225 unsigned indexOfNextNewCallOpResultToReplace = 0;
227 assert(result &&
"expected result to be non-null");
228 if (toErase[index]) {
229 result.dropAllUses();
231 result.replaceAllUsesWith(
232 newOp->
getResult(indexOfNextNewCallOpResultToReplace++));
242 for (
unsigned i = 0, e = operands.size(); i < e; i++)
243 opOperands.push_back(&values[i]);
262 RDVFinalCleanupList &cl) {
264 LDBG() <<
"Simple op is not memory effect free or has live results, "
271 <<
"Simple op has all dead results and is memory effect free, scheduling "
274 cl.operations.push_back(op);
275 collectNonLiveValues(nonLiveSet, op->
getResults(),
289 static void processFuncOp(FunctionOpInterface funcOp,
Operation *module,
291 RDVFinalCleanupList &cl) {
292 LDBG() <<
"Processing function op: "
294 if (funcOp.isPublic() || funcOp.isExternal()) {
295 LDBG() <<
"Function is public or external, skipping: "
296 << funcOp.getOperation()->getName();
302 BitVector nonLiveArgs = markLives(arguments, nonLiveSet, la);
303 nonLiveArgs = nonLiveArgs.flip();
307 if (arg && nonLiveArgs[index]) {
308 cl.values.push_back(arg);
309 nonLiveSet.insert(arg);
318 assert(isa<CallOpInterface>(callOp) &&
"expected a call-like user");
323 cl.operands.push_back({callOp, BitVector(callOp->
getNumOperands(),
false),
324 funcOp.getOperation()});
350 size_t numReturns = funcOp.getNumResults();
351 BitVector nonLiveRets(numReturns,
true);
354 assert(isa<CallOpInterface>(callOp) &&
"expected a call-like user");
355 BitVector liveCallRets = markLives(callOp->
getResults(), nonLiveSet, la);
356 nonLiveRets &= liveCallRets.flip();
362 for (
Block &block : funcOp.getBlocks()) {
363 Operation *returnOp = block.getTerminator();
365 cl.operands.push_back({returnOp, nonLiveRets});
369 cl.functions.push_back({funcOp, nonLiveArgs, nonLiveRets});
376 assert(isa<CallOpInterface>(callOp) &&
"expected a call-like user");
377 cl.results.push_back({callOp, nonLiveRets});
378 collectNonLiveValues(nonLiveSet, callOp->
getResults(), nonLiveRets);
411 static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
414 RDVFinalCleanupList &cl) {
415 LDBG() <<
"Processing region branch op: "
418 auto markLiveResults = [&](BitVector &liveResults) {
419 liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la);
424 for (
Region ®ion : regionBranchOp->getRegions()) {
428 BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la);
429 liveArgs[®ion] = regionLiveArgs;
437 regionBranchOp.getSuccessorRegions(point, successors);
447 terminator ? cast<RegionBranchTerminatorOpInterface>(terminator)
448 .getSuccessorOperands(successor)
449 : regionBranchOp.getEntrySuccessorOperands(successor);
456 auto markNonForwardedOperands = [&](BitVector &nonForwardedOperands) {
457 nonForwardedOperands.resize(regionBranchOp->getNumOperands(),
true);
460 for (
OpOperand *opOperand : getForwardedOpOperands(successor))
461 nonForwardedOperands.reset(opOperand->getOperandNumber());
467 auto markNonForwardedReturnValues =
469 for (
Region ®ion : regionBranchOp->getRegions()) {
473 Operation *terminator = region.front().getTerminator();
474 nonForwardedRets[terminator] =
478 cast<RegionBranchTerminatorOpInterface>(terminator)))) {
480 getForwardedOpOperands(successor, terminator))
481 nonForwardedRets[terminator].reset(opOperand->getOperandNumber());
490 auto updateOperandsOrTerminatorOperandsToKeep =
491 [&](BitVector &valuesToKeep, BitVector &resultsToKeep,
494 region ? region->front().getTerminator() :
nullptr;
498 cast<RegionBranchTerminatorOpInterface>(terminator))
503 for (
auto [opOperand, input] :
504 llvm::zip(getForwardedOpOperands(successor, terminator),
506 size_t operandNum = opOperand->getOperandNumber();
509 ? argsToKeep[successorRegion]
510 [cast<BlockArgument>(input).getArgNumber()]
511 : resultsToKeep[cast<OpResult>(input).getResultNumber()];
512 valuesToKeep[operandNum] = valuesToKeep[operandNum] | updateBasedOn;
520 auto recomputeResultsAndArgsToKeep =
522 BitVector &operandsToKeep,
524 bool &resultsOrArgsToKeepChanged) {
525 resultsOrArgsToKeepChanged =
false;
531 for (
auto [opOperand, input] :
532 llvm::zip(getForwardedOpOperands(successor),
534 bool recomputeBasedOn =
535 operandsToKeep[opOperand->getOperandNumber()];
538 ? argsToKeep[successorRegion]
539 [cast<BlockArgument>(input).getArgNumber()]
540 : resultsToKeep[cast<OpResult>(input).getResultNumber()];
541 if (!toRecompute && recomputeBasedOn)
542 resultsOrArgsToKeepChanged =
true;
543 if (successorRegion) {
544 argsToKeep[successorRegion][cast<BlockArgument>(input)
546 argsToKeep[successorRegion]
547 [cast<BlockArgument>(input).getArgNumber()] |
550 resultsToKeep[cast<OpResult>(input).getResultNumber()] =
551 resultsToKeep[cast<OpResult>(input).getResultNumber()] |
559 for (
Region ®ion : regionBranchOp->getRegions()) {
562 Operation *terminator = region.front().getTerminator();
565 cast<RegionBranchTerminatorOpInterface>(terminator)))) {
567 for (
auto [opOperand, input] :
568 llvm::zip(getForwardedOpOperands(successor, terminator),
570 bool recomputeBasedOn =
571 terminatorOperandsToKeep[region.back().getTerminator()]
572 [opOperand->getOperandNumber()];
575 ? argsToKeep[successorRegion]
576 [cast<BlockArgument>(input).getArgNumber()]
577 : resultsToKeep[cast<OpResult>(input).getResultNumber()];
578 if (!toRecompute && recomputeBasedOn)
579 resultsOrArgsToKeepChanged =
true;
580 if (successorRegion) {
581 argsToKeep[successorRegion][cast<BlockArgument>(input)
583 argsToKeep[successorRegion]
584 [cast<BlockArgument>(input).getArgNumber()] |
587 resultsToKeep[cast<OpResult>(input).getResultNumber()] =
588 resultsToKeep[cast<OpResult>(input).getResultNumber()] |
598 auto markValuesToKeep =
600 BitVector &operandsToKeep,
602 bool resultsOrArgsToKeepChanged =
true;
605 while (resultsOrArgsToKeepChanged) {
607 updateOperandsOrTerminatorOperandsToKeep(operandsToKeep,
608 resultsToKeep, argsToKeep);
611 for (
Region ®ion : regionBranchOp->getRegions()) {
614 updateOperandsOrTerminatorOperandsToKeep(
615 terminatorOperandsToKeep[region.back().getTerminator()],
616 resultsToKeep, argsToKeep, ®ion);
620 recomputeResultsAndArgsToKeep(
621 resultsToKeep, argsToKeep, operandsToKeep,
622 terminatorOperandsToKeep, resultsOrArgsToKeepChanged);
633 !hasLive(regionBranchOp->getResults(), nonLiveSet, la)) {
634 cl.operations.push_back(regionBranchOp.getOperation());
643 BitVector resultsToKeep;
648 BitVector operandsToKeep;
656 markLiveResults(resultsToKeep);
659 markLiveArgs(argsToKeep);
663 markNonForwardedOperands(operandsToKeep);
666 markNonForwardedReturnValues(terminatorOperandsToKeep);
670 markValuesToKeep(resultsToKeep, argsToKeep, operandsToKeep,
671 terminatorOperandsToKeep);
674 cl.operands.push_back({regionBranchOp, operandsToKeep.flip()});
677 for (
Region ®ion : regionBranchOp->getRegions()) {
680 BitVector argsToRemove = argsToKeep[®ion].flip();
681 cl.blocks.push_back({®ion.front(), argsToRemove});
682 collectNonLiveValues(nonLiveSet, region.front().getArguments(),
687 for (
Region ®ion : regionBranchOp->getRegions()) {
690 Operation *terminator = region.front().getTerminator();
691 cl.operands.push_back(
692 {terminator, terminatorOperandsToKeep[terminator].flip()});
696 BitVector resultsToRemove = resultsToKeep.flip();
697 collectNonLiveValues(nonLiveSet, regionBranchOp.getOperation()->getResults(),
699 cl.results.push_back({regionBranchOp.getOperation(), resultsToRemove});
712 RDVFinalCleanupList &cl) {
713 LDBG() <<
"Processing branch op: " << *branchOp;
714 unsigned numSuccessors = branchOp->getNumSuccessors();
716 for (
unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
721 branchOp.getSuccessorOperands(succIdx);
723 for (
unsigned operandIdx = 0; operandIdx < successorOperands.
size();
725 operandValues.push_back(successorOperands[operandIdx]);
729 BitVector successorNonLive =
730 markLives(operandValues, nonLiveSet, la).flip();
731 collectNonLiveValues(nonLiveSet, successorBlock->
getArguments(),
735 cl.blocks.push_back({successorBlock, successorNonLive});
736 cl.successorOperands.push_back({branchOp, succIdx, successorNonLive});
742 static void cleanUpDeadVals(RDVFinalCleanupList &list) {
743 LDBG() <<
"Starting cleanup of dead values...";
746 LDBG() <<
"Cleaning up " << list.operations.size() <<
" operations";
747 for (
auto &op : list.operations) {
748 LDBG() <<
"Erasing operation: "
755 LDBG() <<
"Cleaning up " << list.values.size() <<
" values";
756 for (
auto &v : list.values) {
757 LDBG() <<
"Dropping all uses of value: " << v;
762 LDBG() <<
"Cleaning up " << list.functions.size() <<
" functions";
767 for (
auto &f : list.functions) {
768 LDBG() <<
"Cleaning up function: " << f.funcOp.getOperation()->getName();
769 LDBG() <<
" Erasing " << f.nonLiveArgs.count() <<
" non-live arguments";
770 LDBG() <<
" Erasing " << f.nonLiveRets.count()
771 <<
" non-live return values";
775 if (succeeded(f.funcOp.eraseArguments(f.nonLiveArgs))) {
777 if (f.nonLiveArgs.any())
778 erasedFuncArgs.try_emplace(f.funcOp.getOperation(), f.nonLiveArgs);
780 (void)f.funcOp.eraseResults(f.nonLiveRets);
784 LDBG() <<
"Cleaning up " << list.operands.size() <<
" operand lists";
785 for (OperationToCleanup &o : list.operands) {
789 bool handledAsCall =
false;
790 if (o.callee && isa<CallOpInterface>(o.op)) {
791 auto call = cast<CallOpInterface>(o.op);
792 auto it = erasedFuncArgs.find(o.callee);
793 if (it != erasedFuncArgs.end()) {
794 const BitVector &deadArgIdxs = it->second;
798 for (
unsigned argIdx : llvm::reverse(deadArgIdxs.set_bits()))
804 if (o.nonLive.any()) {
806 int operandOffset = call.getArgOperands().getBeginOperandIndex();
807 for (
int argIdx : deadArgIdxs.set_bits()) {
808 int operandNumber = operandOffset + argIdx;
809 if (operandNumber <
static_cast<int>(o.nonLive.size()))
810 o.nonLive.reset(operandNumber);
813 handledAsCall =
true;
820 if (!handledAsCall && o.nonLive.any()) {
821 o.op->eraseOperands(o.nonLive);
826 LDBG() <<
"Cleaning up " << list.results.size() <<
" result lists";
827 for (
auto &r : list.results) {
828 LDBG() <<
"Erasing " << r.nonLive.count()
829 <<
" non-live results from operation: "
831 dropUsesAndEraseResults(r.op, r.nonLive);
835 LDBG() <<
"Cleaning up " << list.blocks.size() <<
" block argument lists";
836 for (
auto &b : list.blocks) {
838 if (b.b->getNumArguments() != b.nonLiveArgs.size())
840 LDBG() <<
"Erasing " << b.nonLiveArgs.count()
841 <<
" non-live arguments from block: " << b.b;
843 for (
int i = b.nonLiveArgs.size() - 1; i >= 0; --i) {
844 if (!b.nonLiveArgs[i])
846 LDBG() <<
" Erasing block argument " << i <<
": " << b.b->getArgument(i);
847 b.b->getArgument(i).dropAllUses();
848 b.b->eraseArgument(i);
853 LDBG() <<
"Cleaning up " << list.successorOperands.size()
854 <<
" successor operand lists";
855 for (
auto &op : list.successorOperands) {
857 op.branch.getSuccessorOperands(op.successorIndex);
859 if (successorOperands.
size() != op.nonLiveOperands.size())
861 LDBG() <<
"Erasing " << op.nonLiveOperands.count()
862 <<
" non-live successor operands from successor "
863 << op.successorIndex <<
" of branch: "
866 for (
int i = successorOperands.
size() - 1; i >= 0; --i) {
867 if (!op.nonLiveOperands[i])
869 LDBG() <<
" Erasing successor operand " << i <<
": "
870 << successorOperands[i];
871 successorOperands.
erase(i);
875 LDBG() <<
"Finished cleanup of dead values";
878 struct RemoveDeadValues :
public impl::RemoveDeadValuesBase<RemoveDeadValues> {
879 void runOnOperation()
override;
883 void RemoveDeadValues::runOnOperation() {
884 auto &la = getAnalysis<RunLivenessAnalysis>();
893 RDVFinalCleanupList finalCleanupList;
896 if (
auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
897 processFuncOp(funcOp, module, la, deadVals, finalCleanupList);
898 }
else if (
auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
899 processRegionBranchOp(regionBranchOp, la, deadVals, finalCleanupList);
900 }
else if (
auto branchOp = dyn_cast<BranchOpInterface>(op)) {
901 processBranchOp(branchOp, la, deadVals, finalCleanupList);
905 }
else if (isa<CallOpInterface>(op)) {
909 processSimpleOp(op, la, deadVals, finalCleanupList);
913 cleanUpDeadVals(finalCleanupList);
917 return std::make_unique<RemoveDeadValues>();
static MutableArrayRef< OpOperand > operandsToOpOperands(OperandRange &operands)
Block represents an ordered list of Operations.
void erase()
Unlink this Block from its parent region and delete it.
BlockArgListType getArguments()
Block * getSuccessor(unsigned i)
This class provides a mutable adaptor for a range of operands.
void erase(unsigned subStart, unsigned subLen=1)
Erase the operands within the given sub-range.
This class helps build Operations.
This class represents an operand of an operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be terminators.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
This class implements the operand iterators for the Operation class.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
void dropAllUses()
Drop all uses of results of this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
static constexpr RegionBranchPoint parent()
Returns an instance of RegionBranchPoint representing the parent operation.
This class represents a successor of a region.
ValueRange getSuccessorInputs() const
Return the inputs to the successor that are remapped by the exit values of the current region.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void push_back(Block *block)
This class models how operands are forwarded to block arguments in control flow.
void erase(unsigned subStart, unsigned subLen=1)
Erase operands forwarded to the successor.
unsigned size() const
Returns the amount of operands passed to the successor.
This class represents a specific symbol use.
This class implements a range of SymbolRef uses.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
std::unique_ptr< Pass > createRemoveDeadValuesPass()
Creates an optimization pass to remove dead values.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
This lattice represents, for a given value, whether or not it is "live".
Runs liveness analysis on the IR defined by op.
const Liveness * getLiveness(Value val)