18#include "llvm/Support/Debug.h"
19#include "llvm/Support/DebugLog.h"
20#include "llvm/Support/InterleavedRange.h"
23#define GEN_PASS_DEF_SYMBOLDCE
24#include "mlir/Transforms/Passes.h.inc"
29#define DEBUG_TYPE "symbol-dce"
33 void runOnOperation()
override;
38 LogicalResult computeLiveness(
Operation *symbolTableOp,
40 bool symbolTableIsHidden,
45void SymbolDCE::runOnOperation() {
46 Operation *symbolTableOp = getOperation();
49 if (!symbolTableOp->
hasTrait<OpTrait::SymbolTable>()) {
51 <<
" was scheduled to run under SymbolDCE, but does not define a "
53 return signalPassFailure();
58 bool symbolTableIsHidden =
true;
59 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(symbolTableOp);
61 symbolTableIsHidden = symbol.isPrivate();
65 SymbolTableCollection symbolTable;
66 if (
failed(computeLiveness(symbolTableOp, symbolTable, symbolTableIsHidden,
68 return signalPassFailure();
72 symbolTableOp->
walk([&](Operation *nestedSymbolTable) {
73 if (!nestedSymbolTable->
hasTrait<OpTrait::SymbolTable>())
75 for (
auto &block : nestedSymbolTable->
getRegion(0)) {
76 for (Operation &op : llvm::make_early_inc_range(block)) {
77 if (isa<SymbolOpInterface>(&op) && !liveSymbols.count(&op)) {
89LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
90 SymbolTableCollection &symbolTable,
91 bool symbolTableIsHidden,
93 LDBG() <<
"computeLiveness: "
94 << OpWithFlags(symbolTableOp, OpPrintingFlags().skipRegions());
96 SmallVector<Operation *, 16> worklist;
100 for (
auto &block : symbolTableOp->
getRegion(0)) {
102 for (Operation &op : block) {
103 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
105 worklist.push_back(&op);
108 bool isDiscardable = (symbolTableIsHidden || symbol.isPrivate()) &&
109 symbol.canDiscardOnUseEmpty();
110 if (!isDiscardable && liveSymbols.insert(&op).second)
111 worklist.push_back(&op);
120 while (!worklist.empty()) {
121 Operation *op = worklist.pop_back_val();
122 LDBG() <<
"processing: "
123 << OpWithFlags(op, OpPrintingFlags().skipRegions());
126 if (op->
hasTrait<OpTrait::SymbolTable>()) {
129 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
130 bool symIsHidden = symbolTableIsHidden || !symbol || symbol.isPrivate();
131 LDBG() <<
"\tsymbol table: "
132 << OpWithFlags(op, OpPrintingFlags().skipRegions())
133 <<
" is hidden: " << symIsHidden;
134 if (
failed(computeLiveness(op, symbolTable, symIsHidden, liveSymbols)))
137 LDBG() <<
"\tnon-symbol table: "
138 << OpWithFlags(op, OpPrintingFlags().skipRegions());
146 for (
auto &block : region.getBlocks())
147 for (Operation &op : block)
149 worklist.push_back(&op);
156 if (!parentOp->
hasTrait<OpTrait::SymbolTable>())
163 <<
"operation contains potentially unknown symbol table, meaning "
164 <<
"that we can't reliable compute symbol uses";
167 SmallVector<Operation *, 4> resolvedSymbols;
168 LDBG() <<
"uses of " << OpWithFlags(op, OpPrintingFlags().skipRegions());
169 for (
const SymbolTable::SymbolUse &use : *uses) {
170 LDBG() <<
"\tuse: " << use.getUser();
172 resolvedSymbols.clear();
177 LDBG() <<
"\t\tresolved symbols: "
178 << llvm::interleaved(resolvedSymbols,
", ");
181 for (Operation *resolvedSymbol : resolvedSymbols)
182 if (liveSymbols.insert(resolvedSymbol).second)
183 worklist.push_back(resolvedSymbol);
191 return std::make_unique<SymbolDCE>();
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class represents a collection of SymbolTables.
virtual Operation * lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol)
Look up a symbol with the specified name within the specified symbol table operation,...
static std::optional< UseRange > getSymbolUses(Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
Include the generated interface declarations.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
std::unique_ptr< Pass > createSymbolDCEPass()
Creates a pass which delete symbol operations that are unreachable.