54#include "llvm/ADT/STLExtras.h"
55#include "llvm/Support/Debug.h"
56#include "llvm/Support/DebugLog.h"
63#define DEBUG_TYPE "remove-dead-values"
66#define GEN_PASS_DEF_REMOVEDEADVALUESPASS
67#include "mlir/Transforms/Passes.h.inc"
83struct FunctionToCleanUp {
84 FunctionOpInterface funcOp;
85 BitVector nonLiveArgs;
86 BitVector nonLiveRets;
89struct ResultsToCleanup {
94struct OperandsToCleanup {
98 Operation *callee =
nullptr;
101 bool replaceWithPoison =
false;
104struct BlockArgsToCleanup {
106 BitVector nonLiveArgs;
109struct SuccessorOperandsToCleanup {
110 BranchOpInterface branch;
111 unsigned successorIndex;
112 BitVector nonLiveOperands;
115struct RDVFinalCleanupList {
116 SmallVector<Operation *> operations;
117 SmallVector<FunctionToCleanUp> functions;
118 SmallVector<OperandsToCleanup> operands;
119 SmallVector<ResultsToCleanup> results;
120 SmallVector<BlockArgsToCleanup> blocks;
121 SmallVector<SuccessorOperandsToCleanup> successorOperands;
130 for (
Value value : values) {
131 if (nonLiveSet.contains(value)) {
132 LDBG() <<
"Value " << value <<
" is already marked non-live (dead)";
138 LDBG() <<
"Value " << value
139 <<
" has no liveness info, conservatively considered live";
143 LDBG() <<
"Value " << value <<
" is live according to liveness analysis";
146 LDBG() <<
"Value " << value <<
" is dead according to liveness analysis";
156 BitVector lives(values.size(),
true);
158 for (
auto [
index, value] : llvm::enumerate(values)) {
159 if (nonLiveSet.contains(value)) {
161 LDBG() <<
"Value " << value
162 <<
" is already marked non-live (dead) at index " <<
index;
174 LDBG() <<
"Value " << value <<
" at index " <<
index
175 <<
" has no liveness info, conservatively considered live";
180 LDBG() <<
"Value " << value <<
" at index " <<
index
181 <<
" is dead according to liveness analysis";
183 LDBG() <<
"Value " << value <<
" at index " <<
index
184 <<
" is live according to liveness analysis";
195 const BitVector &nonLive) {
196 for (
auto [
index,
result] : llvm::enumerate(range)) {
199 nonLiveSet.insert(
result);
200 LDBG() <<
"Marking value " <<
result <<
" as non-live (dead) at index "
210 "expected the number of results in `op` and the size of `toErase` to "
212 for (
auto idx : toErase.set_bits())
232 RDVFinalCleanupList &cl) {
236 bool hasDeadOperand =
237 markLives(op->
getOperands(), nonLiveSet, la).flip().any();
238 if (hasDeadOperand) {
239 LDBG() <<
"Simple op has dead operands, so the op must be dead: "
242 assert(!hasLive(op->
getResults(), nonLiveSet, la) &&
243 "expected the op to have no live results");
244 cl.operations.push_back(op);
245 collectNonLiveValues(nonLiveSet, op->
getResults(),
251 LDBG() <<
"Simple op is not memory effect free or has live results, "
259 <<
"Simple op has all dead results and is memory effect free, scheduling "
262 cl.operations.push_back(op);
263 collectNonLiveValues(nonLiveSet, op->
getResults(),
277static void processFuncOp(FunctionOpInterface funcOp,
Operation *module,
279 RDVFinalCleanupList &cl) {
280 LDBG() <<
"Processing function op: "
283 if (funcOp.isPublic() || funcOp.isExternal()) {
284 LDBG() <<
"Function is public or external, skipping: "
285 << funcOp.getOperation()->getName();
290 return !isa<CallOpInterface>(use.
getUser());
299 BitVector nonLiveArgs = markLives(arguments, nonLiveSet, la);
300 nonLiveArgs = nonLiveArgs.flip();
303 for (
auto [
index, arg] : llvm::enumerate(arguments))
304 if (arg && nonLiveArgs[
index])
305 nonLiveSet.insert(arg);
316 cl.operands.push_back({callOp, BitVector(callOp->
getNumOperands(),
false),
317 funcOp.getOperation()});
343 size_t numReturns = funcOp.getNumResults();
344 BitVector nonLiveRets(numReturns,
true);
347 assert(isa<CallOpInterface>(callOp) &&
"expected a call-like user");
348 BitVector liveCallRets = markLives(callOp->
getResults(), nonLiveSet, la);
349 nonLiveRets &= liveCallRets.flip();
355 for (
Block &block : funcOp.getBlocks()) {
356 Operation *returnOp = block.getTerminator();
360 cl.operands.push_back({returnOp, nonLiveRets});
364 cl.functions.push_back({funcOp, nonLiveArgs, nonLiveRets});
371 assert(isa<CallOpInterface>(callOp) &&
"expected a call-like user");
372 cl.results.push_back({callOp, nonLiveRets});
373 collectNonLiveValues(nonLiveSet, callOp->
getResults(), nonLiveRets);
396static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
399 RDVFinalCleanupList &cl) {
400 LDBG() <<
"Processing region branch op: "
410 !hasLive(regionBranchOp->getResults(), nonLiveSet, la)) {
411 cl.operations.push_back(regionBranchOp.getOperation());
435 regionBranchOp.getSuccessorOperandInputMapping(operandToSuccessorInputs);
438 for (
auto [opOperand, successorInputs] : operandToSuccessorInputs) {
441 auto markOperandDead = [&opOperand = opOperand, &deadOperandsPerOp]() {
446 BitVector &deadOperands =
448 .try_emplace(opOperand->getOwner(),
449 opOperand->getOwner()->getNumOperands(),
false)
451 deadOperands.set(opOperand->getOperandNumber());
455 if (!hasLive(opOperand->get(), nonLiveSet, la)) {
462 if (!hasLive(successorInputs, nonLiveSet, la))
466 for (
auto [op, deadOperands] : deadOperandsPerOp) {
467 cl.operands.push_back(
468 {op, deadOperands,
nullptr,
true});
486 RDVFinalCleanupList &cl) {
487 LDBG() <<
"Processing branch op: " << *branchOp;
490 BitVector deadNonForwardedOperands =
491 markLives(branchOp->getOperands(), nonLiveSet, la).flip();
492 unsigned numSuccessors = branchOp->getNumSuccessors();
493 for (
unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
495 branchOp.getSuccessorOperands(succIdx);
498 deadNonForwardedOperands[opOperand.getOperandNumber()] =
false;
500 if (deadNonForwardedOperands.any()) {
501 cl.operations.push_back(branchOp.getOperation());
505 for (
unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
510 branchOp.getSuccessorOperands(succIdx);
512 for (
unsigned operandIdx = 0; operandIdx < successorOperands.
size();
514 operandValues.push_back(successorOperands[operandIdx]);
518 BitVector successorNonLive =
519 markLives(operandValues, nonLiveSet, la).flip();
520 collectNonLiveValues(nonLiveSet, successorBlock->
getArguments(),
524 cl.blocks.push_back({successorBlock, successorNonLive});
525 cl.successorOperands.push_back({branchOp, succIdx, successorNonLive});
533 return llvm::map_to_vector(values, [&](
Value value) {
536 return ub::PoisonOp::create(
b, value.
getLoc(), value.
getType()).getResult();
543 void notifyOperationErased(Operation *op)
override {
544 if (
auto poisonOp = dyn_cast<ub::PoisonOp>(op))
545 poisonOps.erase(poisonOp);
547 void notifyOperationInserted(Operation *op,
548 OpBuilder::InsertPoint previous)
override {
549 if (
auto poisonOp = dyn_cast<ub::PoisonOp>(op))
550 poisonOps.insert(poisonOp);
558static void cleanUpDeadVals(
MLIRContext *ctx, RDVFinalCleanupList &list) {
559 LDBG() <<
"Starting cleanup of dead values...";
564 TrackingListener listener;
569 LDBG() <<
"Cleaning up " << list.blocks.size() <<
" block argument lists";
570 for (
auto &
b : list.blocks) {
572 if (
b.b->getNumArguments() !=
b.nonLiveArgs.size())
575 os <<
"Erasing non-live arguments [";
576 llvm::interleaveComma(
b.nonLiveArgs.set_bits(), os);
577 os <<
"] from block #" <<
b.b->computeBlockNumber() <<
" in region #"
578 <<
b.b->getParent()->getRegionNumber() <<
" of operation "
584 for (
int i =
b.nonLiveArgs.size() - 1; i >= 0; --i) {
585 if (!
b.nonLiveArgs[i])
587 b.b->getArgument(i).dropAllUses();
588 b.b->eraseArgument(i);
593 LDBG() <<
"Cleaning up " << list.successorOperands.size()
594 <<
" successor operand lists";
595 for (
auto &op : list.successorOperands) {
597 op.branch.getSuccessorOperands(op.successorIndex);
599 if (successorOperands.
size() != op.nonLiveOperands.size())
602 os <<
"Erasing non-live successor operands [";
603 llvm::interleaveComma(op.nonLiveOperands.set_bits(), os);
604 os <<
"] from successor " << op.successorIndex <<
" of branch: "
609 for (
int i = successorOperands.
size() - 1; i >= 0; --i) {
610 if (!op.nonLiveOperands[i])
612 successorOperands.
erase(i);
617 LDBG() <<
"Cleaning up " << list.functions.size() <<
" functions";
622 for (
auto &f : list.functions) {
623 LDBG() <<
"Cleaning up function: " << f.funcOp.getName() <<
" ("
624 << f.funcOp.getOperation() <<
")";
626 os <<
" Erasing non-live arguments [";
627 llvm::interleaveComma(f.nonLiveArgs.set_bits(), os);
629 os <<
" Erasing non-live return values [";
630 llvm::interleaveComma(f.nonLiveRets.set_bits(), os);
634 for (
auto deadIdx : f.nonLiveArgs.set_bits())
635 f.funcOp.getArgument(deadIdx).dropAllUses();
639 if (succeeded(f.funcOp.eraseArguments(f.nonLiveArgs))) {
641 if (f.nonLiveArgs.any())
642 erasedFuncArgs.try_emplace(f.funcOp.getOperation(), f.nonLiveArgs);
644 LDBG() <<
"Failed to erase arguments for function: "
645 << f.funcOp.getName();
647 (
void)f.funcOp.eraseResults(f.nonLiveRets);
651 LDBG() <<
"Cleaning up " << list.operands.size() <<
" operand lists";
652 for (OperandsToCleanup &o : list.operands) {
656 bool handledAsCall =
false;
657 if (o.callee && isa<CallOpInterface>(o.op)) {
658 auto call = cast<CallOpInterface>(o.op);
659 auto it = erasedFuncArgs.find(o.callee);
660 if (it != erasedFuncArgs.end()) {
661 const BitVector &deadArgIdxs = it->second;
665 for (
unsigned argIdx : llvm::reverse(deadArgIdxs.set_bits()))
671 if (o.nonLive.any()) {
673 int operandOffset = call.getArgOperands().getBeginOperandIndex();
674 for (
int argIdx : deadArgIdxs.set_bits()) {
675 int operandNumber = operandOffset + argIdx;
676 if (operandNumber <
static_cast<int>(o.nonLive.size()))
677 o.nonLive.reset(operandNumber);
680 handledAsCall =
true;
687 if (!handledAsCall && o.nonLive.any()) {
689 os <<
"Erasing non-live operands [";
690 llvm::interleaveComma(o.nonLive.set_bits(), os);
691 os <<
"] from operation: "
695 if (o.replaceWithPoison) {
697 for (
auto deadIdx : o.nonLive.set_bits()) {
699 deadIdx, createPoisonedValues(rewriter, o.op->
getOperand(deadIdx))
709 LDBG() <<
"Cleaning up " << list.results.size() <<
" result lists";
710 for (
auto &r : list.results) {
712 os <<
"Erasing non-live results [";
713 llvm::interleaveComma(r.nonLive.set_bits(), os);
714 os <<
"] from operation: "
718 dropUsesAndEraseResults(rewriter, r.op, r.nonLive);
722 LDBG() <<
"Cleaning up " << list.operations.size() <<
" operations";
724 LDBG() <<
"Erasing operation: "
730 ub::UnreachableOp::create(rewriter, op->
getLoc());
743 for (
Value opResult : opResults) {
746 if (opResult.use_empty())
750 Value poisonedValue = createPoisonedValues(rewriter, opResult).front();
759 for (ub::PoisonOp poisonOp : listener.poisonOps) {
760 if (poisonOp.use_empty())
764 LDBG() <<
"Finished cleanup of dead values";
767struct RemoveDeadValues
769 using impl::RemoveDeadValuesPassBase<
770 RemoveDeadValues>::RemoveDeadValuesPassBase;
771 void runOnOperation()
override;
775void RemoveDeadValues::runOnOperation() {
776 auto &la = getAnalysis<RunLivenessAnalysis>();
777 Operation *module = getOperation();
785 RDVFinalCleanupList finalCleanupList;
787 module->walk([&](Operation *op) {
788 if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
789 processFuncOp(funcOp, module, la, deadVals, finalCleanupList);
790 }
else if (
auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
791 processRegionBranchOp(regionBranchOp, la, deadVals, finalCleanupList);
792 }
else if (
auto branchOp = dyn_cast<BranchOpInterface>(op)) {
793 processBranchOp(branchOp, la, deadVals, finalCleanupList);
794 }
else if (op->
hasTrait<::mlir::OpTrait::IsTerminator>()) {
797 }
else if (isa<CallOpInterface>(op)) {
801 processSimpleOp(op, la, deadVals, finalCleanupList);
805 MLIRContext *context =
module->getContext();
806 cleanUpDeadVals(context, finalCleanupList);
812 SmallVector<Operation *> opsToCanonicalize;
813 module->walk([&](RegionBranchOpInterface regionBranchOp) {
814 opsToCanonicalize.push_back(regionBranchOp.getOperation());
817 RewritePatternSet owningPatterns(context);
819 for (Operation *op : opsToCanonicalize)
821 if (populatedPatterns.insert(*info).second)
822 info->getCanonicalizationPatterns(owningPatterns, context);
824 std::move(owningPatterns)))) {
825 module->emitError("greedy pattern rewrite failed to converge");
Block represents an ordered list of Operations.
BlockArgListType getArguments()
Block * getSuccessor(unsigned i)
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
void erase(unsigned subStart, unsigned subLen=1)
Erase the operands within the given sub-range.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This class represents an operand of an operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
This class provides the API for ops that are known to be terminators.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
void dropAllUses()
Drop all uses of results of this operation.
void setOperand(unsigned idx, Value value)
void eraseOperands(unsigned idx, unsigned length=1)
Erase the operands starting at position idx and ending at position 'idx'+'length'.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
std::optional< RegisteredOperationName > getRegisteredInfo()
If this operation has a registered operation description, return it.
unsigned getNumOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Operation * eraseOpResults(Operation *op, const BitVector &eraseIndices)
Erase the specified results of the given operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class models how operands are forwarded to block arguments in control flow.
void erase(unsigned subStart, unsigned subLen=1)
Erase operands forwarded to the successor.
MutableOperandRange getMutableForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
unsigned size() const
Returns the amount of operands passed to the successor.
This class represents a specific symbol use.
Operation * getUser() const
Return the operation user of this symbol reference.
This class implements a range of SymbolRef uses.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
void dropAllUses()
Drop all uses of this object from their respective owners.
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Include the generated interface declarations.
DenseMap< OpOperand *, SmallVector< Value > > RegionBranchSuccessorMapping
A mapping from successor operands to successor inputs.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
This trait indicates that a terminator operation is "return-like".
This lattice represents, for a given value, whether or not it is "live".
Runs liveness analysis on the IR defined by op.
const Liveness * getLiveness(Value val)