MLIR  16.0.0git
SymbolDCE.cpp
Go to the documentation of this file.
1 //===- SymbolDCE.cpp - Pass to delete dead symbols ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements an algorithm for eliminating symbol operations that are
10 // known to be dead.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/IR/SymbolTable.h"
16 #include "mlir/Transforms/Passes.h"
17 
18 using namespace mlir;
19 
20 namespace {
21 struct SymbolDCE : public SymbolDCEBase<SymbolDCE> {
22  void runOnOperation() override;
23 
24  /// Compute the liveness of the symbols within the given symbol table.
25  /// `symbolTableIsHidden` is true if this symbol table is known to be
26  /// unaccessible from operations in its parent regions.
27  LogicalResult computeLiveness(Operation *symbolTableOp,
28  SymbolTableCollection &symbolTable,
29  bool symbolTableIsHidden,
30  DenseSet<Operation *> &liveSymbols);
31 };
32 } // namespace
33 
34 void SymbolDCE::runOnOperation() {
35  Operation *symbolTableOp = getOperation();
36 
37  // SymbolDCE should only be run on operations that define a symbol table.
38  if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>()) {
39  symbolTableOp->emitOpError()
40  << " was scheduled to run under SymbolDCE, but does not define a "
41  "symbol table";
42  return signalPassFailure();
43  }
44 
45  // A flag that signals if the top level symbol table is hidden, i.e. not
46  // accessible from parent scopes.
47  bool symbolTableIsHidden = true;
48  SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(symbolTableOp);
49  if (symbolTableOp->getParentOp() && symbol)
50  symbolTableIsHidden = symbol.isPrivate();
51 
52  // Compute the set of live symbols within the symbol table.
53  DenseSet<Operation *> liveSymbols;
54  SymbolTableCollection symbolTable;
55  if (failed(computeLiveness(symbolTableOp, symbolTable, symbolTableIsHidden,
56  liveSymbols)))
57  return signalPassFailure();
58 
59  // After computing the liveness, delete all of the symbols that were found to
60  // be dead.
61  symbolTableOp->walk([&](Operation *nestedSymbolTable) {
62  if (!nestedSymbolTable->hasTrait<OpTrait::SymbolTable>())
63  return;
64  for (auto &block : nestedSymbolTable->getRegion(0)) {
65  for (Operation &op : llvm::make_early_inc_range(block)) {
66  if (isa<SymbolOpInterface>(&op) && !liveSymbols.count(&op)) {
67  op.erase();
68  ++numDCE;
69  }
70  }
71  }
72  });
73 }
74 
75 /// Compute the liveness of the symbols within the given symbol table.
76 /// `symbolTableIsHidden` is true if this symbol table is known to be
77 /// unaccessible from operations in its parent regions.
78 LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
79  SymbolTableCollection &symbolTable,
80  bool symbolTableIsHidden,
81  DenseSet<Operation *> &liveSymbols) {
82  // A worklist of live operations to propagate uses from.
84 
85  // Walk the symbols within the current symbol table, marking the symbols that
86  // are known to be live.
87  for (auto &block : symbolTableOp->getRegion(0)) {
88  // Add all non-symbols or symbols that can't be discarded.
89  for (Operation &op : block) {
90  SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
91  if (!symbol) {
92  worklist.push_back(&op);
93  continue;
94  }
95  bool isDiscardable = (symbolTableIsHidden || symbol.isPrivate()) &&
96  symbol.canDiscardOnUseEmpty();
97  if (!isDiscardable && liveSymbols.insert(&op).second)
98  worklist.push_back(&op);
99  }
100  }
101 
102  // Process the set of symbols that were known to be live, adding new symbols
103  // that are referenced within.
104  while (!worklist.empty()) {
105  Operation *op = worklist.pop_back_val();
106 
107  // If this is a symbol table, recursively compute its liveness.
108  if (op->hasTrait<OpTrait::SymbolTable>()) {
109  // The internal symbol table is hidden if the parent is, if its not a
110  // symbol, or if it is a private symbol.
111  SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
112  bool symIsHidden = symbolTableIsHidden || !symbol || symbol.isPrivate();
113  if (failed(computeLiveness(op, symbolTable, symIsHidden, liveSymbols)))
114  return failure();
115  }
116 
117  // Collect the uses held by this operation.
119  if (!uses) {
120  return op->emitError()
121  << "operation contains potentially unknown symbol table, "
122  "meaning that we can't reliable compute symbol uses";
123  }
124 
125  SmallVector<Operation *, 4> resolvedSymbols;
126  for (const SymbolTable::SymbolUse &use : *uses) {
127  // Lookup the symbols referenced by this use.
128  resolvedSymbols.clear();
129  if (failed(symbolTable.lookupSymbolIn(
130  op->getParentOp(), use.getSymbolRef(), resolvedSymbols)))
131  // Ignore references to unknown symbols.
132  continue;
133 
134  // Mark each of the resolved symbols as live.
135  for (Operation *resolvedSymbol : resolvedSymbols)
136  if (liveSymbols.insert(resolvedSymbol).second)
137  worklist.push_back(resolvedSymbol);
138  }
139  }
140 
141  return success();
142 }
143 
144 std::unique_ptr<Pass> mlir::createSymbolDCEPass() {
145  return std::make_unique<SymbolDCE>();
146 }
Include the generated interface declarations.
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
static Optional< UseRange > getSymbolUses(Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
std::enable_if< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT >::type walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one)...
Definition: Operation.h:574
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:242
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:528
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:165
std::unique_ptr< Pass > createSymbolDCEPass()
Creates a pass which delete symbol operations that are unreachable.
Definition: SymbolDCE.cpp:144
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:338
Operation * lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol)
Look up a symbol with the specified name within the specified symbol table operation, returning null if no such name exists.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "&#39;dim&#39; op " which is convenient for verifiers...
Definition: Operation.cpp:508
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:221
This class represents a specific symbol use.
Definition: SymbolTable.h:144
Region & getRegion(unsigned index)
Returns the region held by this operation at position &#39;index&#39;.
Definition: Operation.h:486