18 namespace ml_program {
19 #define GEN_PASS_DEF_MLPROGRAMPIPELINEGLOBALS
20 #include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc"
24 class MLProgramPipelineGlobals
25 :
public impl::MLProgramPipelineGlobalsBase<MLProgramPipelineGlobals> {
27 void runOnOperation()
override;
30 LogicalResult buildGlobalMap(ModuleOp op);
40 static Operation *getFromSymbol(Operation *baseOp, SymbolRefAttr symbol) {
41 for (
auto *op = baseOp; op; op = op->getParentOp()) {
51 LogicalResult MLProgramPipelineGlobals::buildGlobalMap(ModuleOp module) {
53 auto res = module->walk([&](Operation *op) {
54 if (
auto caller = mlir::dyn_cast<CallOpInterface>(op)) {
55 auto callable = caller.getCallableForCallee();
57 if (mlir::isa<Value>(callable)) {
61 auto symbol = mlir::dyn_cast<SymbolRefAttr>(callable);
62 auto *func = getFromSymbol(op, symbol);
63 callableMap[symbol] = func;
68 if (res.wasInterrupted()) {
76 for (
auto callable : callableMap) {
80 callable.getSecond()->walk(
81 [&](GlobalLoadOp op) { loadSymbols.insert(op.getGlobal()); });
83 callable.getSecond()->walk(
84 [&](GlobalStoreOp op) { storeSymbols.insert(op.getGlobal()); });
86 opLoadSymbols[callable.getFirst()] = std::move(loadSymbols);
87 opStoreSymbols[callable.getFirst()] = std::move(storeSymbols);
93 for (
auto callable : callableMap) {
94 SymbolRefAttr thisSymbol = llvm::dyn_cast<SymbolRefAttr>(callable.first);
100 for (
size_t i = 0; i < work.size(); ++i) {
101 callableMap[work[i]]->walk([&](CallOpInterface call) {
102 auto symbol = dyn_cast<SymbolRefAttr>(call.getCallableForCallee());
103 if (visited.insert(symbol).second)
104 work.push_back(symbol);
107 for (
auto load : opLoadSymbols[work[i]])
108 loadSymbols.insert(load);
110 for (
auto store : opStoreSymbols[work[i]])
111 storeSymbols.insert(store);
114 loadSymbolsMap[thisSymbol] = std::move(loadSymbols);
115 storeSymbolsMap[thisSymbol] = std::move(storeSymbols);
123 void MLProgramPipelineGlobals::processBlock(
130 for (
auto &op : block) {
134 if (
auto load = mlir::dyn_cast<GlobalLoadOp>(op)) {
135 auto ref = load.getGlobal();
136 symbolLoad.insert(ref);
137 if (previousLoads.contains(ref)) {
138 toDelete.push_back(&op);
139 load.getResult().replaceAllUsesWith(previousLoads[ref]);
141 previousLoads[ref] = load.getResult();
148 if (
auto store = mlir::dyn_cast<GlobalStoreOp>(op)) {
149 auto ref = store.getGlobal();
150 symbolStore.insert(ref);
151 auto it = previousStores.find(ref);
152 if (it != previousStores.end()) {
153 toDelete.push_back(it->getSecond());
156 previousLoads[ref] = store.getValue();
157 previousStores[ref] = &op;
163 if (
auto call = mlir::dyn_cast<CallOpInterface>(op)) {
165 loadSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];
167 storeSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];
169 for (
auto sym : loadSymbols) {
170 previousStores.erase(sym);
173 for (
auto sym : storeSymbols) {
174 previousLoads.erase(sym);
175 previousStores.erase(sym);
184 for (
auto ®ion : op.getRegions()) {
185 for (
auto &block : region) {
186 processBlock(block, opSymbolLoad, opSymbolStore);
191 for (
auto change : opSymbolLoad) {
192 symbolLoad.insert(change);
193 previousStores.erase(change);
196 for (
auto change : opSymbolStore) {
197 symbolStore.insert(change);
198 previousLoads.erase(change);
199 previousStores.erase(change);
203 for (
auto *op : toDelete) {
208 void MLProgramPipelineGlobals::runOnOperation() {
209 auto targetOp = getOperation();
210 if (failed(buildGlobalMap(targetOp))) {
214 for (
auto &funcOp : *targetOp.getBody()) {
215 for (
auto ®ion : funcOp.getRegions()) {
216 for (
auto &block : region.getBlocks()) {
219 processBlock(block, symbolsLoaded, symbolsStored);
227 std::unique_ptr<OperationPass<mlir::ModuleOp>>
229 return std::make_unique<MLProgramPipelineGlobals>();
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static WalkResult advance()
static WalkResult interrupt()
std::unique_ptr< OperationPass< ModuleOp > > createMLProgramPipelineGlobalsPass()
Include the generated interface declarations.