MLIR  19.0.0git
PipelineGlobalOps.cpp
Go to the documentation of this file.
1 //===- PipelineGlobalOpsPass.cpp - Pipeline Global Ops Pass ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/Pass/Pass.h"
16 
17 namespace mlir {
18 namespace ml_program {
19 #define GEN_PASS_DEF_MLPROGRAMPIPELINEGLOBALS
20 #include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc"
21 
22 namespace {
23 
24 class MLProgramPipelineGlobals
25  : public impl::MLProgramPipelineGlobalsBase<MLProgramPipelineGlobals> {
26 public:
27  void runOnOperation() override;
28 
29 private:
30  LogicalResult buildGlobalMap(ModuleOp op);
31 
32  void processBlock(Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad,
33  llvm::DenseSet<SymbolRefAttr> &symbolStore);
34 
37 };
38 
39 // Traverses upwards searchign for the operation mapped by the symbol.
40 static Operation *getFromSymbol(Operation *baseOp, SymbolRefAttr symbol) {
41  for (auto *op = baseOp; op; op = op->getParentOp()) {
42  auto *lookup = SymbolTable::lookupNearestSymbolFrom(op, symbol);
43  if (lookup)
44  return lookup;
45  }
46  return nullptr;
47 }
48 
49 // Builds map from a symbol to MLProgram global symbols loaded or stored
50 // during processing.
51 LogicalResult MLProgramPipelineGlobals::buildGlobalMap(ModuleOp module) {
53  auto res = module->walk([&](Operation *op) {
54  if (auto caller = mlir::dyn_cast<CallOpInterface>(op)) {
55  auto callable = caller.getCallableForCallee();
56  // For now we do not know how to handle Value based tracing, so fail.
57  if (mlir::isa<Value>(callable)) {
58  return WalkResult::interrupt();
59  }
60 
61  auto symbol = mlir::dyn_cast<SymbolRefAttr>(callable);
62  auto *func = getFromSymbol(op, symbol);
63  callableMap[symbol] = func;
64  }
65  return WalkResult::advance();
66  });
67 
68  if (res.wasInterrupted()) {
69  return failure();
70  }
71 
72  // First grab all symbols loaded or stored by each function. This
73  // will not handle calls initially.
76  for (auto callable : callableMap) {
78  llvm::DenseSet<SymbolRefAttr> storeSymbols;
79 
80  callable.getSecond()->walk(
81  [&](GlobalLoadOp op) { loadSymbols.insert(op.getGlobal()); });
82 
83  callable.getSecond()->walk(
84  [&](GlobalStoreOp op) { storeSymbols.insert(op.getGlobal()); });
85 
86  opLoadSymbols[callable.getFirst()] = std::move(loadSymbols);
87  opStoreSymbols[callable.getFirst()] = std::move(storeSymbols);
88  }
89 
90  // For each callable function we find each global loaded/stored within the
91  // function or a nested called function. This includes recursion checking to
92  // avoid infinitely recursing.
93  for (auto callable : callableMap) {
94  SymbolRefAttr thisSymbol = llvm::dyn_cast<SymbolRefAttr>(callable.first);
95  llvm::SmallVector<SymbolRefAttr> work = {thisSymbol};
96  llvm::DenseSet<SymbolRefAttr> visited = {thisSymbol};
98  llvm::DenseSet<SymbolRefAttr> storeSymbols;
99 
100  for (size_t i = 0; i < work.size(); ++i) {
101  callableMap[work[i]]->walk([&](CallOpInterface call) {
102  auto symbol = dyn_cast<SymbolRefAttr>(call.getCallableForCallee());
103  if (!visited.contains(symbol)) {
104  visited.insert(symbol);
105  work.push_back(symbol);
106  }
107  });
108 
109  for (auto load : opLoadSymbols[work[i]])
110  loadSymbols.insert(load);
111 
112  for (auto store : opStoreSymbols[work[i]])
113  storeSymbols.insert(store);
114  }
115 
116  loadSymbolsMap[thisSymbol] = std::move(loadSymbols);
117  storeSymbolsMap[thisSymbol] = std::move(storeSymbols);
118  }
119 
120  return success();
121 }
122 
123 // Process each operation in the block deleting unneeded loads / stores,
124 // recursing on subblocks and checking function calls.
125 void MLProgramPipelineGlobals::processBlock(
126  Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad,
127  llvm::DenseSet<SymbolRefAttr> &symbolStore) {
128 
132  for (auto &op : block) {
133  // If this is a global load, remap to a previous value if known
134  // and delete this load. Remember that this value is the currently
135  // known load.
136  if (auto load = mlir::dyn_cast<GlobalLoadOp>(op)) {
137  auto ref = load.getGlobal();
138  symbolLoad.insert(ref);
139  if (previousLoads.contains(ref)) {
140  toDelete.push_back(&op);
141  load.getResult().replaceAllUsesWith(previousLoads[ref]);
142  } else {
143  previousLoads[ref] = load.getResult();
144  }
145  continue;
146  }
147 
148  // Delete a previous store if it exists and is not needed, update
149  // the most recent known value for this global ref.
150  if (auto store = mlir::dyn_cast<GlobalStoreOp>(op)) {
151  auto ref = store.getGlobal();
152  symbolStore.insert(ref);
153  if (previousStores.contains(ref)) {
154  toDelete.push_back(previousStores.find(ref)->getSecond());
155  }
156 
157  previousLoads[ref] = store.getValue();
158  previousStores[ref] = &op;
159  continue;
160  }
161 
162  // If a function is called, clear known values for loads/stores used by
163  // the function or its sub-functions.
164  if (auto call = mlir::dyn_cast<CallOpInterface>(op)) {
165  auto loadSymbols =
166  loadSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];
167  auto storeSymbols =
168  storeSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];
169 
170  for (auto sym : loadSymbols) {
171  previousStores.erase(sym);
172  }
173 
174  for (auto sym : storeSymbols) {
175  previousLoads.erase(sym);
176  previousStores.erase(sym);
177  }
178  continue;
179  }
180 
181  // If the op has sub-regions, recurse inside. We make no guarantees whether
182  // the recursion occurs.
183  llvm::DenseSet<SymbolRefAttr> opSymbolLoad;
184  llvm::DenseSet<SymbolRefAttr> opSymbolStore;
185  for (auto &region : op.getRegions()) {
186  for (auto &block : region) {
187  processBlock(block, opSymbolLoad, opSymbolStore);
188  }
189  }
190 
191  // Update current state from the subblock.
192  for (auto change : opSymbolLoad) {
193  symbolLoad.insert(change);
194  previousStores.erase(change);
195  }
196 
197  for (auto change : opSymbolStore) {
198  symbolStore.insert(change);
199  previousLoads.erase(change);
200  previousStores.erase(change);
201  }
202  }
203 
204  for (auto *op : toDelete) {
205  op->erase();
206  }
207 }
208 
209 void MLProgramPipelineGlobals::runOnOperation() {
210  auto targetOp = getOperation();
211  if (failed(buildGlobalMap(targetOp))) {
212  return;
213  }
214 
215  for (auto &funcOp : *targetOp.getBody()) {
216  for (auto &region : funcOp.getRegions()) {
217  for (auto &block : region.getBlocks()) {
218  llvm::DenseSet<SymbolRefAttr> symbolsLoaded;
219  llvm::DenseSet<SymbolRefAttr> symbolsStored;
220  processBlock(block, symbolsLoaded, symbolsStored);
221  }
222  }
223  }
224 }
225 
226 } // namespace
227 
228 std::unique_ptr<OperationPass<mlir::ModuleOp>>
230  return std::make_unique<MLProgramPipelineGlobals>();
231 }
232 
233 } // namespace ml_program
234 } // namespace mlir
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static WalkResult advance()
Definition: Visitors.h:52
static WalkResult interrupt()
Definition: Visitors.h:51
std::unique_ptr< OperationPass< ModuleOp > > createMLProgramPipelineGlobalsPass()
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72