MLIR  22.0.0git
SymbolDCE.cpp
Go to the documentation of this file.
1 //===- SymbolDCE.cpp - Pass to delete dead symbols ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements an algorithm for eliminating symbol operations that are
10 // known to be dead.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Transforms/Passes.h"
15 
16 #include "mlir/IR/Operation.h"
17 #include "mlir/IR/SymbolTable.h"
18 #include "llvm/Support/Debug.h"
19 #include "llvm/Support/DebugLog.h"
20 #include "llvm/Support/InterleavedRange.h"
21 
22 namespace mlir {
23 #define GEN_PASS_DEF_SYMBOLDCE
24 #include "mlir/Transforms/Passes.h.inc"
25 } // namespace mlir
26 
27 using namespace mlir;
28 
29 #define DEBUG_TYPE "symbol-dce"
30 
31 namespace {
32 struct SymbolDCE : public impl::SymbolDCEBase<SymbolDCE> {
33  void runOnOperation() override;
34 
35  /// Compute the liveness of the symbols within the given symbol table.
36  /// `symbolTableIsHidden` is true if this symbol table is known to be
37  /// unaccessible from operations in its parent regions.
38  LogicalResult computeLiveness(Operation *symbolTableOp,
39  SymbolTableCollection &symbolTable,
40  bool symbolTableIsHidden,
41  DenseSet<Operation *> &liveSymbols);
42 };
43 } // namespace
44 
45 void SymbolDCE::runOnOperation() {
46  Operation *symbolTableOp = getOperation();
47 
48  // SymbolDCE should only be run on operations that define a symbol table.
49  if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>()) {
50  symbolTableOp->emitOpError()
51  << " was scheduled to run under SymbolDCE, but does not define a "
52  "symbol table";
53  return signalPassFailure();
54  }
55 
56  // A flag that signals if the top level symbol table is hidden, i.e. not
57  // accessible from parent scopes.
58  bool symbolTableIsHidden = true;
59  SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(symbolTableOp);
60  if (symbolTableOp->getParentOp() && symbol)
61  symbolTableIsHidden = symbol.isPrivate();
62 
63  // Compute the set of live symbols within the symbol table.
64  DenseSet<Operation *> liveSymbols;
65  SymbolTableCollection symbolTable;
66  if (failed(computeLiveness(symbolTableOp, symbolTable, symbolTableIsHidden,
67  liveSymbols)))
68  return signalPassFailure();
69 
70  // After computing the liveness, delete all of the symbols that were found to
71  // be dead.
72  symbolTableOp->walk([&](Operation *nestedSymbolTable) {
73  if (!nestedSymbolTable->hasTrait<OpTrait::SymbolTable>())
74  return;
75  for (auto &block : nestedSymbolTable->getRegion(0)) {
76  for (Operation &op : llvm::make_early_inc_range(block)) {
77  if (isa<SymbolOpInterface>(&op) && !liveSymbols.count(&op)) {
78  op.erase();
79  ++numDCE;
80  }
81  }
82  }
83  });
84 }
85 
86 /// Compute the liveness of the symbols within the given symbol table.
87 /// `symbolTableIsHidden` is true if this symbol table is known to be
88 /// unaccessible from operations in its parent regions.
89 LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
90  SymbolTableCollection &symbolTable,
91  bool symbolTableIsHidden,
92  DenseSet<Operation *> &liveSymbols) {
93  LDBG() << "computeLiveness: "
94  << OpWithFlags(symbolTableOp, OpPrintingFlags().skipRegions());
95  // A worklist of live operations to propagate uses from.
97 
98  // Walk the symbols within the current symbol table, marking the symbols that
99  // are known to be live.
100  for (auto &block : symbolTableOp->getRegion(0)) {
101  // Add all non-symbols or symbols that can't be discarded.
102  for (Operation &op : block) {
103  SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
104  if (!symbol) {
105  worklist.push_back(&op);
106  continue;
107  }
108  bool isDiscardable = (symbolTableIsHidden || symbol.isPrivate()) &&
109  symbol.canDiscardOnUseEmpty();
110  if (!isDiscardable && liveSymbols.insert(&op).second)
111  worklist.push_back(&op);
112  }
113  }
114 
115  // Process the set of symbols that were known to be live, adding new symbols
116  // that are referenced within. For operations that are not symbol tables, it
117  // considers the liveness with respect to the op itself rather than scope of
118  // nested symbol tables by enqueuing all the top level operations for
119  // consideration.
120  while (!worklist.empty()) {
121  Operation *op = worklist.pop_back_val();
122  LDBG() << "processing: "
123  << OpWithFlags(op, OpPrintingFlags().skipRegions());
124 
125  // If this is a symbol table, recursively compute its liveness.
126  if (op->hasTrait<OpTrait::SymbolTable>()) {
127  // The internal symbol table is hidden if the parent is, if its not a
128  // symbol, or if it is a private symbol.
129  SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
130  bool symIsHidden = symbolTableIsHidden || !symbol || symbol.isPrivate();
131  LDBG() << "\tsymbol table: "
132  << OpWithFlags(op, OpPrintingFlags().skipRegions())
133  << " is hidden: " << symIsHidden;
134  if (failed(computeLiveness(op, symbolTable, symIsHidden, liveSymbols)))
135  return failure();
136  } else {
137  LDBG() << "\tnon-symbol table: "
138  << OpWithFlags(op, OpPrintingFlags().skipRegions());
139  // If the op is not a symbol table, then, unless op itself is dead which
140  // would be handled by DCE, we need to check all the regions and blocks
141  // within the op to find the uses (e.g., consider visibility within op as
142  // if top level rather than relying on pure symbol table visibility). This
143  // is more conservative than SymbolTable::walkSymbolTables in the case
144  // where there is again SymbolTable information to take advantage of.
145  for (auto &region : op->getRegions())
146  for (auto &block : region.getBlocks())
147  for (Operation &op : block)
148  if (op.getNumRegions())
149  worklist.push_back(&op);
150  }
151 
152  // Get the first parent symbol table op. Note: due to enqueueing of
153  // top-level ops, we may not have a symbol table parent here, but if we do
154  // not, then we also don't have a symbol.
155  Operation *parentOp = op->getParentOp();
156  if (!parentOp->hasTrait<OpTrait::SymbolTable>())
157  continue;
158 
159  // Collect the uses held by this operation.
160  std::optional<SymbolTable::UseRange> uses = SymbolTable::getSymbolUses(op);
161  if (!uses) {
162  return op->emitError()
163  << "operation contains potentially unknown symbol table, meaning "
164  << "that we can't reliable compute symbol uses";
165  }
166 
167  SmallVector<Operation *, 4> resolvedSymbols;
168  LDBG() << "uses of " << OpWithFlags(op, OpPrintingFlags().skipRegions());
169  for (const SymbolTable::SymbolUse &use : *uses) {
170  LDBG() << "\tuse: " << use.getUser();
171  // Lookup the symbols referenced by this use.
172  resolvedSymbols.clear();
173  if (failed(symbolTable.lookupSymbolIn(parentOp, use.getSymbolRef(),
174  resolvedSymbols)))
175  // Ignore references to unknown symbols.
176  continue;
177  LDBG() << "\t\tresolved symbols: "
178  << llvm::interleaved(resolvedSymbols, ", ");
179 
180  // Mark each of the resolved symbols as live.
181  for (Operation *resolvedSymbol : resolvedSymbols)
182  if (liveSymbols.insert(resolvedSymbol).second)
183  worklist.push_back(resolvedSymbol);
184  }
185  }
186 
187  return success();
188 }
189 
190 std::unique_ptr<Pass> mlir::createSymbolDCEPass() {
191  return std::make_unique<SymbolDCE>();
192 }
Set of flags used to control the behavior of the various IR print methods (e.g.
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:452
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition: Operation.h:1111
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
virtual Operation * lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol)
Look up a symbol with the specified name within the specified symbol table operation,...
This class represents a specific symbol use.
Definition: SymbolTable.h:183
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
std::unique_ptr< Pass > createSymbolDCEPass()
Creates a pass which delete symbol operations that are unreachable.
Definition: SymbolDCE.cpp:190