54 #include "llvm/ADT/STLExtras.h"
62 #define GEN_PASS_DEF_REMOVEDEADVALUES
63 #include "mlir/Transforms/Passes.h.inc"
80 for (
Value value : values) {
87 if (!liveness || liveness->
isLive)
96 BitVector lives(values.size(),
true);
111 if (liveness && !liveness->
isLive)
120 static void dropUsesAndEraseResults(
Operation *op, BitVector toErase) {
122 "expected the number of results in `op` and the size of `toErase` to "
125 std::vector<Type> newResultTypes;
127 if (!toErase[result.getResultNumber()])
128 newResultTypes.push_back(result.getType());
130 builder.setInsertionPointAfter(op);
141 while (!region.empty())
142 region.front().moveBefore(temp);
146 unsigned indexOfNextNewCallOpResultToReplace = 0;
148 assert(result &&
"expected result to be non-null");
149 if (toErase[index]) {
150 result.dropAllUses();
152 result.replaceAllUsesWith(
153 newOp->
getResult(indexOfNextNewCallOpResultToReplace++));
163 for (
unsigned i = 0, e = operands.size(); i < e; i++)
164 opOperands.push_back(&values[i]);
195 static void cleanFuncOp(FunctionOpInterface funcOp,
Operation *module,
197 if (funcOp.isPublic())
202 BitVector nonLiveArgs = markLives(arguments, la);
203 nonLiveArgs = nonLiveArgs.flip();
207 if (arg && nonLiveArgs[index])
211 funcOp.eraseArguments(nonLiveArgs);
217 assert(isa<CallOpInterface>(callOp) &&
"expected a call-like user");
223 for (
int index : nonLiveArgs.set_bits())
224 nonLiveCallOperands.set(callOpOperands[index]->getOperandNumber());
250 Operation *lastReturnOp = funcOp.back().getTerminator();
252 BitVector nonLiveRets(numReturns,
true);
255 assert(isa<CallOpInterface>(callOp) &&
"expected a call-like user");
256 BitVector liveCallRets = markLives(callOp->
getResults(), la);
257 nonLiveRets &= liveCallRets.flip();
264 for (
Block &block : funcOp.getBlocks()) {
265 Operation *returnOp = block.getTerminator();
269 funcOp.eraseResults(nonLiveRets);
274 assert(isa<CallOpInterface>(callOp) &&
"expected a call-like user");
275 dropUsesAndEraseResults(callOp, nonLiveRets);
300 static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
303 auto markLiveResults = [&](BitVector &liveResults) {
304 liveResults = markLives(regionBranchOp->getResults(), la);
309 for (
Region ®ion : regionBranchOp->getRegions()) {
311 BitVector regionLiveArgs = markLives(arguments, la);
312 liveArgs[®ion] = regionLiveArgs;
318 auto getSuccessors = [&](
Region *region =
nullptr) {
323 regionBranchOp.getSuccessorRegions(point, successors);
333 terminator ? cast<RegionBranchTerminatorOpInterface>(terminator)
334 .getSuccessorOperands(successor)
335 : regionBranchOp.getEntrySuccessorOperands(successor);
342 auto markNonForwardedOperands = [&](BitVector &nonForwardedOperands) {
343 nonForwardedOperands.resize(regionBranchOp->getNumOperands(),
true);
345 for (
OpOperand *opOperand : getForwardedOpOperands(successor))
346 nonForwardedOperands.reset(opOperand->getOperandNumber());
352 auto markNonForwardedReturnValues =
354 for (
Region ®ion : regionBranchOp->getRegions()) {
356 nonForwardedRets[terminator] =
360 getForwardedOpOperands(successor, terminator))
361 nonForwardedRets[terminator].reset(opOperand->getOperandNumber());
370 auto updateOperandsOrTerminatorOperandsToKeep =
371 [&](BitVector &valuesToKeep, BitVector &resultsToKeep,
374 region ? region->front().getTerminator() :
nullptr;
378 for (
auto [opOperand, input] :
379 llvm::zip(getForwardedOpOperands(successor, terminator),
381 size_t operandNum = opOperand->getOperandNumber();
384 ? argsToKeep[successorRegion]
385 [cast<BlockArgument>(input).getArgNumber()]
386 : resultsToKeep[cast<OpResult>(input).getResultNumber()];
387 valuesToKeep[operandNum] = valuesToKeep[operandNum] | updateBasedOn;
395 auto recomputeResultsAndArgsToKeep =
397 BitVector &operandsToKeep,
399 bool &resultsOrArgsToKeepChanged) {
400 resultsOrArgsToKeepChanged =
false;
405 for (
auto [opOperand, input] :
406 llvm::zip(getForwardedOpOperands(successor),
408 bool recomputeBasedOn =
409 operandsToKeep[opOperand->getOperandNumber()];
412 ? argsToKeep[successorRegion]
413 [cast<BlockArgument>(input).getArgNumber()]
414 : resultsToKeep[cast<OpResult>(input).getResultNumber()];
415 if (!toRecompute && recomputeBasedOn)
416 resultsOrArgsToKeepChanged =
true;
417 if (successorRegion) {
418 argsToKeep[successorRegion][cast<BlockArgument>(input)
420 argsToKeep[successorRegion]
421 [cast<BlockArgument>(input).getArgNumber()] |
424 resultsToKeep[cast<OpResult>(input).getResultNumber()] =
425 resultsToKeep[cast<OpResult>(input).getResultNumber()] |
433 for (
Region ®ion : regionBranchOp->getRegions()) {
434 Operation *terminator = region.front().getTerminator();
437 for (
auto [opOperand, input] :
438 llvm::zip(getForwardedOpOperands(successor, terminator),
440 bool recomputeBasedOn =
441 terminatorOperandsToKeep[region.back().getTerminator()]
442 [opOperand->getOperandNumber()];
445 ? argsToKeep[successorRegion]
446 [cast<BlockArgument>(input).getArgNumber()]
447 : resultsToKeep[cast<OpResult>(input).getResultNumber()];
448 if (!toRecompute && recomputeBasedOn)
449 resultsOrArgsToKeepChanged =
true;
450 if (successorRegion) {
451 argsToKeep[successorRegion][cast<BlockArgument>(input)
453 argsToKeep[successorRegion]
454 [cast<BlockArgument>(input).getArgNumber()] |
457 resultsToKeep[cast<OpResult>(input).getResultNumber()] =
458 resultsToKeep[cast<OpResult>(input).getResultNumber()] |
468 auto markValuesToKeep =
470 BitVector &operandsToKeep,
472 bool resultsOrArgsToKeepChanged =
true;
475 while (resultsOrArgsToKeepChanged) {
477 updateOperandsOrTerminatorOperandsToKeep(operandsToKeep,
478 resultsToKeep, argsToKeep);
481 for (
Region ®ion : regionBranchOp->getRegions()) {
482 updateOperandsOrTerminatorOperandsToKeep(
483 terminatorOperandsToKeep[region.back().getTerminator()],
484 resultsToKeep, argsToKeep, ®ion);
488 recomputeResultsAndArgsToKeep(
489 resultsToKeep, argsToKeep, operandsToKeep,
490 terminatorOperandsToKeep, resultsOrArgsToKeepChanged);
500 !hasLive(regionBranchOp->getResults(), la)) {
501 regionBranchOp->dropAllUses();
502 regionBranchOp->erase();
510 BitVector resultsToKeep;
515 BitVector operandsToKeep;
523 markLiveResults(resultsToKeep);
526 markLiveArgs(argsToKeep);
530 markNonForwardedOperands(operandsToKeep);
533 markNonForwardedReturnValues(terminatorOperandsToKeep);
537 markValuesToKeep(resultsToKeep, argsToKeep, operandsToKeep,
538 terminatorOperandsToKeep);
541 regionBranchOp->eraseOperands(operandsToKeep.flip());
544 for (
Region ®ion : regionBranchOp->getRegions()) {
545 assert(!region.empty() &&
"expected a non-empty region in an op "
546 "implementing `RegionBranchOpInterface`");
547 for (
auto [index, arg] :
llvm::enumerate(region.front().getArguments())) {
548 if (argsToKeep[®ion][index])
553 region.front().eraseArguments(argsToKeep[®ion].flip());
557 for (
Region ®ion : regionBranchOp->getRegions()) {
558 Operation *terminator = region.front().getTerminator();
559 terminator->
eraseOperands(terminatorOperandsToKeep[terminator].flip());
563 dropUsesAndEraseResults(regionBranchOp.getOperation(), resultsToKeep.flip());
566 struct RemoveDeadValues :
public impl::RemoveDeadValuesBase<RemoveDeadValues> {
567 void runOnOperation()
override;
571 void RemoveDeadValues::runOnOperation() {
572 auto &la = getAnalysis<RunLivenessAnalysis>();
579 if (isa<BranchOpInterface>(op) ||
580 (isa<SymbolOpInterface>(op) && !isa<FunctionOpInterface>(op)) ||
581 (isa<SymbolUserOpInterface>(op) && !isa<CallOpInterface>(op))) {
582 op->
emitError() <<
"cannot optimize an IR with non-function symbol ops, "
583 "non-call symbol user ops or branch ops\n";
593 if (
auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
594 cleanFuncOp(funcOp, module, la);
595 }
else if (
auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
596 cleanRegionBranchOp(regionBranchOp, la);
600 }
else if (isa<CallOpInterface>(op)) {
604 cleanSimpleOp(op, la);
610 return std::make_unique<RemoveDeadValues>();
static MutableArrayRef< OpOperand > operandsToOpOperands(OperandRange &operands)
Block represents an ordered list of Operations.
void erase()
Unlink this Block from its parent region and delete it.
Operation * getTerminator()
Get the terminator operation of this block.
This class helps build Operations.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be terminators.
This class implements the operand iterators for the Operation class.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
void dropAllUses()
Drop all uses of results of this operation.
void eraseOperands(unsigned idx, unsigned length=1)
Erase the operands starting at position idx and ending at position 'idx'+'length'.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
static constexpr RegionBranchPoint parent()
Returns an instance of RegionBranchPoint representing the parent operation.
This class represents a successor of a region.
ValueRange getSuccessorInputs() const
Return the inputs to the successor that are remapped by the exit values of the current region.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void push_back(Block *block)
This class represents a specific symbol use.
This class implements a range of SymbolRef uses.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
std::unique_ptr< Pass > createRemoveDeadValuesPass()
Creates an optimization pass to remove dead values.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
This lattice represents, for a given value, whether or not it is "live".
Runs liveness analysis on the IR defined by op.
const Liveness * getLiveness(Value val)