49#include "llvm/ADT/STLExtras.h"
50#include "llvm/Support/Debug.h"
51#include "llvm/Support/DebugLog.h"
58#define DEBUG_TYPE "remove-dead-values"
61#define GEN_PASS_DEF_REMOVEDEADVALUESPASS
62#include "mlir/Transforms/Passes.h.inc"
78struct FunctionToCleanUp {
79 FunctionOpInterface funcOp;
80 BitVector nonLiveArgs;
81 BitVector nonLiveRets;
84struct ResultsToCleanup {
89struct OperandsToCleanup {
93 Operation *callee =
nullptr;
96 bool replaceWithPoison =
false;
99struct BlockArgsToCleanup {
101 BitVector nonLiveArgs;
104struct SuccessorOperandsToCleanup {
105 BranchOpInterface branch;
106 unsigned successorIndex;
107 BitVector nonLiveOperands;
110struct RDVFinalCleanupList {
111 SmallVector<Operation *> operations;
112 SmallVector<FunctionToCleanUp> functions;
113 SmallVector<OperandsToCleanup> operands;
114 SmallVector<ResultsToCleanup> results;
115 SmallVector<BlockArgsToCleanup> blocks;
116 SmallVector<SuccessorOperandsToCleanup> successorOperands;
125 for (
Value value : values) {
126 if (nonLiveSet.contains(value)) {
127 LDBG() <<
"Value " << value <<
" is already marked non-live (dead)";
133 LDBG() <<
"Value " << value
134 <<
" has no liveness info, conservatively considered live";
138 LDBG() <<
"Value " << value <<
" is live according to liveness analysis";
141 LDBG() <<
"Value " << value <<
" is dead according to liveness analysis";
150 BitVector lives(values.size(),
true);
152 for (
auto [
index, value] : llvm::enumerate(values)) {
153 if (nonLiveSet.contains(value)) {
155 LDBG() <<
"Value " << value
156 <<
" is already marked non-live (dead) at index " <<
index;
168 LDBG() <<
"Value " << value <<
" at index " <<
index
169 <<
" has no liveness info, conservatively considered live";
174 LDBG() <<
"Value " << value <<
" at index " <<
index
175 <<
" is dead according to liveness analysis";
177 LDBG() <<
"Value " << value <<
" at index " <<
index
178 <<
" is live according to liveness analysis";
189 const BitVector &nonLive) {
190 for (
auto [
index,
result] : llvm::enumerate(range)) {
193 nonLiveSet.insert(
result);
194 LDBG() <<
"Marking value " <<
result <<
" as non-live (dead) at index "
204 "expected the number of results in `op` and the size of `toErase` to "
206 for (
auto idx : toErase.set_bits())
226 RDVFinalCleanupList &cl) {
230 bool hasDeadOperand =
231 markLives(op->
getOperands(), nonLiveSet, la).flip().any();
232 if (hasDeadOperand) {
233 LDBG() <<
"Simple op has dead operands, so the op must be dead: "
236 assert(!hasLive(op->
getResults(), nonLiveSet, la) &&
237 "expected the op to have no live results");
238 cl.operations.push_back(op);
239 collectNonLiveValues(nonLiveSet, op->
getResults(),
245 LDBG() <<
"Simple op is not memory effect free or has live results, "
253 <<
"Simple op has all dead results and is memory effect free, scheduling "
256 cl.operations.push_back(op);
257 collectNonLiveValues(nonLiveSet, op->
getResults(),
271static void processFuncOp(FunctionOpInterface funcOp,
Operation *module,
273 RDVFinalCleanupList &cl) {
274 LDBG() <<
"Processing function op: "
277 if (funcOp.isPublic() || funcOp.isExternal()) {
278 LDBG() <<
"Function is public or external, skipping: "
279 << funcOp.getOperation()->getName();
284 return !isa<CallOpInterface>(use.
getUser());
293 BitVector nonLiveArgs = markLives(arguments, nonLiveSet, la);
294 nonLiveArgs = nonLiveArgs.flip();
297 for (
auto [
index, arg] : llvm::enumerate(arguments))
298 if (arg && nonLiveArgs[
index])
299 nonLiveSet.insert(arg);
310 cl.operands.push_back({callOp, BitVector(callOp->
getNumOperands(),
false),
311 funcOp.getOperation()});
337 size_t numReturns = funcOp.getNumResults();
338 BitVector nonLiveRets(numReturns,
true);
341 assert(isa<CallOpInterface>(callOp) &&
"expected a call-like user");
342 BitVector liveCallRets = markLives(callOp->
getResults(), nonLiveSet, la);
343 nonLiveRets &= liveCallRets.flip();
349 for (
Block &block : funcOp.getBlocks()) {
350 Operation *returnOp = block.getTerminator();
354 cl.operands.push_back({returnOp, nonLiveRets});
358 cl.functions.push_back({funcOp, nonLiveArgs, nonLiveRets});
365 assert(isa<CallOpInterface>(callOp) &&
"expected a call-like user");
366 cl.results.push_back({callOp, nonLiveRets});
367 collectNonLiveValues(nonLiveSet, callOp->
getResults(), nonLiveRets);
390static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
393 RDVFinalCleanupList &cl) {
394 LDBG() <<
"Processing region branch op: "
404 !hasLive(regionBranchOp->getResults(), nonLiveSet, la)) {
405 cl.operations.push_back(regionBranchOp.getOperation());
429 regionBranchOp.getSuccessorOperandInputMapping(operandToSuccessorInputs);
432 for (
auto [opOperand, successorInputs] : operandToSuccessorInputs) {
435 auto markOperandDead = [&opOperand = opOperand, &deadOperandsPerOp]() {
440 BitVector &deadOperands =
442 .try_emplace(opOperand->getOwner(),
443 opOperand->getOwner()->getNumOperands(),
false)
445 deadOperands.set(opOperand->getOperandNumber());
449 if (!hasLive(opOperand->get(), nonLiveSet, la)) {
456 if (!hasLive(successorInputs, nonLiveSet, la))
460 for (
auto [op, deadOperands] : deadOperandsPerOp) {
461 cl.operands.push_back(
462 {op, deadOperands,
nullptr,
true});
480 RDVFinalCleanupList &cl) {
481 LDBG() <<
"Processing branch op: " << *branchOp;
484 BitVector deadNonForwardedOperands =
485 markLives(branchOp->getOperands(), nonLiveSet, la).flip();
486 unsigned numSuccessors = branchOp->getNumSuccessors();
487 for (
unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
489 branchOp.getSuccessorOperands(succIdx);
492 deadNonForwardedOperands[opOperand.getOperandNumber()] =
false;
494 if (deadNonForwardedOperands.any()) {
495 cl.operations.push_back(branchOp.getOperation());
499 for (
unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
504 branchOp.getSuccessorOperands(succIdx);
506 for (
unsigned operandIdx = 0; operandIdx < successorOperands.
size();
508 operandValues.push_back(successorOperands[operandIdx]);
512 BitVector successorNonLive =
513 markLives(operandValues, nonLiveSet, la).flip();
514 collectNonLiveValues(nonLiveSet, successorBlock->
getArguments(),
518 cl.blocks.push_back({successorBlock, successorNonLive});
519 cl.successorOperands.push_back({branchOp, succIdx, successorNonLive});
527 return llvm::map_to_vector(values, [&](
Value value) {
530 return ub::PoisonOp::create(
b, value.
getLoc(), value.
getType()).getResult();
537 void notifyOperationErased(Operation *op)
override {
538 if (
auto poisonOp = dyn_cast<ub::PoisonOp>(op))
539 poisonOps.erase(poisonOp);
541 void notifyOperationInserted(Operation *op,
542 OpBuilder::InsertPoint previous)
override {
543 if (
auto poisonOp = dyn_cast<ub::PoisonOp>(op))
544 poisonOps.insert(poisonOp);
552static void cleanUpDeadVals(
MLIRContext *ctx, RDVFinalCleanupList &list) {
553 LDBG() <<
"Starting cleanup of dead values...";
558 TrackingListener listener;
563 LDBG() <<
"Cleaning up " << list.blocks.size() <<
" block argument lists";
564 for (
auto &
b : list.blocks) {
566 if (
b.b->getNumArguments() !=
b.nonLiveArgs.size())
569 os <<
"Erasing non-live arguments [";
570 llvm::interleaveComma(
b.nonLiveArgs.set_bits(), os);
571 os <<
"] from block #" <<
b.b->computeBlockNumber() <<
" in region #"
572 <<
b.b->getParent()->getRegionNumber() <<
" of operation "
578 for (
int i =
b.nonLiveArgs.size() - 1; i >= 0; --i) {
579 if (!
b.nonLiveArgs[i])
581 b.b->getArgument(i).dropAllUses();
582 b.b->eraseArgument(i);
587 LDBG() <<
"Cleaning up " << list.successorOperands.size()
588 <<
" successor operand lists";
589 for (
auto &op : list.successorOperands) {
591 op.branch.getSuccessorOperands(op.successorIndex);
593 if (successorOperands.
size() != op.nonLiveOperands.size())
596 os <<
"Erasing non-live successor operands [";
597 llvm::interleaveComma(op.nonLiveOperands.set_bits(), os);
598 os <<
"] from successor " << op.successorIndex <<
" of branch: "
603 for (
int i = successorOperands.
size() - 1; i >= 0; --i) {
604 if (!op.nonLiveOperands[i])
606 successorOperands.
erase(i);
611 LDBG() <<
"Cleaning up " << list.functions.size() <<
" functions";
616 for (
auto &f : list.functions) {
617 LDBG() <<
"Cleaning up function: " << f.funcOp.getName() <<
" ("
618 << f.funcOp.getOperation() <<
")";
620 os <<
" Erasing non-live arguments [";
621 llvm::interleaveComma(f.nonLiveArgs.set_bits(), os);
623 os <<
" Erasing non-live return values [";
624 llvm::interleaveComma(f.nonLiveRets.set_bits(), os);
628 for (
auto deadIdx : f.nonLiveArgs.set_bits())
629 f.funcOp.getArgument(deadIdx).dropAllUses();
633 if (succeeded(f.funcOp.eraseArguments(f.nonLiveArgs))) {
635 if (f.nonLiveArgs.any())
636 erasedFuncArgs.try_emplace(f.funcOp.getOperation(), f.nonLiveArgs);
638 LDBG() <<
"Failed to erase arguments for function: "
639 << f.funcOp.getName();
641 (
void)f.funcOp.eraseResults(f.nonLiveRets);
645 LDBG() <<
"Cleaning up " << list.operands.size() <<
" operand lists";
646 for (OperandsToCleanup &o : list.operands) {
650 bool handledAsCall =
false;
651 if (o.callee && isa<CallOpInterface>(o.op)) {
652 auto call = cast<CallOpInterface>(o.op);
653 auto it = erasedFuncArgs.find(o.callee);
654 if (it != erasedFuncArgs.end()) {
655 const BitVector &deadArgIdxs = it->second;
659 for (
unsigned argIdx : llvm::reverse(deadArgIdxs.set_bits()))
665 if (o.nonLive.any()) {
667 int operandOffset = call.getArgOperands().getBeginOperandIndex();
668 for (
int argIdx : deadArgIdxs.set_bits()) {
669 int operandNumber = operandOffset + argIdx;
670 if (operandNumber <
static_cast<int>(o.nonLive.size()))
671 o.nonLive.reset(operandNumber);
674 handledAsCall =
true;
681 if (!handledAsCall && o.nonLive.any()) {
683 os <<
"Erasing non-live operands [";
684 llvm::interleaveComma(o.nonLive.set_bits(), os);
685 os <<
"] from operation: "
689 if (o.replaceWithPoison) {
691 for (
auto deadIdx : o.nonLive.set_bits()) {
693 deadIdx, createPoisonedValues(rewriter, o.op->
getOperand(deadIdx))
703 LDBG() <<
"Cleaning up " << list.results.size() <<
" result lists";
704 for (
auto &r : list.results) {
706 os <<
"Erasing non-live results [";
707 llvm::interleaveComma(r.nonLive.set_bits(), os);
708 os <<
"] from operation: "
712 dropUsesAndEraseResults(rewriter, r.op, r.nonLive);
716 LDBG() <<
"Cleaning up " << list.operations.size() <<
" operations";
718 LDBG() <<
"Erasing operation: "
724 ub::UnreachableOp::create(rewriter, op->
getLoc());
737 for (
Value opResult : opResults) {
740 if (opResult.use_empty())
744 Value poisonedValue = createPoisonedValues(rewriter, opResult).front();
753 for (ub::PoisonOp poisonOp : listener.poisonOps) {
754 if (poisonOp.use_empty())
758 LDBG() <<
"Finished cleanup of dead values";
761struct RemoveDeadValues
763 using impl::RemoveDeadValuesPassBase<
764 RemoveDeadValues>::RemoveDeadValuesPassBase;
765 void runOnOperation()
override;
769void RemoveDeadValues::runOnOperation() {
770 auto &la = getAnalysis<RunLivenessAnalysis>();
771 Operation *module = getOperation();
779 RDVFinalCleanupList finalCleanupList;
781 module->walk([&](Operation *op) {
782 if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
783 processFuncOp(funcOp, module, la, deadVals, finalCleanupList);
784 }
else if (
auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
785 processRegionBranchOp(regionBranchOp, la, deadVals, finalCleanupList);
786 }
else if (
auto branchOp = dyn_cast<BranchOpInterface>(op)) {
787 processBranchOp(branchOp, la, deadVals, finalCleanupList);
788 }
else if (op->
hasTrait<::mlir::OpTrait::IsTerminator>()) {
791 }
else if (isa<CallOpInterface>(op)) {
795 processSimpleOp(op, la, deadVals, finalCleanupList);
799 MLIRContext *context =
module->getContext();
800 cleanUpDeadVals(context, finalCleanupList);
806 SmallVector<Operation *> opsToCanonicalize;
807 module->walk([&](RegionBranchOpInterface regionBranchOp) {
808 opsToCanonicalize.push_back(regionBranchOp.getOperation());
811 RewritePatternSet owningPatterns(context);
813 for (Operation *op : opsToCanonicalize)
815 if (populatedPatterns.insert(*info).second)
816 info->getCanonicalizationPatterns(owningPatterns, context);
818 std::move(owningPatterns)))) {
819 module->emitError("greedy pattern rewrite failed to converge");
Block represents an ordered list of Operations.
BlockArgListType getArguments()
Block * getSuccessor(unsigned i)
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
void erase(unsigned subStart, unsigned subLen=1)
Erase the operands within the given sub-range.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This class represents an operand of an operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
This class provides the API for ops that are known to be terminators.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
void dropAllUses()
Drop all uses of results of this operation.
void setOperand(unsigned idx, Value value)
void eraseOperands(unsigned idx, unsigned length=1)
Erase the operands starting at position idx and ending at position 'idx'+'length'.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
std::optional< RegisteredOperationName > getRegisteredInfo()
If this operation has a registered operation description, return it.
unsigned getNumOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Operation * eraseOpResults(Operation *op, const BitVector &eraseIndices)
Erase the specified results of the given operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class models how operands are forwarded to block arguments in control flow.
void erase(unsigned subStart, unsigned subLen=1)
Erase operands forwarded to the successor.
MutableOperandRange getMutableForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
unsigned size() const
Returns the amount of operands passed to the successor.
This class represents a specific symbol use.
Operation * getUser() const
Return the operation user of this symbol reference.
This class implements a range of SymbolRef uses.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
void dropAllUses()
Drop all uses of this object from their respective owners.
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Include the generated interface declarations.
DenseMap< OpOperand *, SmallVector< Value > > RegionBranchSuccessorMapping
A mapping from successor operands to successor inputs.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
This trait indicates that a terminator operation is "return-like".
This lattice represents, for a given value, whether or not it is "live".
Runs liveness analysis on the IR defined by op.
const Liveness * getLiveness(Value val)