54#include "llvm/ADT/STLExtras.h"
55#include "llvm/Support/Debug.h"
56#include "llvm/Support/DebugLog.h"
63#define DEBUG_TYPE "remove-dead-values"
66#define GEN_PASS_DEF_REMOVEDEADVALUESPASS
67#include "mlir/Transforms/Passes.h.inc"
83struct FunctionToCleanUp {
84 FunctionOpInterface funcOp;
85 BitVector nonLiveArgs;
86 BitVector nonLiveRets;
89struct ResultsToCleanup {
94struct OperandsToCleanup {
98 Operation *callee =
nullptr;
101 bool replaceWithPoison =
false;
104struct BlockArgsToCleanup {
106 BitVector nonLiveArgs;
109struct SuccessorOperandsToCleanup {
110 BranchOpInterface branch;
111 unsigned successorIndex;
112 BitVector nonLiveOperands;
115struct RDVFinalCleanupList {
116 SmallVector<Operation *> operations;
117 SmallVector<FunctionToCleanUp> functions;
118 SmallVector<OperandsToCleanup> operands;
119 SmallVector<ResultsToCleanup> results;
120 SmallVector<BlockArgsToCleanup> blocks;
121 SmallVector<SuccessorOperandsToCleanup> successorOperands;
130 for (
Value value : values) {
131 if (nonLiveSet.contains(value)) {
132 LDBG() <<
"Value " << value <<
" is already marked non-live (dead)";
138 LDBG() <<
"Value " << value
139 <<
" has no liveness info, conservatively considered live";
143 LDBG() <<
"Value " << value <<
" is live according to liveness analysis";
146 LDBG() <<
"Value " << value <<
" is dead according to liveness analysis";
155 BitVector lives(values.size(),
true);
157 for (
auto [
index, value] : llvm::enumerate(values)) {
158 if (nonLiveSet.contains(value)) {
160 LDBG() <<
"Value " << value
161 <<
" is already marked non-live (dead) at index " <<
index;
173 LDBG() <<
"Value " << value <<
" at index " <<
index
174 <<
" has no liveness info, conservatively considered live";
179 LDBG() <<
"Value " << value <<
" at index " <<
index
180 <<
" is dead according to liveness analysis";
182 LDBG() <<
"Value " << value <<
" at index " <<
index
183 <<
" is live according to liveness analysis";
194 const BitVector &nonLive) {
195 for (
auto [
index,
result] : llvm::enumerate(range)) {
198 nonLiveSet.insert(
result);
199 LDBG() <<
"Marking value " <<
result <<
" as non-live (dead) at index "
209 "expected the number of results in `op` and the size of `toErase` to "
211 for (
auto idx : toErase.set_bits())
231 RDVFinalCleanupList &cl) {
235 bool hasDeadOperand =
236 markLives(op->
getOperands(), nonLiveSet, la).flip().any();
237 if (hasDeadOperand) {
238 LDBG() <<
"Simple op has dead operands, so the op must be dead: "
241 assert(!hasLive(op->
getResults(), nonLiveSet, la) &&
242 "expected the op to have no live results");
243 cl.operations.push_back(op);
244 collectNonLiveValues(nonLiveSet, op->
getResults(),
250 LDBG() <<
"Simple op is not memory effect free or has live results, "
258 <<
"Simple op has all dead results and is memory effect free, scheduling "
261 cl.operations.push_back(op);
262 collectNonLiveValues(nonLiveSet, op->
getResults(),
276static void processFuncOp(FunctionOpInterface funcOp,
Operation *module,
278 RDVFinalCleanupList &cl) {
279 LDBG() <<
"Processing function op: "
282 if (funcOp.isPublic() || funcOp.isExternal()) {
283 LDBG() <<
"Function is public or external, skipping: "
284 << funcOp.getOperation()->getName();
289 return !isa<CallOpInterface>(use.
getUser());
298 BitVector nonLiveArgs = markLives(arguments, nonLiveSet, la);
299 nonLiveArgs = nonLiveArgs.flip();
302 for (
auto [
index, arg] : llvm::enumerate(arguments))
303 if (arg && nonLiveArgs[
index])
304 nonLiveSet.insert(arg);
315 cl.operands.push_back({callOp, BitVector(callOp->
getNumOperands(),
false),
316 funcOp.getOperation()});
342 size_t numReturns = funcOp.getNumResults();
343 BitVector nonLiveRets(numReturns,
true);
346 assert(isa<CallOpInterface>(callOp) &&
"expected a call-like user");
347 BitVector liveCallRets = markLives(callOp->
getResults(), nonLiveSet, la);
348 nonLiveRets &= liveCallRets.flip();
354 for (
Block &block : funcOp.getBlocks()) {
355 Operation *returnOp = block.getTerminator();
359 cl.operands.push_back({returnOp, nonLiveRets});
363 cl.functions.push_back({funcOp, nonLiveArgs, nonLiveRets});
370 assert(isa<CallOpInterface>(callOp) &&
"expected a call-like user");
371 cl.results.push_back({callOp, nonLiveRets});
372 collectNonLiveValues(nonLiveSet, callOp->
getResults(), nonLiveRets);
395static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
398 RDVFinalCleanupList &cl) {
399 LDBG() <<
"Processing region branch op: "
409 !hasLive(regionBranchOp->getResults(), nonLiveSet, la)) {
410 cl.operations.push_back(regionBranchOp.getOperation());
434 regionBranchOp.getSuccessorOperandInputMapping(operandToSuccessorInputs);
437 for (
auto [opOperand, successorInputs] : operandToSuccessorInputs) {
440 auto markOperandDead = [&opOperand = opOperand, &deadOperandsPerOp]() {
445 BitVector &deadOperands =
447 .try_emplace(opOperand->getOwner(),
448 opOperand->getOwner()->getNumOperands(),
false)
450 deadOperands.set(opOperand->getOperandNumber());
454 if (!hasLive(opOperand->get(), nonLiveSet, la)) {
461 if (!hasLive(successorInputs, nonLiveSet, la))
465 for (
auto [op, deadOperands] : deadOperandsPerOp) {
466 cl.operands.push_back(
467 {op, deadOperands,
nullptr,
true});
485 RDVFinalCleanupList &cl) {
486 LDBG() <<
"Processing branch op: " << *branchOp;
489 BitVector deadNonForwardedOperands =
490 markLives(branchOp->getOperands(), nonLiveSet, la).flip();
491 unsigned numSuccessors = branchOp->getNumSuccessors();
492 for (
unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
494 branchOp.getSuccessorOperands(succIdx);
497 deadNonForwardedOperands[opOperand.getOperandNumber()] =
false;
499 if (deadNonForwardedOperands.any()) {
500 cl.operations.push_back(branchOp.getOperation());
504 for (
unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
509 branchOp.getSuccessorOperands(succIdx);
511 for (
unsigned operandIdx = 0; operandIdx < successorOperands.
size();
513 operandValues.push_back(successorOperands[operandIdx]);
517 BitVector successorNonLive =
518 markLives(operandValues, nonLiveSet, la).flip();
519 collectNonLiveValues(nonLiveSet, successorBlock->
getArguments(),
523 cl.blocks.push_back({successorBlock, successorNonLive});
524 cl.successorOperands.push_back({branchOp, succIdx, successorNonLive});
532 return llvm::map_to_vector(values, [&](
Value value) {
535 return ub::PoisonOp::create(
b, value.
getLoc(), value.
getType()).getResult();
542 void notifyOperationErased(Operation *op)
override {
543 if (
auto poisonOp = dyn_cast<ub::PoisonOp>(op))
544 poisonOps.erase(poisonOp);
546 void notifyOperationInserted(Operation *op,
547 OpBuilder::InsertPoint previous)
override {
548 if (
auto poisonOp = dyn_cast<ub::PoisonOp>(op))
549 poisonOps.insert(poisonOp);
557static void cleanUpDeadVals(
MLIRContext *ctx, RDVFinalCleanupList &list) {
558 LDBG() <<
"Starting cleanup of dead values...";
563 TrackingListener listener;
568 LDBG() <<
"Cleaning up " << list.blocks.size() <<
" block argument lists";
569 for (
auto &
b : list.blocks) {
571 if (
b.b->getNumArguments() !=
b.nonLiveArgs.size())
574 os <<
"Erasing non-live arguments [";
575 llvm::interleaveComma(
b.nonLiveArgs.set_bits(), os);
576 os <<
"] from block #" <<
b.b->computeBlockNumber() <<
" in region #"
577 <<
b.b->getParent()->getRegionNumber() <<
" of operation "
583 for (
int i =
b.nonLiveArgs.size() - 1; i >= 0; --i) {
584 if (!
b.nonLiveArgs[i])
586 b.b->getArgument(i).dropAllUses();
587 b.b->eraseArgument(i);
592 LDBG() <<
"Cleaning up " << list.successorOperands.size()
593 <<
" successor operand lists";
594 for (
auto &op : list.successorOperands) {
596 op.branch.getSuccessorOperands(op.successorIndex);
598 if (successorOperands.
size() != op.nonLiveOperands.size())
601 os <<
"Erasing non-live successor operands [";
602 llvm::interleaveComma(op.nonLiveOperands.set_bits(), os);
603 os <<
"] from successor " << op.successorIndex <<
" of branch: "
608 for (
int i = successorOperands.
size() - 1; i >= 0; --i) {
609 if (!op.nonLiveOperands[i])
611 successorOperands.
erase(i);
616 LDBG() <<
"Cleaning up " << list.functions.size() <<
" functions";
621 for (
auto &f : list.functions) {
622 LDBG() <<
"Cleaning up function: " << f.funcOp.getName() <<
" ("
623 << f.funcOp.getOperation() <<
")";
625 os <<
" Erasing non-live arguments [";
626 llvm::interleaveComma(f.nonLiveArgs.set_bits(), os);
628 os <<
" Erasing non-live return values [";
629 llvm::interleaveComma(f.nonLiveRets.set_bits(), os);
633 for (
auto deadIdx : f.nonLiveArgs.set_bits())
634 f.funcOp.getArgument(deadIdx).dropAllUses();
638 if (succeeded(f.funcOp.eraseArguments(f.nonLiveArgs))) {
640 if (f.nonLiveArgs.any())
641 erasedFuncArgs.try_emplace(f.funcOp.getOperation(), f.nonLiveArgs);
643 LDBG() <<
"Failed to erase arguments for function: "
644 << f.funcOp.getName();
646 (
void)f.funcOp.eraseResults(f.nonLiveRets);
650 LDBG() <<
"Cleaning up " << list.operands.size() <<
" operand lists";
651 for (OperandsToCleanup &o : list.operands) {
655 bool handledAsCall =
false;
656 if (o.callee && isa<CallOpInterface>(o.op)) {
657 auto call = cast<CallOpInterface>(o.op);
658 auto it = erasedFuncArgs.find(o.callee);
659 if (it != erasedFuncArgs.end()) {
660 const BitVector &deadArgIdxs = it->second;
664 for (
unsigned argIdx : llvm::reverse(deadArgIdxs.set_bits()))
670 if (o.nonLive.any()) {
672 int operandOffset = call.getArgOperands().getBeginOperandIndex();
673 for (
int argIdx : deadArgIdxs.set_bits()) {
674 int operandNumber = operandOffset + argIdx;
675 if (operandNumber <
static_cast<int>(o.nonLive.size()))
676 o.nonLive.reset(operandNumber);
679 handledAsCall =
true;
686 if (!handledAsCall && o.nonLive.any()) {
688 os <<
"Erasing non-live operands [";
689 llvm::interleaveComma(o.nonLive.set_bits(), os);
690 os <<
"] from operation: "
694 if (o.replaceWithPoison) {
696 for (
auto deadIdx : o.nonLive.set_bits()) {
698 deadIdx, createPoisonedValues(rewriter, o.op->
getOperand(deadIdx))
708 LDBG() <<
"Cleaning up " << list.results.size() <<
" result lists";
709 for (
auto &r : list.results) {
711 os <<
"Erasing non-live results [";
712 llvm::interleaveComma(r.nonLive.set_bits(), os);
713 os <<
"] from operation: "
717 dropUsesAndEraseResults(rewriter, r.op, r.nonLive);
721 LDBG() <<
"Cleaning up " << list.operations.size() <<
" operations";
723 LDBG() <<
"Erasing operation: "
729 ub::UnreachableOp::create(rewriter, op->
getLoc());
742 for (
Value opResult : opResults) {
745 if (opResult.use_empty())
749 Value poisonedValue = createPoisonedValues(rewriter, opResult).front();
758 for (ub::PoisonOp poisonOp : listener.poisonOps) {
759 if (poisonOp.use_empty())
763 LDBG() <<
"Finished cleanup of dead values";
766struct RemoveDeadValues
768 using impl::RemoveDeadValuesPassBase<
769 RemoveDeadValues>::RemoveDeadValuesPassBase;
770 void runOnOperation()
override;
774void RemoveDeadValues::runOnOperation() {
775 auto &la = getAnalysis<RunLivenessAnalysis>();
776 Operation *module = getOperation();
784 RDVFinalCleanupList finalCleanupList;
786 module->walk([&](Operation *op) {
787 if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
788 processFuncOp(funcOp, module, la, deadVals, finalCleanupList);
789 }
else if (
auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
790 processRegionBranchOp(regionBranchOp, la, deadVals, finalCleanupList);
791 }
else if (
auto branchOp = dyn_cast<BranchOpInterface>(op)) {
792 processBranchOp(branchOp, la, deadVals, finalCleanupList);
793 }
else if (op->
hasTrait<::mlir::OpTrait::IsTerminator>()) {
796 }
else if (isa<CallOpInterface>(op)) {
800 processSimpleOp(op, la, deadVals, finalCleanupList);
804 MLIRContext *context =
module->getContext();
805 cleanUpDeadVals(context, finalCleanupList);
811 SmallVector<Operation *> opsToCanonicalize;
812 module->walk([&](RegionBranchOpInterface regionBranchOp) {
813 opsToCanonicalize.push_back(regionBranchOp.getOperation());
816 RewritePatternSet owningPatterns(context);
818 for (Operation *op : opsToCanonicalize)
820 if (populatedPatterns.insert(*info).second)
821 info->getCanonicalizationPatterns(owningPatterns, context);
823 std::move(owningPatterns)))) {
824 module->emitError("greedy pattern rewrite failed to converge");
Block represents an ordered list of Operations.
BlockArgListType getArguments()
Block * getSuccessor(unsigned i)
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
void erase(unsigned subStart, unsigned subLen=1)
Erase the operands within the given sub-range.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This class represents an operand of an operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
This class provides the API for ops that are known to be terminators.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
void dropAllUses()
Drop all uses of results of this operation.
void setOperand(unsigned idx, Value value)
void eraseOperands(unsigned idx, unsigned length=1)
Erase the operands starting at position idx and ending at position 'idx'+'length'.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
std::optional< RegisteredOperationName > getRegisteredInfo()
If this operation has a registered operation description, return it.
unsigned getNumOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Operation * eraseOpResults(Operation *op, const BitVector &eraseIndices)
Erase the specified results of the given operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class models how operands are forwarded to block arguments in control flow.
void erase(unsigned subStart, unsigned subLen=1)
Erase operands forwarded to the successor.
MutableOperandRange getMutableForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
unsigned size() const
Returns the amount of operands passed to the successor.
This class represents a specific symbol use.
Operation * getUser() const
Return the operation user of this symbol reference.
This class implements a range of SymbolRef uses.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
void dropAllUses()
Drop all uses of this object from their respective owners.
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Include the generated interface declarations.
DenseMap< OpOperand *, SmallVector< Value > > RegionBranchSuccessorMapping
A mapping from successor operands to successor inputs.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
This trait indicates that a terminator operation is "return-like".
This lattice represents, for a given value, whether or not it is "live".
Runs liveness analysis on the IR defined by op.
const Liveness * getLiveness(Value val)