18 namespace ml_program {
19 #define GEN_PASS_DEF_MLPROGRAMPIPELINEGLOBALSPASS
20 #include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc"
24 class MLProgramPipelineGlobals
25 :
public impl::MLProgramPipelineGlobalsPassBase<MLProgramPipelineGlobals> {
27 void runOnOperation()
override;
30 LogicalResult buildGlobalMap(ModuleOp op);
40 static Operation *getFromSymbol(Operation *baseOp, SymbolRefAttr symbol) {
41 for (
auto *op = baseOp; op; op = op->getParentOp()) {
51 LogicalResult MLProgramPipelineGlobals::buildGlobalMap(ModuleOp module) {
53 auto res = module->walk([&](Operation *op) {
54 if (
auto caller = mlir::dyn_cast<CallOpInterface>(op)) {
55 auto callable = caller.getCallableForCallee();
57 if (mlir::isa<Value>(callable)) {
61 auto symbol = mlir::dyn_cast<SymbolRefAttr>(callable);
62 auto *func = getFromSymbol(op, symbol);
63 callableMap[symbol] = func;
68 if (res.wasInterrupted()) {
76 for (
auto callable : callableMap) {
80 callable.getSecond()->walk(
81 [&](GlobalLoadOp op) { loadSymbols.insert(op.getGlobal()); });
83 callable.getSecond()->walk(
84 [&](GlobalStoreOp op) { storeSymbols.insert(op.getGlobal()); });
86 opLoadSymbols[callable.getFirst()] = std::move(loadSymbols);
87 opStoreSymbols[callable.getFirst()] = std::move(storeSymbols);
93 for (
auto callable : callableMap) {
94 SymbolRefAttr thisSymbol = llvm::dyn_cast<SymbolRefAttr>(callable.first);
100 for (
size_t i = 0; i < work.size(); ++i) {
101 callableMap[work[i]]->walk([&](CallOpInterface call) {
102 auto symbol = dyn_cast<SymbolRefAttr>(call.getCallableForCallee());
103 if (visited.insert(symbol).second)
104 work.push_back(symbol);
107 loadSymbols.insert_range(opLoadSymbols[work[i]]);
109 storeSymbols.insert_range(opStoreSymbols[work[i]]);
112 loadSymbolsMap[thisSymbol] = std::move(loadSymbols);
113 storeSymbolsMap[thisSymbol] = std::move(storeSymbols);
121 void MLProgramPipelineGlobals::processBlock(
128 for (
auto &op : block) {
132 if (
auto load = mlir::dyn_cast<GlobalLoadOp>(op)) {
133 auto ref = load.getGlobal();
134 symbolLoad.insert(ref);
135 if (previousLoads.contains(ref)) {
136 toDelete.push_back(&op);
137 load.getResult().replaceAllUsesWith(previousLoads[ref]);
139 previousLoads[ref] = load.getResult();
146 if (
auto store = mlir::dyn_cast<GlobalStoreOp>(op)) {
147 auto ref = store.getGlobal();
148 symbolStore.insert(ref);
149 auto it = previousStores.find(ref);
150 if (it != previousStores.end()) {
151 toDelete.push_back(it->getSecond());
154 previousLoads[ref] = store.getValue();
155 previousStores[ref] = &op;
161 if (
auto call = mlir::dyn_cast<CallOpInterface>(op)) {
163 loadSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];
165 storeSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];
167 for (
auto sym : loadSymbols) {
168 previousStores.erase(sym);
171 for (
auto sym : storeSymbols) {
172 previousLoads.erase(sym);
173 previousStores.erase(sym);
182 for (
auto ®ion : op.getRegions()) {
183 for (
auto &block : region) {
184 processBlock(block, opSymbolLoad, opSymbolStore);
189 for (
auto change : opSymbolLoad) {
190 symbolLoad.insert(change);
191 previousStores.erase(change);
194 for (
auto change : opSymbolStore) {
195 symbolStore.insert(change);
196 previousLoads.erase(change);
197 previousStores.erase(change);
201 for (
auto *op : toDelete) {
206 void MLProgramPipelineGlobals::runOnOperation() {
207 auto targetOp = getOperation();
208 if (failed(buildGlobalMap(targetOp))) {
212 for (
auto &funcOp : *targetOp.getBody()) {
213 for (
auto ®ion : funcOp.getRegions()) {
214 for (
auto &block : region.getBlocks()) {
217 processBlock(block, symbolsLoaded, symbolsStored);
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static WalkResult advance()
static WalkResult interrupt()
Include the generated interface declarations.