54 #include "llvm/ADT/STLExtras.h"
62 #define GEN_PASS_DEF_REMOVEDEADVALUES
63 #include "mlir/Transforms/Passes.h.inc"
79 struct FunctionToCleanUp {
80 FunctionOpInterface funcOp;
81 BitVector nonLiveArgs;
82 BitVector nonLiveRets;
85 struct OperationToCleanup {
90 struct BlockArgsToCleanup {
92 BitVector nonLiveArgs;
95 struct SuccessorOperandsToCleanup {
96 BranchOpInterface branch;
97 unsigned successorIndex;
98 BitVector nonLiveOperands;
101 struct RDVFinalCleanupList {
117 for (
Value value : values) {
118 if (nonLiveSet.contains(value))
122 if (!liveness || liveness->
isLive)
132 BitVector lives(values.size(),
true);
135 if (nonLiveSet.contains(value)) {
147 if (liveness && !liveness->
isLive)
158 const BitVector &nonLive) {
162 nonLiveSet.insert(result);
168 static void dropUsesAndEraseResults(
Operation *op, BitVector toErase) {
170 "expected the number of results in `op` and the size of `toErase` to "
173 std::vector<Type> newResultTypes;
175 if (!toErase[result.getResultNumber()])
176 newResultTypes.push_back(result.getType());
178 builder.setInsertionPointAfter(op);
189 while (!region.empty())
190 region.front().moveBefore(temp);
194 unsigned indexOfNextNewCallOpResultToReplace = 0;
196 assert(result &&
"expected result to be non-null");
197 if (toErase[index]) {
198 result.dropAllUses();
200 result.replaceAllUsesWith(
201 newOp->
getResult(indexOfNextNewCallOpResultToReplace++));
211 for (
unsigned i = 0, e = operands.size(); i < e; i++)
212 opOperands.push_back(&values[i]);
231 RDVFinalCleanupList &cl) {
235 cl.operations.push_back(op);
236 collectNonLiveValues(nonLiveSet, op->
getResults(),
250 static void processFuncOp(FunctionOpInterface funcOp,
Operation *module,
252 RDVFinalCleanupList &cl) {
253 if (funcOp.isPublic() || funcOp.isExternal())
258 BitVector nonLiveArgs = markLives(arguments, nonLiveSet, la);
259 nonLiveArgs = nonLiveArgs.flip();
263 if (arg && nonLiveArgs[index]) {
264 cl.values.push_back(arg);
265 nonLiveSet.insert(arg);
272 assert(isa<CallOpInterface>(callOp) &&
"expected a call-like user");
278 for (
int index : nonLiveArgs.set_bits())
279 nonLiveCallOperands.set(callOpOperands[index]->getOperandNumber());
280 cl.operands.push_back({callOp, nonLiveCallOperands});
306 Operation *lastReturnOp = funcOp.back().getTerminator();
310 BitVector nonLiveRets(numReturns,
true);
313 assert(isa<CallOpInterface>(callOp) &&
"expected a call-like user");
314 BitVector liveCallRets = markLives(callOp->
getResults(), nonLiveSet, la);
315 nonLiveRets &= liveCallRets.flip();
321 for (
Block &block : funcOp.getBlocks()) {
322 Operation *returnOp = block.getTerminator();
324 cl.operands.push_back({returnOp, nonLiveRets});
328 cl.functions.push_back({funcOp, nonLiveArgs, nonLiveRets});
333 assert(isa<CallOpInterface>(callOp) &&
"expected a call-like user");
334 cl.results.push_back({callOp, nonLiveRets});
335 collectNonLiveValues(nonLiveSet, callOp->
getResults(), nonLiveRets);
368 static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
371 RDVFinalCleanupList &cl) {
373 auto markLiveResults = [&](BitVector &liveResults) {
374 liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la);
379 for (
Region ®ion : regionBranchOp->getRegions()) {
383 BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la);
384 liveArgs[®ion] = regionLiveArgs;
390 auto getSuccessors = [&](
Region *region =
nullptr) {
395 regionBranchOp.getSuccessorRegions(point, successors);
405 terminator ? cast<RegionBranchTerminatorOpInterface>(terminator)
406 .getSuccessorOperands(successor)
407 : regionBranchOp.getEntrySuccessorOperands(successor);
414 auto markNonForwardedOperands = [&](BitVector &nonForwardedOperands) {
415 nonForwardedOperands.resize(regionBranchOp->getNumOperands(),
true);
417 for (
OpOperand *opOperand : getForwardedOpOperands(successor))
418 nonForwardedOperands.reset(opOperand->getOperandNumber());
424 auto markNonForwardedReturnValues =
426 for (
Region ®ion : regionBranchOp->getRegions()) {
430 nonForwardedRets[terminator] =
434 getForwardedOpOperands(successor, terminator))
435 nonForwardedRets[terminator].reset(opOperand->getOperandNumber());
444 auto updateOperandsOrTerminatorOperandsToKeep =
445 [&](BitVector &valuesToKeep, BitVector &resultsToKeep,
448 region ? region->front().getTerminator() :
nullptr;
452 for (
auto [opOperand, input] :
453 llvm::zip(getForwardedOpOperands(successor, terminator),
455 size_t operandNum = opOperand->getOperandNumber();
458 ? argsToKeep[successorRegion]
459 [cast<BlockArgument>(input).getArgNumber()]
460 : resultsToKeep[cast<OpResult>(input).getResultNumber()];
461 valuesToKeep[operandNum] = valuesToKeep[operandNum] | updateBasedOn;
469 auto recomputeResultsAndArgsToKeep =
471 BitVector &operandsToKeep,
473 bool &resultsOrArgsToKeepChanged) {
474 resultsOrArgsToKeepChanged =
false;
479 for (
auto [opOperand, input] :
480 llvm::zip(getForwardedOpOperands(successor),
482 bool recomputeBasedOn =
483 operandsToKeep[opOperand->getOperandNumber()];
486 ? argsToKeep[successorRegion]
487 [cast<BlockArgument>(input).getArgNumber()]
488 : resultsToKeep[cast<OpResult>(input).getResultNumber()];
489 if (!toRecompute && recomputeBasedOn)
490 resultsOrArgsToKeepChanged =
true;
491 if (successorRegion) {
492 argsToKeep[successorRegion][cast<BlockArgument>(input)
494 argsToKeep[successorRegion]
495 [cast<BlockArgument>(input).getArgNumber()] |
498 resultsToKeep[cast<OpResult>(input).getResultNumber()] =
499 resultsToKeep[cast<OpResult>(input).getResultNumber()] |
507 for (
Region ®ion : regionBranchOp->getRegions()) {
510 Operation *terminator = region.front().getTerminator();
513 for (
auto [opOperand, input] :
514 llvm::zip(getForwardedOpOperands(successor, terminator),
516 bool recomputeBasedOn =
517 terminatorOperandsToKeep[region.back().getTerminator()]
518 [opOperand->getOperandNumber()];
521 ? argsToKeep[successorRegion]
522 [cast<BlockArgument>(input).getArgNumber()]
523 : resultsToKeep[cast<OpResult>(input).getResultNumber()];
524 if (!toRecompute && recomputeBasedOn)
525 resultsOrArgsToKeepChanged =
true;
526 if (successorRegion) {
527 argsToKeep[successorRegion][cast<BlockArgument>(input)
529 argsToKeep[successorRegion]
530 [cast<BlockArgument>(input).getArgNumber()] |
533 resultsToKeep[cast<OpResult>(input).getResultNumber()] =
534 resultsToKeep[cast<OpResult>(input).getResultNumber()] |
544 auto markValuesToKeep =
546 BitVector &operandsToKeep,
548 bool resultsOrArgsToKeepChanged =
true;
551 while (resultsOrArgsToKeepChanged) {
553 updateOperandsOrTerminatorOperandsToKeep(operandsToKeep,
554 resultsToKeep, argsToKeep);
557 for (
Region ®ion : regionBranchOp->getRegions()) {
560 updateOperandsOrTerminatorOperandsToKeep(
561 terminatorOperandsToKeep[region.back().getTerminator()],
562 resultsToKeep, argsToKeep, ®ion);
566 recomputeResultsAndArgsToKeep(
567 resultsToKeep, argsToKeep, operandsToKeep,
568 terminatorOperandsToKeep, resultsOrArgsToKeepChanged);
579 !hasLive(regionBranchOp->getResults(), nonLiveSet, la)) {
580 cl.operations.push_back(regionBranchOp.getOperation());
589 BitVector resultsToKeep;
594 BitVector operandsToKeep;
602 markLiveResults(resultsToKeep);
605 markLiveArgs(argsToKeep);
609 markNonForwardedOperands(operandsToKeep);
612 markNonForwardedReturnValues(terminatorOperandsToKeep);
616 markValuesToKeep(resultsToKeep, argsToKeep, operandsToKeep,
617 terminatorOperandsToKeep);
620 cl.operands.push_back({regionBranchOp, operandsToKeep.flip()});
623 for (
Region ®ion : regionBranchOp->getRegions()) {
626 BitVector argsToRemove = argsToKeep[®ion].flip();
627 cl.blocks.push_back({®ion.front(), argsToRemove});
628 collectNonLiveValues(nonLiveSet, region.front().getArguments(),
633 for (
Region ®ion : regionBranchOp->getRegions()) {
636 Operation *terminator = region.front().getTerminator();
637 cl.operands.push_back(
638 {terminator, terminatorOperandsToKeep[terminator].flip()});
642 BitVector resultsToRemove = resultsToKeep.flip();
643 collectNonLiveValues(nonLiveSet, regionBranchOp.getOperation()->getResults(),
645 cl.results.push_back({regionBranchOp.getOperation(), resultsToRemove});
658 RDVFinalCleanupList &cl) {
659 unsigned numSuccessors = branchOp->getNumSuccessors();
661 for (
unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
666 branchOp.getSuccessorOperands(succIdx);
668 for (
unsigned operandIdx = 0; operandIdx < successorOperands.
size();
670 operandValues.push_back(successorOperands[operandIdx]);
674 BitVector successorNonLive =
675 markLives(operandValues, nonLiveSet, la).flip();
676 collectNonLiveValues(nonLiveSet, successorBlock->
getArguments(),
680 cl.blocks.push_back({successorBlock, successorNonLive});
681 cl.successorOperands.push_back({branchOp, succIdx, successorNonLive});
687 static void cleanUpDeadVals(RDVFinalCleanupList &list) {
689 for (
auto &op : list.operations) {
695 for (
auto &v : list.values) {
700 for (
auto &f : list.functions) {
701 f.funcOp.eraseArguments(f.nonLiveArgs);
702 f.funcOp.eraseResults(f.nonLiveRets);
706 for (
auto &o : list.operands) {
707 o.op->eraseOperands(o.nonLive);
711 for (
auto &r : list.results) {
712 dropUsesAndEraseResults(r.op, r.nonLive);
716 for (
auto &b : list.blocks) {
718 if (b.b->getNumArguments() != b.nonLiveArgs.size())
721 for (
int i = b.nonLiveArgs.size() - 1; i >= 0; --i) {
722 if (!b.nonLiveArgs[i])
724 b.b->getArgument(i).dropAllUses();
725 b.b->eraseArgument(i);
730 for (
auto &op : list.successorOperands) {
732 op.branch.getSuccessorOperands(op.successorIndex);
734 if (successorOperands.
size() != op.nonLiveOperands.size())
737 for (
int i = successorOperands.
size() - 1; i >= 0; --i) {
738 if (!op.nonLiveOperands[i])
740 successorOperands.
erase(i);
745 struct RemoveDeadValues :
public impl::RemoveDeadValuesBase<RemoveDeadValues> {
746 void runOnOperation()
override;
750 void RemoveDeadValues::runOnOperation() {
751 auto &la = getAnalysis<RunLivenessAnalysis>();
760 RDVFinalCleanupList finalCleanupList;
763 if (
auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
764 processFuncOp(funcOp, module, la, deadVals, finalCleanupList);
765 }
else if (
auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
766 processRegionBranchOp(regionBranchOp, la, deadVals, finalCleanupList);
767 }
else if (
auto branchOp = dyn_cast<BranchOpInterface>(op)) {
768 processBranchOp(branchOp, la, deadVals, finalCleanupList);
772 }
else if (isa<CallOpInterface>(op)) {
776 processSimpleOp(op, la, deadVals, finalCleanupList);
780 cleanUpDeadVals(finalCleanupList);
784 return std::make_unique<RemoveDeadValues>();
static MutableArrayRef< OpOperand > operandsToOpOperands(OperandRange &operands)
Block represents an ordered list of Operations.
void erase()
Unlink this Block from its parent region and delete it.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
Block * getSuccessor(unsigned i)
This class helps build Operations.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be terminators.
This class implements the operand iterators for the Operation class.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
void dropAllUses()
Drop all uses of results of this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
static constexpr RegionBranchPoint parent()
Returns an instance of RegionBranchPoint representing the parent operation.
This class represents a successor of a region.
ValueRange getSuccessorInputs() const
Return the inputs to the successor that are remapped by the exit values of the current region.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void push_back(Block *block)
This class models how operands are forwarded to block arguments in control flow.
void erase(unsigned subStart, unsigned subLen=1)
Erase operands forwarded to the successor.
unsigned size() const
Returns the amount of operands passed to the successor.
This class represents a specific symbol use.
This class implements a range of SymbolRef uses.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
std::unique_ptr< Pass > createRemoveDeadValuesPass()
Creates an optimization pass to remove dead values.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
This lattice represents, for a given value, whether or not it is "live".
Runs liveness analysis on the IR defined by op.
const Liveness * getLiveness(Value val)