54 #include "llvm/ADT/STLExtras.h"
62 #define GEN_PASS_DEF_REMOVEDEADVALUES
63 #include "mlir/Transforms/Passes.h.inc"
79 struct FunctionToCleanUp {
80 FunctionOpInterface funcOp;
81 BitVector nonLiveArgs;
82 BitVector nonLiveRets;
85 struct OperationToCleanup {
90 struct BlockArgsToCleanup {
92 BitVector nonLiveArgs;
95 struct SuccessorOperandsToCleanup {
96 BranchOpInterface branch;
97 unsigned successorIndex;
98 BitVector nonLiveOperands;
101 struct RDVFinalCleanupList {
117 for (
Value value : values) {
118 if (nonLiveSet.contains(value))
122 if (!liveness || liveness->
isLive)
132 BitVector lives(values.size(),
true);
135 if (nonLiveSet.contains(value)) {
147 if (liveness && !liveness->
isLive)
158 const BitVector &nonLive) {
162 nonLiveSet.insert(result);
168 static void dropUsesAndEraseResults(
Operation *op, BitVector toErase) {
170 "expected the number of results in `op` and the size of `toErase` to "
173 std::vector<Type> newResultTypes;
175 if (!toErase[result.getResultNumber()])
176 newResultTypes.push_back(result.getType());
178 builder.setInsertionPointAfter(op);
189 while (!region.empty())
190 region.front().moveBefore(temp);
194 unsigned indexOfNextNewCallOpResultToReplace = 0;
196 assert(result &&
"expected result to be non-null");
197 if (toErase[index]) {
198 result.dropAllUses();
200 result.replaceAllUsesWith(
201 newOp->
getResult(indexOfNextNewCallOpResultToReplace++));
211 for (
unsigned i = 0, e = operands.size(); i < e; i++)
212 opOperands.push_back(&values[i]);
231 RDVFinalCleanupList &cl) {
235 cl.operations.push_back(op);
236 collectNonLiveValues(nonLiveSet, op->
getResults(),
250 static void processFuncOp(FunctionOpInterface funcOp,
Operation *module,
252 RDVFinalCleanupList &cl) {
253 if (funcOp.isPublic() || funcOp.isExternal())
258 BitVector nonLiveArgs = markLives(arguments, nonLiveSet, la);
259 nonLiveArgs = nonLiveArgs.flip();
263 if (arg && nonLiveArgs[index]) {
264 cl.values.push_back(arg);
265 nonLiveSet.insert(arg);
272 assert(isa<CallOpInterface>(callOp) &&
"expected a call-like user");
278 for (
int index : nonLiveArgs.set_bits())
279 nonLiveCallOperands.set(callOpOperands[index]->getOperandNumber());
280 cl.operands.push_back({callOp, nonLiveCallOperands});
306 Operation *lastReturnOp = funcOp.back().getTerminator();
308 BitVector nonLiveRets(numReturns,
true);
311 assert(isa<CallOpInterface>(callOp) &&
"expected a call-like user");
312 BitVector liveCallRets = markLives(callOp->
getResults(), nonLiveSet, la);
313 nonLiveRets &= liveCallRets.flip();
319 for (
Block &block : funcOp.getBlocks()) {
320 Operation *returnOp = block.getTerminator();
322 cl.operands.push_back({returnOp, nonLiveRets});
326 cl.functions.push_back({funcOp, nonLiveArgs, nonLiveRets});
331 assert(isa<CallOpInterface>(callOp) &&
"expected a call-like user");
332 cl.results.push_back({callOp, nonLiveRets});
333 collectNonLiveValues(nonLiveSet, callOp->
getResults(), nonLiveRets);
366 static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
369 RDVFinalCleanupList &cl) {
371 auto markLiveResults = [&](BitVector &liveResults) {
372 liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la);
377 for (
Region ®ion : regionBranchOp->getRegions()) {
381 BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la);
382 liveArgs[®ion] = regionLiveArgs;
388 auto getSuccessors = [&](
Region *region =
nullptr) {
393 regionBranchOp.getSuccessorRegions(point, successors);
403 terminator ? cast<RegionBranchTerminatorOpInterface>(terminator)
404 .getSuccessorOperands(successor)
405 : regionBranchOp.getEntrySuccessorOperands(successor);
412 auto markNonForwardedOperands = [&](BitVector &nonForwardedOperands) {
413 nonForwardedOperands.resize(regionBranchOp->getNumOperands(),
true);
415 for (
OpOperand *opOperand : getForwardedOpOperands(successor))
416 nonForwardedOperands.reset(opOperand->getOperandNumber());
422 auto markNonForwardedReturnValues =
424 for (
Region ®ion : regionBranchOp->getRegions()) {
428 nonForwardedRets[terminator] =
432 getForwardedOpOperands(successor, terminator))
433 nonForwardedRets[terminator].reset(opOperand->getOperandNumber());
442 auto updateOperandsOrTerminatorOperandsToKeep =
443 [&](BitVector &valuesToKeep, BitVector &resultsToKeep,
446 region ? region->front().getTerminator() :
nullptr;
450 for (
auto [opOperand, input] :
451 llvm::zip(getForwardedOpOperands(successor, terminator),
453 size_t operandNum = opOperand->getOperandNumber();
456 ? argsToKeep[successorRegion]
457 [cast<BlockArgument>(input).getArgNumber()]
458 : resultsToKeep[cast<OpResult>(input).getResultNumber()];
459 valuesToKeep[operandNum] = valuesToKeep[operandNum] | updateBasedOn;
467 auto recomputeResultsAndArgsToKeep =
469 BitVector &operandsToKeep,
471 bool &resultsOrArgsToKeepChanged) {
472 resultsOrArgsToKeepChanged =
false;
477 for (
auto [opOperand, input] :
478 llvm::zip(getForwardedOpOperands(successor),
480 bool recomputeBasedOn =
481 operandsToKeep[opOperand->getOperandNumber()];
484 ? argsToKeep[successorRegion]
485 [cast<BlockArgument>(input).getArgNumber()]
486 : resultsToKeep[cast<OpResult>(input).getResultNumber()];
487 if (!toRecompute && recomputeBasedOn)
488 resultsOrArgsToKeepChanged =
true;
489 if (successorRegion) {
490 argsToKeep[successorRegion][cast<BlockArgument>(input)
492 argsToKeep[successorRegion]
493 [cast<BlockArgument>(input).getArgNumber()] |
496 resultsToKeep[cast<OpResult>(input).getResultNumber()] =
497 resultsToKeep[cast<OpResult>(input).getResultNumber()] |
505 for (
Region ®ion : regionBranchOp->getRegions()) {
508 Operation *terminator = region.front().getTerminator();
511 for (
auto [opOperand, input] :
512 llvm::zip(getForwardedOpOperands(successor, terminator),
514 bool recomputeBasedOn =
515 terminatorOperandsToKeep[region.back().getTerminator()]
516 [opOperand->getOperandNumber()];
519 ? argsToKeep[successorRegion]
520 [cast<BlockArgument>(input).getArgNumber()]
521 : resultsToKeep[cast<OpResult>(input).getResultNumber()];
522 if (!toRecompute && recomputeBasedOn)
523 resultsOrArgsToKeepChanged =
true;
524 if (successorRegion) {
525 argsToKeep[successorRegion][cast<BlockArgument>(input)
527 argsToKeep[successorRegion]
528 [cast<BlockArgument>(input).getArgNumber()] |
531 resultsToKeep[cast<OpResult>(input).getResultNumber()] =
532 resultsToKeep[cast<OpResult>(input).getResultNumber()] |
542 auto markValuesToKeep =
544 BitVector &operandsToKeep,
546 bool resultsOrArgsToKeepChanged =
true;
549 while (resultsOrArgsToKeepChanged) {
551 updateOperandsOrTerminatorOperandsToKeep(operandsToKeep,
552 resultsToKeep, argsToKeep);
555 for (
Region ®ion : regionBranchOp->getRegions()) {
558 updateOperandsOrTerminatorOperandsToKeep(
559 terminatorOperandsToKeep[region.back().getTerminator()],
560 resultsToKeep, argsToKeep, ®ion);
564 recomputeResultsAndArgsToKeep(
565 resultsToKeep, argsToKeep, operandsToKeep,
566 terminatorOperandsToKeep, resultsOrArgsToKeepChanged);
577 !hasLive(regionBranchOp->getResults(), nonLiveSet, la)) {
578 cl.operations.push_back(regionBranchOp.getOperation());
587 BitVector resultsToKeep;
592 BitVector operandsToKeep;
600 markLiveResults(resultsToKeep);
603 markLiveArgs(argsToKeep);
607 markNonForwardedOperands(operandsToKeep);
610 markNonForwardedReturnValues(terminatorOperandsToKeep);
614 markValuesToKeep(resultsToKeep, argsToKeep, operandsToKeep,
615 terminatorOperandsToKeep);
618 cl.operands.push_back({regionBranchOp, operandsToKeep.flip()});
621 for (
Region ®ion : regionBranchOp->getRegions()) {
624 BitVector argsToRemove = argsToKeep[®ion].flip();
625 cl.blocks.push_back({®ion.front(), argsToRemove});
626 collectNonLiveValues(nonLiveSet, region.front().getArguments(),
631 for (
Region ®ion : regionBranchOp->getRegions()) {
634 Operation *terminator = region.front().getTerminator();
635 cl.operands.push_back(
636 {terminator, terminatorOperandsToKeep[terminator].flip()});
640 BitVector resultsToRemove = resultsToKeep.flip();
641 collectNonLiveValues(nonLiveSet, regionBranchOp.getOperation()->getResults(),
643 cl.results.push_back({regionBranchOp.getOperation(), resultsToRemove});
656 RDVFinalCleanupList &cl) {
657 unsigned numSuccessors = branchOp->getNumSuccessors();
659 for (
unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
664 branchOp.getSuccessorOperands(succIdx);
666 for (
unsigned operandIdx = 0; operandIdx < successorOperands.
size();
668 operandValues.push_back(successorOperands[operandIdx]);
672 BitVector successorNonLive =
673 markLives(operandValues, nonLiveSet, la).flip();
674 collectNonLiveValues(nonLiveSet, successorBlock->
getArguments(),
678 cl.blocks.push_back({successorBlock, successorNonLive});
679 cl.successorOperands.push_back({branchOp, succIdx, successorNonLive});
685 static void cleanUpDeadVals(RDVFinalCleanupList &list) {
687 for (
auto &op : list.operations) {
693 for (
auto &v : list.values) {
698 for (
auto &f : list.functions) {
699 f.funcOp.eraseArguments(f.nonLiveArgs);
700 f.funcOp.eraseResults(f.nonLiveRets);
704 for (
auto &o : list.operands) {
705 o.op->eraseOperands(o.nonLive);
709 for (
auto &r : list.results) {
710 dropUsesAndEraseResults(r.op, r.nonLive);
714 for (
auto &b : list.blocks) {
716 if (b.b->getNumArguments() != b.nonLiveArgs.size())
719 for (
int i = b.nonLiveArgs.size() - 1; i >= 0; --i) {
720 if (!b.nonLiveArgs[i])
722 b.b->getArgument(i).dropAllUses();
723 b.b->eraseArgument(i);
728 for (
auto &op : list.successorOperands) {
730 op.branch.getSuccessorOperands(op.successorIndex);
732 if (successorOperands.
size() != op.nonLiveOperands.size())
735 for (
int i = successorOperands.
size() - 1; i >= 0; --i) {
736 if (!op.nonLiveOperands[i])
738 successorOperands.
erase(i);
743 struct RemoveDeadValues :
public impl::RemoveDeadValuesBase<RemoveDeadValues> {
744 void runOnOperation()
override;
748 void RemoveDeadValues::runOnOperation() {
749 auto &la = getAnalysis<RunLivenessAnalysis>();
758 RDVFinalCleanupList finalCleanupList;
761 if (
auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
762 processFuncOp(funcOp, module, la, deadVals, finalCleanupList);
763 }
else if (
auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
764 processRegionBranchOp(regionBranchOp, la, deadVals, finalCleanupList);
765 }
else if (
auto branchOp = dyn_cast<BranchOpInterface>(op)) {
766 processBranchOp(branchOp, la, deadVals, finalCleanupList);
770 }
else if (isa<CallOpInterface>(op)) {
774 processSimpleOp(op, la, deadVals, finalCleanupList);
778 cleanUpDeadVals(finalCleanupList);
782 return std::make_unique<RemoveDeadValues>();
static MutableArrayRef< OpOperand > operandsToOpOperands(OperandRange &operands)
Block represents an ordered list of Operations.
void erase()
Unlink this Block from its parent region and delete it.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
Block * getSuccessor(unsigned i)
This class helps build Operations.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be terminators.
This class implements the operand iterators for the Operation class.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
void dropAllUses()
Drop all uses of results of this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
static constexpr RegionBranchPoint parent()
Returns an instance of RegionBranchPoint representing the parent operation.
This class represents a successor of a region.
ValueRange getSuccessorInputs() const
Return the inputs to the successor that are remapped by the exit values of the current region.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void push_back(Block *block)
This class models how operands are forwarded to block arguments in control flow.
void erase(unsigned subStart, unsigned subLen=1)
Erase operands forwarded to the successor.
unsigned size() const
Returns the amount of operands passed to the successor.
This class represents a specific symbol use.
This class implements a range of SymbolRef uses.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
std::unique_ptr< Pass > createRemoveDeadValuesPass()
Creates an optimization pass to remove dead values.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
This lattice represents, for a given value, whether or not it is "live".
Runs liveness analysis on the IR defined by op.
const Liveness * getLiveness(Value val)