MLIR 22.0.0git
PipelineGlobalOps.cpp
Go to the documentation of this file.
1//===- PipelineGlobalOpsPass.cpp - Pipeline Global Ops Pass ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
12#include "mlir/IR/BuiltinOps.h"
13
14namespace mlir {
15namespace ml_program {
16#define GEN_PASS_DEF_MLPROGRAMPIPELINEGLOBALSPASS
17#include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc"
18
19namespace {
20
21class MLProgramPipelineGlobals
22 : public impl::MLProgramPipelineGlobalsPassBase<MLProgramPipelineGlobals> {
23public:
24 void runOnOperation() override;
26private:
27 LogicalResult buildGlobalMap(ModuleOp op);
28
29 void processBlock(Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad,
34};
35
36// Traverses upwards searchign for the operation mapped by the symbol.
37static Operation *getFromSymbol(Operation *baseOp, SymbolRefAttr symbol) {
38 for (auto *op = baseOp; op; op = op->getParentOp()) {
39 auto *lookup = SymbolTable::lookupNearestSymbolFrom(op, symbol);
40 if (lookup)
41 return lookup;
42 }
43 return nullptr;
44}
46// Builds map from a symbol to MLProgram global symbols loaded or stored
47// during processing.
48LogicalResult MLProgramPipelineGlobals::buildGlobalMap(ModuleOp module) {
50 auto res = module->walk([&](Operation *op) {
51 if (auto caller = mlir::dyn_cast<CallOpInterface>(op)) {
52 auto callable = caller.getCallableForCallee();
53 // For now we do not know how to handle Value based tracing, so fail.
54 if (mlir::isa<Value>(callable)) {
55 return WalkResult::interrupt();
56 }
57
58 auto symbol = mlir::dyn_cast<SymbolRefAttr>(callable);
59 auto *func = getFromSymbol(op, symbol);
60 callableMap[symbol] = func;
61 }
62 return WalkResult::advance();
63 });
64
65 if (res.wasInterrupted()) {
66 return failure();
67 }
68
69 // First grab all symbols loaded or stored by each function. This
70 // will not handle calls initially.
73 for (auto callable : callableMap) {
76
77 callable.getSecond()->walk(
78 [&](GlobalLoadOp op) { loadSymbols.insert(op.getGlobal()); });
80 callable.getSecond()->walk(
81 [&](GlobalStoreOp op) { storeSymbols.insert(op.getGlobal()); });
82
83 opLoadSymbols[callable.getFirst()] = std::move(loadSymbols);
84 opStoreSymbols[callable.getFirst()] = std::move(storeSymbols);
85 }
86
87 // For each callable function we find each global loaded/stored within the
88 // function or a nested called function. This includes recursion checking to
89 // avoid infinitely recursing.
90 for (auto callable : callableMap) {
91 SymbolRefAttr thisSymbol = llvm::dyn_cast<SymbolRefAttr>(callable.first);
92 llvm::SmallVector<SymbolRefAttr> work = {thisSymbol};
93 llvm::DenseSet<SymbolRefAttr> visited = {thisSymbol};
94 llvm::DenseSet<SymbolRefAttr> loadSymbols;
95 llvm::DenseSet<SymbolRefAttr> storeSymbols;
96
97 for (size_t i = 0; i < work.size(); ++i) {
98 callableMap[work[i]]->walk([&](CallOpInterface call) {
99 auto symbol = dyn_cast<SymbolRefAttr>(call.getCallableForCallee());
100 if (visited.insert(symbol).second)
101 work.push_back(symbol);
102 });
103
104 loadSymbols.insert_range(opLoadSymbols[work[i]]);
105
106 storeSymbols.insert_range(opStoreSymbols[work[i]]);
107 }
108
109 loadSymbolsMap[thisSymbol] = std::move(loadSymbols);
110 storeSymbolsMap[thisSymbol] = std::move(storeSymbols);
111 }
112
113 return success();
114}
115
116// Process each operation in the block deleting unneeded loads / stores,
117// recursing on subblocks and checking function calls.
118void MLProgramPipelineGlobals::processBlock(
119 Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad,
120 llvm::DenseSet<SymbolRefAttr> &symbolStore) {
121
122 llvm::DenseMap<SymbolRefAttr, Value> previousLoads;
123 llvm::DenseMap<SymbolRefAttr, Operation *> previousStores;
124 llvm::SmallVector<Operation *> toDelete;
125 for (auto &op : block) {
126 // If this is a global load, remap to a previous value if known
127 // and delete this load. Remember that this value is the currently
128 // known load.
129 if (auto load = mlir::dyn_cast<GlobalLoadOp>(op)) {
130 auto ref = load.getGlobal();
131 symbolLoad.insert(ref);
132 if (previousLoads.contains(ref)) {
133 toDelete.push_back(&op);
134 load.getResult().replaceAllUsesWith(previousLoads[ref]);
135 } else {
136 previousLoads[ref] = load.getResult();
137 }
138 continue;
139 }
140
141 // Delete a previous store if it exists and is not needed, update
142 // the most recent known value for this global ref.
143 if (auto store = mlir::dyn_cast<GlobalStoreOp>(op)) {
144 auto ref = store.getGlobal();
145 symbolStore.insert(ref);
146 auto it = previousStores.find(ref);
147 if (it != previousStores.end()) {
148 toDelete.push_back(it->getSecond());
149 }
150
151 previousLoads[ref] = store.getValue();
152 previousStores[ref] = &op;
153 continue;
154 }
155
156 // If a function is called, clear known values for loads/stores used by
157 // the function or its sub-functions.
158 if (auto call = mlir::dyn_cast<CallOpInterface>(op)) {
159 auto loadSymbols =
160 loadSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];
161 auto storeSymbols =
162 storeSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];
163
164 for (auto sym : loadSymbols) {
165 previousStores.erase(sym);
166 }
167
168 for (auto sym : storeSymbols) {
169 previousLoads.erase(sym);
170 previousStores.erase(sym);
171 }
172 continue;
173 }
174
175 // If the op has sub-regions, recurse inside. We make no guarantees whether
176 // the recursion occurs.
177 llvm::DenseSet<SymbolRefAttr> opSymbolLoad;
178 llvm::DenseSet<SymbolRefAttr> opSymbolStore;
179 for (auto &region : op.getRegions()) {
180 for (auto &block : region) {
181 processBlock(block, opSymbolLoad, opSymbolStore);
182 }
183 }
184
185 // Update current state from the subblock.
186 for (auto change : opSymbolLoad) {
187 symbolLoad.insert(change);
188 previousStores.erase(change);
189 }
190
191 for (auto change : opSymbolStore) {
192 symbolStore.insert(change);
193 previousLoads.erase(change);
194 previousStores.erase(change);
195 }
196 }
197
198 for (auto *op : toDelete) {
199 op->erase();
200 }
201}
202
203void MLProgramPipelineGlobals::runOnOperation() {
204 auto targetOp = getOperation();
205 if (failed(buildGlobalMap(targetOp))) {
206 return;
207 }
208
209 for (auto &funcOp : *targetOp.getBody()) {
210 for (auto &region : funcOp.getRegions()) {
211 for (auto &block : region.getBlocks()) {
212 llvm::DenseSet<SymbolRefAttr> symbolsLoaded;
213 llvm::DenseSet<SymbolRefAttr> symbolsStored;
214 processBlock(block, symbolsLoaded, symbolsStored);
215 }
216 }
217 }
218}
219
220} // namespace
221
222} // namespace ml_program
223} // namespace mlir
return success()
auto load
Block represents an ordered list of Operations.
Definition Block.h:33
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.