MLIR  20.0.0git
PipelineGlobalOps.cpp
Go to the documentation of this file.
1 //===- PipelineGlobalOpsPass.cpp - Pipeline Global Ops Pass ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/Pass/Pass.h"
16 
17 namespace mlir {
18 namespace ml_program {
19 #define GEN_PASS_DEF_MLPROGRAMPIPELINEGLOBALS
20 #include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc"
21 
22 namespace {
23 
24 class MLProgramPipelineGlobals
25  : public impl::MLProgramPipelineGlobalsBase<MLProgramPipelineGlobals> {
26 public:
27  void runOnOperation() override;
28 
29 private:
30  LogicalResult buildGlobalMap(ModuleOp op);
31 
32  void processBlock(Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad,
33  llvm::DenseSet<SymbolRefAttr> &symbolStore);
34 
37 };
38 
39 // Traverses upwards searchign for the operation mapped by the symbol.
40 static Operation *getFromSymbol(Operation *baseOp, SymbolRefAttr symbol) {
41  for (auto *op = baseOp; op; op = op->getParentOp()) {
42  auto *lookup = SymbolTable::lookupNearestSymbolFrom(op, symbol);
43  if (lookup)
44  return lookup;
45  }
46  return nullptr;
47 }
48 
49 // Builds map from a symbol to MLProgram global symbols loaded or stored
50 // during processing.
51 LogicalResult MLProgramPipelineGlobals::buildGlobalMap(ModuleOp module) {
53  auto res = module->walk([&](Operation *op) {
54  if (auto caller = mlir::dyn_cast<CallOpInterface>(op)) {
55  auto callable = caller.getCallableForCallee();
56  // For now we do not know how to handle Value based tracing, so fail.
57  if (mlir::isa<Value>(callable)) {
58  return WalkResult::interrupt();
59  }
60 
61  auto symbol = mlir::dyn_cast<SymbolRefAttr>(callable);
62  auto *func = getFromSymbol(op, symbol);
63  callableMap[symbol] = func;
64  }
65  return WalkResult::advance();
66  });
67 
68  if (res.wasInterrupted()) {
69  return failure();
70  }
71 
72  // First grab all symbols loaded or stored by each function. This
73  // will not handle calls initially.
76  for (auto callable : callableMap) {
78  llvm::DenseSet<SymbolRefAttr> storeSymbols;
79 
80  callable.getSecond()->walk(
81  [&](GlobalLoadOp op) { loadSymbols.insert(op.getGlobal()); });
82 
83  callable.getSecond()->walk(
84  [&](GlobalStoreOp op) { storeSymbols.insert(op.getGlobal()); });
85 
86  opLoadSymbols[callable.getFirst()] = std::move(loadSymbols);
87  opStoreSymbols[callable.getFirst()] = std::move(storeSymbols);
88  }
89 
90  // For each callable function we find each global loaded/stored within the
91  // function or a nested called function. This includes recursion checking to
92  // avoid infinitely recursing.
93  for (auto callable : callableMap) {
94  SymbolRefAttr thisSymbol = llvm::dyn_cast<SymbolRefAttr>(callable.first);
95  llvm::SmallVector<SymbolRefAttr> work = {thisSymbol};
96  llvm::DenseSet<SymbolRefAttr> visited = {thisSymbol};
98  llvm::DenseSet<SymbolRefAttr> storeSymbols;
99 
100  for (size_t i = 0; i < work.size(); ++i) {
101  callableMap[work[i]]->walk([&](CallOpInterface call) {
102  auto symbol = dyn_cast<SymbolRefAttr>(call.getCallableForCallee());
103  if (visited.insert(symbol).second)
104  work.push_back(symbol);
105  });
106 
107  for (auto load : opLoadSymbols[work[i]])
108  loadSymbols.insert(load);
109 
110  for (auto store : opStoreSymbols[work[i]])
111  storeSymbols.insert(store);
112  }
113 
114  loadSymbolsMap[thisSymbol] = std::move(loadSymbols);
115  storeSymbolsMap[thisSymbol] = std::move(storeSymbols);
116  }
117 
118  return success();
119 }
120 
121 // Process each operation in the block deleting unneeded loads / stores,
122 // recursing on subblocks and checking function calls.
123 void MLProgramPipelineGlobals::processBlock(
124  Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad,
125  llvm::DenseSet<SymbolRefAttr> &symbolStore) {
126 
130  for (auto &op : block) {
131  // If this is a global load, remap to a previous value if known
132  // and delete this load. Remember that this value is the currently
133  // known load.
134  if (auto load = mlir::dyn_cast<GlobalLoadOp>(op)) {
135  auto ref = load.getGlobal();
136  symbolLoad.insert(ref);
137  if (previousLoads.contains(ref)) {
138  toDelete.push_back(&op);
139  load.getResult().replaceAllUsesWith(previousLoads[ref]);
140  } else {
141  previousLoads[ref] = load.getResult();
142  }
143  continue;
144  }
145 
146  // Delete a previous store if it exists and is not needed, update
147  // the most recent known value for this global ref.
148  if (auto store = mlir::dyn_cast<GlobalStoreOp>(op)) {
149  auto ref = store.getGlobal();
150  symbolStore.insert(ref);
151  auto it = previousStores.find(ref);
152  if (it != previousStores.end()) {
153  toDelete.push_back(it->getSecond());
154  }
155 
156  previousLoads[ref] = store.getValue();
157  previousStores[ref] = &op;
158  continue;
159  }
160 
161  // If a function is called, clear known values for loads/stores used by
162  // the function or its sub-functions.
163  if (auto call = mlir::dyn_cast<CallOpInterface>(op)) {
164  auto loadSymbols =
165  loadSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];
166  auto storeSymbols =
167  storeSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];
168 
169  for (auto sym : loadSymbols) {
170  previousStores.erase(sym);
171  }
172 
173  for (auto sym : storeSymbols) {
174  previousLoads.erase(sym);
175  previousStores.erase(sym);
176  }
177  continue;
178  }
179 
180  // If the op has sub-regions, recurse inside. We make no guarantees whether
181  // the recursion occurs.
182  llvm::DenseSet<SymbolRefAttr> opSymbolLoad;
183  llvm::DenseSet<SymbolRefAttr> opSymbolStore;
184  for (auto &region : op.getRegions()) {
185  for (auto &block : region) {
186  processBlock(block, opSymbolLoad, opSymbolStore);
187  }
188  }
189 
190  // Update current state from the subblock.
191  for (auto change : opSymbolLoad) {
192  symbolLoad.insert(change);
193  previousStores.erase(change);
194  }
195 
196  for (auto change : opSymbolStore) {
197  symbolStore.insert(change);
198  previousLoads.erase(change);
199  previousStores.erase(change);
200  }
201  }
202 
203  for (auto *op : toDelete) {
204  op->erase();
205  }
206 }
207 
208 void MLProgramPipelineGlobals::runOnOperation() {
209  auto targetOp = getOperation();
210  if (failed(buildGlobalMap(targetOp))) {
211  return;
212  }
213 
214  for (auto &funcOp : *targetOp.getBody()) {
215  for (auto &region : funcOp.getRegions()) {
216  for (auto &block : region.getBlocks()) {
217  llvm::DenseSet<SymbolRefAttr> symbolsLoaded;
218  llvm::DenseSet<SymbolRefAttr> symbolsStored;
219  processBlock(block, symbolsLoaded, symbolsStored);
220  }
221  }
222  }
223 }
224 
225 } // namespace
226 
227 std::unique_ptr<OperationPass<mlir::ModuleOp>>
229  return std::make_unique<MLProgramPipelineGlobals>();
230 }
231 
232 } // namespace ml_program
233 } // namespace mlir
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static WalkResult advance()
Definition: Visitors.h:51
static WalkResult interrupt()
Definition: Visitors.h:50
std::unique_ptr< OperationPass< ModuleOp > > createMLProgramPipelineGlobalsPass()
Include the generated interface declarations.