15 namespace ml_program {
16 #define GEN_PASS_DEF_MLPROGRAMPIPELINEGLOBALSPASS
17 #include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc"
21 class MLProgramPipelineGlobals
22 :
public impl::MLProgramPipelineGlobalsPassBase<MLProgramPipelineGlobals> {
24 void runOnOperation()
override;
27 LogicalResult buildGlobalMap(ModuleOp op);
37 static Operation *getFromSymbol(Operation *baseOp, SymbolRefAttr symbol) {
38 for (
auto *op = baseOp; op; op = op->getParentOp()) {
48 LogicalResult MLProgramPipelineGlobals::buildGlobalMap(ModuleOp module) {
50 auto res = module->walk([&](Operation *op) {
51 if (
auto caller = mlir::dyn_cast<CallOpInterface>(op)) {
52 auto callable = caller.getCallableForCallee();
54 if (mlir::isa<Value>(callable)) {
58 auto symbol = mlir::dyn_cast<SymbolRefAttr>(callable);
59 auto *func = getFromSymbol(op, symbol);
60 callableMap[symbol] = func;
65 if (res.wasInterrupted()) {
73 for (
auto callable : callableMap) {
77 callable.getSecond()->walk(
78 [&](GlobalLoadOp op) { loadSymbols.insert(op.getGlobal()); });
80 callable.getSecond()->walk(
81 [&](GlobalStoreOp op) { storeSymbols.insert(op.getGlobal()); });
83 opLoadSymbols[callable.getFirst()] = std::move(loadSymbols);
84 opStoreSymbols[callable.getFirst()] = std::move(storeSymbols);
90 for (
auto callable : callableMap) {
91 SymbolRefAttr thisSymbol = llvm::dyn_cast<SymbolRefAttr>(callable.first);
97 for (
size_t i = 0; i < work.size(); ++i) {
98 callableMap[work[i]]->walk([&](CallOpInterface call) {
99 auto symbol = dyn_cast<SymbolRefAttr>(call.getCallableForCallee());
100 if (visited.insert(symbol).second)
101 work.push_back(symbol);
104 loadSymbols.insert_range(opLoadSymbols[work[i]]);
106 storeSymbols.insert_range(opStoreSymbols[work[i]]);
109 loadSymbolsMap[thisSymbol] = std::move(loadSymbols);
110 storeSymbolsMap[thisSymbol] = std::move(storeSymbols);
118 void MLProgramPipelineGlobals::processBlock(
125 for (
auto &op : block) {
129 if (
auto load = mlir::dyn_cast<GlobalLoadOp>(op)) {
130 auto ref = load.getGlobal();
131 symbolLoad.insert(ref);
132 if (previousLoads.contains(ref)) {
133 toDelete.push_back(&op);
134 load.getResult().replaceAllUsesWith(previousLoads[ref]);
136 previousLoads[ref] = load.getResult();
143 if (
auto store = mlir::dyn_cast<GlobalStoreOp>(op)) {
144 auto ref = store.getGlobal();
145 symbolStore.insert(ref);
146 auto it = previousStores.find(ref);
147 if (it != previousStores.end()) {
148 toDelete.push_back(it->getSecond());
151 previousLoads[ref] = store.getValue();
152 previousStores[ref] = &op;
158 if (
auto call = mlir::dyn_cast<CallOpInterface>(op)) {
160 loadSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];
162 storeSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];
164 for (
auto sym : loadSymbols) {
165 previousStores.erase(sym);
168 for (
auto sym : storeSymbols) {
169 previousLoads.erase(sym);
170 previousStores.erase(sym);
179 for (
auto ®ion : op.getRegions()) {
180 for (
auto &block : region) {
181 processBlock(block, opSymbolLoad, opSymbolStore);
186 for (
auto change : opSymbolLoad) {
187 symbolLoad.insert(change);
188 previousStores.erase(change);
191 for (
auto change : opSymbolStore) {
192 symbolStore.insert(change);
193 previousLoads.erase(change);
194 previousStores.erase(change);
198 for (
auto *op : toDelete) {
203 void MLProgramPipelineGlobals::runOnOperation() {
204 auto targetOp = getOperation();
205 if (
failed(buildGlobalMap(targetOp))) {
209 for (
auto &funcOp : *targetOp.getBody()) {
210 for (
auto ®ion : funcOp.getRegions()) {
211 for (
auto &block : region.getBlocks()) {
214 processBlock(block, symbolsLoaded, symbolsStored);
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static WalkResult advance()
static WalkResult interrupt()
Include the generated interface declarations.