MLIR  22.0.0git
PipelineGlobalOps.cpp
Go to the documentation of this file.
1 //===- PipelineGlobalOpsPass.cpp - Pipeline Global Ops Pass ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
12 #include "mlir/IR/BuiltinOps.h"
13 
14 namespace mlir {
15 namespace ml_program {
16 #define GEN_PASS_DEF_MLPROGRAMPIPELINEGLOBALSPASS
17 #include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc"
18 
19 namespace {
20 
21 class MLProgramPipelineGlobals
22  : public impl::MLProgramPipelineGlobalsPassBase<MLProgramPipelineGlobals> {
23 public:
24  void runOnOperation() override;
25 
26 private:
27  LogicalResult buildGlobalMap(ModuleOp op);
28 
29  void processBlock(Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad,
30  llvm::DenseSet<SymbolRefAttr> &symbolStore);
31 
34 };
35 
36 // Traverses upwards searchign for the operation mapped by the symbol.
37 static Operation *getFromSymbol(Operation *baseOp, SymbolRefAttr symbol) {
38  for (auto *op = baseOp; op; op = op->getParentOp()) {
39  auto *lookup = SymbolTable::lookupNearestSymbolFrom(op, symbol);
40  if (lookup)
41  return lookup;
42  }
43  return nullptr;
44 }
45 
46 // Builds map from a symbol to MLProgram global symbols loaded or stored
47 // during processing.
48 LogicalResult MLProgramPipelineGlobals::buildGlobalMap(ModuleOp module) {
50  auto res = module->walk([&](Operation *op) {
51  if (auto caller = mlir::dyn_cast<CallOpInterface>(op)) {
52  auto callable = caller.getCallableForCallee();
53  // For now we do not know how to handle Value based tracing, so fail.
54  if (mlir::isa<Value>(callable)) {
55  return WalkResult::interrupt();
56  }
57 
58  auto symbol = mlir::dyn_cast<SymbolRefAttr>(callable);
59  auto *func = getFromSymbol(op, symbol);
60  callableMap[symbol] = func;
61  }
62  return WalkResult::advance();
63  });
64 
65  if (res.wasInterrupted()) {
66  return failure();
67  }
68 
69  // First grab all symbols loaded or stored by each function. This
70  // will not handle calls initially.
73  for (auto callable : callableMap) {
75  llvm::DenseSet<SymbolRefAttr> storeSymbols;
76 
77  callable.getSecond()->walk(
78  [&](GlobalLoadOp op) { loadSymbols.insert(op.getGlobal()); });
79 
80  callable.getSecond()->walk(
81  [&](GlobalStoreOp op) { storeSymbols.insert(op.getGlobal()); });
82 
83  opLoadSymbols[callable.getFirst()] = std::move(loadSymbols);
84  opStoreSymbols[callable.getFirst()] = std::move(storeSymbols);
85  }
86 
87  // For each callable function we find each global loaded/stored within the
88  // function or a nested called function. This includes recursion checking to
89  // avoid infinitely recursing.
90  for (auto callable : callableMap) {
91  SymbolRefAttr thisSymbol = llvm::dyn_cast<SymbolRefAttr>(callable.first);
92  llvm::SmallVector<SymbolRefAttr> work = {thisSymbol};
93  llvm::DenseSet<SymbolRefAttr> visited = {thisSymbol};
95  llvm::DenseSet<SymbolRefAttr> storeSymbols;
96 
97  for (size_t i = 0; i < work.size(); ++i) {
98  callableMap[work[i]]->walk([&](CallOpInterface call) {
99  auto symbol = dyn_cast<SymbolRefAttr>(call.getCallableForCallee());
100  if (visited.insert(symbol).second)
101  work.push_back(symbol);
102  });
103 
104  loadSymbols.insert_range(opLoadSymbols[work[i]]);
105 
106  storeSymbols.insert_range(opStoreSymbols[work[i]]);
107  }
108 
109  loadSymbolsMap[thisSymbol] = std::move(loadSymbols);
110  storeSymbolsMap[thisSymbol] = std::move(storeSymbols);
111  }
112 
113  return success();
114 }
115 
116 // Process each operation in the block deleting unneeded loads / stores,
117 // recursing on subblocks and checking function calls.
118 void MLProgramPipelineGlobals::processBlock(
119  Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad,
120  llvm::DenseSet<SymbolRefAttr> &symbolStore) {
121 
125  for (auto &op : block) {
126  // If this is a global load, remap to a previous value if known
127  // and delete this load. Remember that this value is the currently
128  // known load.
129  if (auto load = mlir::dyn_cast<GlobalLoadOp>(op)) {
130  auto ref = load.getGlobal();
131  symbolLoad.insert(ref);
132  if (previousLoads.contains(ref)) {
133  toDelete.push_back(&op);
134  load.getResult().replaceAllUsesWith(previousLoads[ref]);
135  } else {
136  previousLoads[ref] = load.getResult();
137  }
138  continue;
139  }
140 
141  // Delete a previous store if it exists and is not needed, update
142  // the most recent known value for this global ref.
143  if (auto store = mlir::dyn_cast<GlobalStoreOp>(op)) {
144  auto ref = store.getGlobal();
145  symbolStore.insert(ref);
146  auto it = previousStores.find(ref);
147  if (it != previousStores.end()) {
148  toDelete.push_back(it->getSecond());
149  }
150 
151  previousLoads[ref] = store.getValue();
152  previousStores[ref] = &op;
153  continue;
154  }
155 
156  // If a function is called, clear known values for loads/stores used by
157  // the function or its sub-functions.
158  if (auto call = mlir::dyn_cast<CallOpInterface>(op)) {
159  auto loadSymbols =
160  loadSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];
161  auto storeSymbols =
162  storeSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];
163 
164  for (auto sym : loadSymbols) {
165  previousStores.erase(sym);
166  }
167 
168  for (auto sym : storeSymbols) {
169  previousLoads.erase(sym);
170  previousStores.erase(sym);
171  }
172  continue;
173  }
174 
175  // If the op has sub-regions, recurse inside. We make no guarantees whether
176  // the recursion occurs.
177  llvm::DenseSet<SymbolRefAttr> opSymbolLoad;
178  llvm::DenseSet<SymbolRefAttr> opSymbolStore;
179  for (auto &region : op.getRegions()) {
180  for (auto &block : region) {
181  processBlock(block, opSymbolLoad, opSymbolStore);
182  }
183  }
184 
185  // Update current state from the subblock.
186  for (auto change : opSymbolLoad) {
187  symbolLoad.insert(change);
188  previousStores.erase(change);
189  }
190 
191  for (auto change : opSymbolStore) {
192  symbolStore.insert(change);
193  previousLoads.erase(change);
194  previousStores.erase(change);
195  }
196  }
197 
198  for (auto *op : toDelete) {
199  op->erase();
200  }
201 }
202 
203 void MLProgramPipelineGlobals::runOnOperation() {
204  auto targetOp = getOperation();
205  if (failed(buildGlobalMap(targetOp))) {
206  return;
207  }
208 
209  for (auto &funcOp : *targetOp.getBody()) {
210  for (auto &region : funcOp.getRegions()) {
211  for (auto &block : region.getBlocks()) {
212  llvm::DenseSet<SymbolRefAttr> symbolsLoaded;
213  llvm::DenseSet<SymbolRefAttr> symbolsStored;
214  processBlock(block, symbolsLoaded, symbolsStored);
215  }
216  }
217  }
218 }
219 
220 } // namespace
221 
222 } // namespace ml_program
223 } // namespace mlir
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.