18 namespace ml_program {
19 #define GEN_PASS_DEF_MLPROGRAMPIPELINEGLOBALS
20 #include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc"
24 class MLProgramPipelineGlobals
25 :
public impl::MLProgramPipelineGlobalsBase<MLProgramPipelineGlobals> {
27 void runOnOperation()
override;
30 LogicalResult buildGlobalMap(ModuleOp op);
40 static Operation *getFromSymbol(Operation *baseOp, SymbolRefAttr symbol) {
41 for (
auto *op = baseOp; op; op = op->getParentOp()) {
51 LogicalResult MLProgramPipelineGlobals::buildGlobalMap(ModuleOp module) {
53 auto res = module->walk([&](Operation *op) {
54 if (
auto caller = mlir::dyn_cast<CallOpInterface>(op)) {
55 auto callable = caller.getCallableForCallee();
57 if (mlir::isa<Value>(callable)) {
61 auto symbol = mlir::dyn_cast<SymbolRefAttr>(callable);
62 auto *func = getFromSymbol(op, symbol);
63 callableMap[symbol] = func;
68 if (res.wasInterrupted()) {
76 for (
auto callable : callableMap) {
80 callable.getSecond()->walk(
81 [&](GlobalLoadOp op) { loadSymbols.insert(op.getGlobal()); });
83 callable.getSecond()->walk(
84 [&](GlobalStoreOp op) { storeSymbols.insert(op.getGlobal()); });
86 opLoadSymbols[callable.getFirst()] = std::move(loadSymbols);
87 opStoreSymbols[callable.getFirst()] = std::move(storeSymbols);
93 for (
auto callable : callableMap) {
94 SymbolRefAttr thisSymbol = llvm::dyn_cast<SymbolRefAttr>(callable.first);
100 for (
size_t i = 0; i < work.size(); ++i) {
101 callableMap[work[i]]->walk([&](CallOpInterface call) {
102 auto symbol = dyn_cast<SymbolRefAttr>(call.getCallableForCallee());
103 if (!visited.contains(symbol)) {
104 visited.insert(symbol);
105 work.push_back(symbol);
109 for (
auto load : opLoadSymbols[work[i]])
110 loadSymbols.insert(load);
112 for (
auto store : opStoreSymbols[work[i]])
113 storeSymbols.insert(store);
116 loadSymbolsMap[thisSymbol] = std::move(loadSymbols);
117 storeSymbolsMap[thisSymbol] = std::move(storeSymbols);
125 void MLProgramPipelineGlobals::processBlock(
132 for (
auto &op : block) {
136 if (
auto load = mlir::dyn_cast<GlobalLoadOp>(op)) {
137 auto ref = load.getGlobal();
138 symbolLoad.insert(ref);
139 if (previousLoads.contains(ref)) {
140 toDelete.push_back(&op);
141 load.getResult().replaceAllUsesWith(previousLoads[ref]);
143 previousLoads[ref] = load.getResult();
150 if (
auto store = mlir::dyn_cast<GlobalStoreOp>(op)) {
151 auto ref = store.getGlobal();
152 symbolStore.insert(ref);
153 if (previousStores.contains(ref)) {
154 toDelete.push_back(previousStores.find(ref)->getSecond());
157 previousLoads[ref] = store.getValue();
158 previousStores[ref] = &op;
164 if (
auto call = mlir::dyn_cast<CallOpInterface>(op)) {
166 loadSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];
168 storeSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];
170 for (
auto sym : loadSymbols) {
171 previousStores.erase(sym);
174 for (
auto sym : storeSymbols) {
175 previousLoads.erase(sym);
176 previousStores.erase(sym);
185 for (
auto ®ion : op.getRegions()) {
186 for (
auto &block : region) {
187 processBlock(block, opSymbolLoad, opSymbolStore);
192 for (
auto change : opSymbolLoad) {
193 symbolLoad.insert(change);
194 previousStores.erase(change);
197 for (
auto change : opSymbolStore) {
198 symbolStore.insert(change);
199 previousLoads.erase(change);
200 previousStores.erase(change);
204 for (
auto *op : toDelete) {
209 void MLProgramPipelineGlobals::runOnOperation() {
210 auto targetOp = getOperation();
211 if (
failed(buildGlobalMap(targetOp))) {
215 for (
auto &funcOp : *targetOp.getBody()) {
216 for (
auto ®ion : funcOp.getRegions()) {
217 for (
auto &block : region.getBlocks()) {
220 processBlock(block, symbolsLoaded, symbolsStored);
228 std::unique_ptr<OperationPass<mlir::ModuleOp>>
230 return std::make_unique<MLProgramPipelineGlobals>();
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static WalkResult advance()
static WalkResult interrupt()
std::unique_ptr< OperationPass< ModuleOp > > createMLProgramPipelineGlobalsPass()
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.