MLIR 23.0.0git
PipelineGlobalOps.cpp
Go to the documentation of this file.
1//===- PipelineGlobalOpsPass.cpp - Pipeline Global Ops Pass ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
12#include "mlir/IR/BuiltinOps.h"
13
14namespace mlir {
15namespace ml_program {
16#define GEN_PASS_DEF_MLPROGRAMPIPELINEGLOBALSPASS
17#include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc"
18
19namespace {
20
21class MLProgramPipelineGlobals
22 : public impl::MLProgramPipelineGlobalsPassBase<MLProgramPipelineGlobals> {
23public:
24 void runOnOperation() override;
26private:
27 LogicalResult buildGlobalMap(ModuleOp op);
28
29 void processBlock(Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad,
34};
35
36// Traverses upwards searching for the operation mapped by the symbol.
37static Operation *getFromSymbol(Operation *baseOp, SymbolRefAttr symbol) {
38 for (auto *op = baseOp; op; op = op->getParentOp()) {
39 auto *lookup = SymbolTable::lookupNearestSymbolFrom(op, symbol);
40 if (lookup)
41 return lookup;
42 }
43 return nullptr;
44}
46// Builds map from a symbol to MLProgram global symbols loaded or stored
47// during processing.
48LogicalResult MLProgramPipelineGlobals::buildGlobalMap(ModuleOp module) {
50 auto res = module->walk([&](Operation *op) {
51 if (auto caller = mlir::dyn_cast<CallOpInterface>(op)) {
52 auto callable = caller.getCallableForCallee();
53 // For now we do not know how to handle Value based tracing, so fail.
54 if (mlir::isa<Value>(callable)) {
55 return WalkResult::interrupt();
56 }
57
58 auto symbol = mlir::dyn_cast<SymbolRefAttr>(callable);
59 auto *func = getFromSymbol(op, symbol);
60 // If the callee cannot be resolved, we cannot safely analyze the IR.
61 if (!func)
62 return WalkResult::interrupt();
63 callableMap[symbol] = func;
64 }
65 return WalkResult::advance();
66 });
67
68 if (res.wasInterrupted()) {
69 return failure();
70 }
71
72 // First grab all symbols loaded or stored by each function. This
76 for (auto callable : callableMap) {
80 callable.getSecond()->walk(
81 [&](GlobalLoadOp op) { loadSymbols.insert(op.getGlobal()); });
82
83 callable.getSecond()->walk(
84 [&](GlobalStoreOp op) { storeSymbols.insert(op.getGlobal()); });
85
86 opLoadSymbols[callable.getFirst()] = std::move(loadSymbols);
87 opStoreSymbols[callable.getFirst()] = std::move(storeSymbols);
88 }
89
90 // For each callable function we find each global loaded/stored within the
91 // function or a nested called function. This includes recursion checking to
92 // avoid infinitely recursing.
93 for (auto callable : callableMap) {
94 SymbolRefAttr thisSymbol = llvm::dyn_cast<SymbolRefAttr>(callable.first);
95 llvm::SmallVector<SymbolRefAttr> work = {thisSymbol};
96 llvm::DenseSet<SymbolRefAttr> visited = {thisSymbol};
97 llvm::DenseSet<SymbolRefAttr> loadSymbols;
98 llvm::DenseSet<SymbolRefAttr> storeSymbols;
99
100 for (size_t i = 0; i < work.size(); ++i) {
101 // Defensive: symbols in `work` should always be in `callableMap` since
102 // buildGlobalMap interrupted on any unresolvable callee, but use find to
103 // avoid inserting null entries via operator[].
104 auto it = callableMap.find(work[i]);
105 assert(it != callableMap.end() && "Expected callable in callableMap");
106 it->second->walk([&](CallOpInterface call) {
107 auto symbol = cast<SymbolRefAttr>(call.getCallableForCallee());
108 if (visited.insert(symbol).second)
109 work.push_back(symbol);
110 });
111
112 loadSymbols.insert_range(opLoadSymbols[work[i]]);
113
114 storeSymbols.insert_range(opStoreSymbols[work[i]]);
115 }
116
117 loadSymbolsMap[thisSymbol] = std::move(loadSymbols);
118 storeSymbolsMap[thisSymbol] = std::move(storeSymbols);
119 }
120
121 return success();
122}
123
124// Process each operation in the block deleting unneeded loads / stores,
125// recursing on subblocks and checking function calls.
126void MLProgramPipelineGlobals::processBlock(
127 Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad,
128 llvm::DenseSet<SymbolRefAttr> &symbolStore) {
129
130 llvm::DenseMap<SymbolRefAttr, Value> previousLoads;
131 llvm::DenseMap<SymbolRefAttr, Operation *> previousStores;
132 llvm::SmallVector<Operation *> toDelete;
133 for (auto &op : block) {
134 // If this is a global load, remap to a previous value if known
135 // and delete this load. Remember that this value is the currently
136 // known load.
137 if (auto load = mlir::dyn_cast<GlobalLoadOp>(op)) {
138 auto ref = load.getGlobal();
139 symbolLoad.insert(ref);
140 if (previousLoads.contains(ref)) {
141 toDelete.push_back(&op);
142 load.getResult().replaceAllUsesWith(previousLoads[ref]);
143 } else {
144 previousLoads[ref] = load.getResult();
145 }
146 continue;
147 }
148
149 // Delete a previous store if it exists and is not needed, update
150 // the most recent known value for this global ref.
151 if (auto store = mlir::dyn_cast<GlobalStoreOp>(op)) {
152 auto ref = store.getGlobal();
153 symbolStore.insert(ref);
154 auto it = previousStores.find(ref);
155 if (it != previousStores.end()) {
156 toDelete.push_back(it->getSecond());
157 }
158
159 previousLoads[ref] = store.getValue();
160 previousStores[ref] = &op;
161 continue;
162 }
163
164 // If a function is called, clear known values for loads/stores used by
165 // the function or its sub-functions.
166 if (auto call = mlir::dyn_cast<CallOpInterface>(op)) {
167 auto loadSymbols =
168 loadSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];
169 auto storeSymbols =
170 storeSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];
171
172 for (auto sym : loadSymbols) {
173 previousStores.erase(sym);
174 }
175
176 for (auto sym : storeSymbols) {
177 previousLoads.erase(sym);
178 previousStores.erase(sym);
179 }
180 continue;
181 }
182
183 // If the op has sub-regions, recurse inside. We make no guarantees whether
184 // the recursion occurs.
185 llvm::DenseSet<SymbolRefAttr> opSymbolLoad;
186 llvm::DenseSet<SymbolRefAttr> opSymbolStore;
187 for (auto &region : op.getRegions()) {
188 for (auto &block : region) {
189 processBlock(block, opSymbolLoad, opSymbolStore);
190 }
191 }
192
193 // Update current state from the subblock.
194 for (auto change : opSymbolLoad) {
195 symbolLoad.insert(change);
196 previousStores.erase(change);
197 }
198
199 for (auto change : opSymbolStore) {
200 symbolStore.insert(change);
201 previousLoads.erase(change);
202 previousStores.erase(change);
203 }
204 }
205
206 for (auto *op : toDelete) {
207 op->erase();
208 }
209}
210
211void MLProgramPipelineGlobals::runOnOperation() {
212 auto targetOp = getOperation();
213 if (failed(buildGlobalMap(targetOp))) {
214 return;
215 }
216
217 for (auto &funcOp : *targetOp.getBody()) {
218 for (auto &region : funcOp.getRegions()) {
219 for (auto &block : region.getBlocks()) {
220 llvm::DenseSet<SymbolRefAttr> symbolsLoaded;
221 llvm::DenseSet<SymbolRefAttr> symbolsStored;
222 processBlock(block, symbolsLoaded, symbolsStored);
223 }
224 }
225 }
226}
227
228} // namespace
229
230} // namespace ml_program
231} // namespace mlir
return success()
auto load
Block represents an ordered list of Operations.
Definition Block.h:33
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:252
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.