MLIR  19.0.0git
SymbolDCE.cpp
Go to the documentation of this file.
1 //===- SymbolDCE.cpp - Pass to delete dead symbols ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements an algorithm for eliminating symbol operations that are
10 // known to be dead.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Transforms/Passes.h"
15 
16 #include "mlir/IR/SymbolTable.h"
17 
18 namespace mlir {
19 #define GEN_PASS_DEF_SYMBOLDCE
20 #include "mlir/Transforms/Passes.h.inc"
21 } // namespace mlir
22 
23 using namespace mlir;
24 
25 namespace {
26 struct SymbolDCE : public impl::SymbolDCEBase<SymbolDCE> {
27  void runOnOperation() override;
28 
29  /// Compute the liveness of the symbols within the given symbol table.
30  /// `symbolTableIsHidden` is true if this symbol table is known to be
31  /// unaccessible from operations in its parent regions.
32  LogicalResult computeLiveness(Operation *symbolTableOp,
33  SymbolTableCollection &symbolTable,
34  bool symbolTableIsHidden,
35  DenseSet<Operation *> &liveSymbols);
36 };
37 } // namespace
38 
39 void SymbolDCE::runOnOperation() {
40  Operation *symbolTableOp = getOperation();
41 
42  // SymbolDCE should only be run on operations that define a symbol table.
43  if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>()) {
44  symbolTableOp->emitOpError()
45  << " was scheduled to run under SymbolDCE, but does not define a "
46  "symbol table";
47  return signalPassFailure();
48  }
49 
50  // A flag that signals if the top level symbol table is hidden, i.e. not
51  // accessible from parent scopes.
52  bool symbolTableIsHidden = true;
53  SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(symbolTableOp);
54  if (symbolTableOp->getParentOp() && symbol)
55  symbolTableIsHidden = symbol.isPrivate();
56 
57  // Compute the set of live symbols within the symbol table.
58  DenseSet<Operation *> liveSymbols;
59  SymbolTableCollection symbolTable;
60  if (failed(computeLiveness(symbolTableOp, symbolTable, symbolTableIsHidden,
61  liveSymbols)))
62  return signalPassFailure();
63 
64  // After computing the liveness, delete all of the symbols that were found to
65  // be dead.
66  symbolTableOp->walk([&](Operation *nestedSymbolTable) {
67  if (!nestedSymbolTable->hasTrait<OpTrait::SymbolTable>())
68  return;
69  for (auto &block : nestedSymbolTable->getRegion(0)) {
70  for (Operation &op : llvm::make_early_inc_range(block)) {
71  if (isa<SymbolOpInterface>(&op) && !liveSymbols.count(&op)) {
72  op.erase();
73  ++numDCE;
74  }
75  }
76  }
77  });
78 }
79 
80 /// Compute the liveness of the symbols within the given symbol table.
81 /// `symbolTableIsHidden` is true if this symbol table is known to be
82 /// unaccessible from operations in its parent regions.
83 LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
84  SymbolTableCollection &symbolTable,
85  bool symbolTableIsHidden,
86  DenseSet<Operation *> &liveSymbols) {
87  // A worklist of live operations to propagate uses from.
89 
90  // Walk the symbols within the current symbol table, marking the symbols that
91  // are known to be live.
92  for (auto &block : symbolTableOp->getRegion(0)) {
93  // Add all non-symbols or symbols that can't be discarded.
94  for (Operation &op : block) {
95  SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
96  if (!symbol) {
97  worklist.push_back(&op);
98  continue;
99  }
100  bool isDiscardable = (symbolTableIsHidden || symbol.isPrivate()) &&
101  symbol.canDiscardOnUseEmpty();
102  if (!isDiscardable && liveSymbols.insert(&op).second)
103  worklist.push_back(&op);
104  }
105  }
106 
107  // Process the set of symbols that were known to be live, adding new symbols
108  // that are referenced within.
109  while (!worklist.empty()) {
110  Operation *op = worklist.pop_back_val();
111 
112  // If this is a symbol table, recursively compute its liveness.
113  if (op->hasTrait<OpTrait::SymbolTable>()) {
114  // The internal symbol table is hidden if the parent is, if its not a
115  // symbol, or if it is a private symbol.
116  SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
117  bool symIsHidden = symbolTableIsHidden || !symbol || symbol.isPrivate();
118  if (failed(computeLiveness(op, symbolTable, symIsHidden, liveSymbols)))
119  return failure();
120  }
121 
122  // Collect the uses held by this operation.
123  std::optional<SymbolTable::UseRange> uses = SymbolTable::getSymbolUses(op);
124  if (!uses) {
125  return op->emitError()
126  << "operation contains potentially unknown symbol table, "
127  "meaning that we can't reliable compute symbol uses";
128  }
129 
130  SmallVector<Operation *, 4> resolvedSymbols;
131  for (const SymbolTable::SymbolUse &use : *uses) {
132  // Lookup the symbols referenced by this use.
133  resolvedSymbols.clear();
134  if (failed(symbolTable.lookupSymbolIn(
135  op->getParentOp(), use.getSymbolRef(), resolvedSymbols)))
136  // Ignore references to unknown symbols.
137  continue;
138 
139  // Mark each of the resolved symbols as live.
140  for (Operation *resolvedSymbol : resolvedSymbols)
141  if (liveSymbols.insert(resolvedSymbol).second)
142  worklist.push_back(resolvedSymbol);
143  }
144  }
145 
146  return success();
147 }
148 
149 std::unique_ptr<Pass> mlir::createSymbolDCEPass() {
150  return std::make_unique<SymbolDCE>();
151 }
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:435
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Operation * lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol)
Look up a symbol with the specified name within the specified symbol table operation,...
This class represents a specific symbol use.
Definition: SymbolTable.h:183
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
std::unique_ptr< Pass > createSymbolDCEPass()
Creates a pass which delete symbol operations that are unreachable.
Definition: SymbolDCE.cpp:149
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26