16#define GEN_PASS_DEF_MLPROGRAMPIPELINEGLOBALSPASS
17#include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc"
21class MLProgramPipelineGlobals
24 void runOnOperation()
override;
27 LogicalResult buildGlobalMap(ModuleOp op);
38 for (
auto *op = baseOp; op; op = op->
getParentOp()) {
48LogicalResult MLProgramPipelineGlobals::buildGlobalMap(ModuleOp module) {
50 auto res =
module->walk([&](Operation *op) {
51 if (auto caller = mlir::dyn_cast<CallOpInterface>(op)) {
52 auto callable = caller.getCallableForCallee();
54 if (mlir::isa<Value>(callable)) {
58 auto symbol = mlir::dyn_cast<SymbolRefAttr>(callable);
59 auto *
func = getFromSymbol(op, symbol);
60 callableMap[symbol] =
func;
65 if (res.wasInterrupted()) {
73 for (
auto callable : callableMap) {
77 callable.getSecond()->walk(
78 [&](GlobalLoadOp op) { loadSymbols.insert(op.getGlobal()); });
80 callable.getSecond()->walk(
81 [&](GlobalStoreOp op) { storeSymbols.insert(op.getGlobal()); });
83 opLoadSymbols[callable.getFirst()] = std::move(loadSymbols);
84 opStoreSymbols[callable.getFirst()] = std::move(storeSymbols);
90 for (
auto callable : callableMap) {
91 SymbolRefAttr thisSymbol = llvm::dyn_cast<SymbolRefAttr>(callable.first);
93 llvm::DenseSet<SymbolRefAttr> visited = {thisSymbol};
94 llvm::DenseSet<SymbolRefAttr> loadSymbols;
95 llvm::DenseSet<SymbolRefAttr> storeSymbols;
97 for (
size_t i = 0; i < work.size(); ++i) {
98 callableMap[work[i]]->walk([&](CallOpInterface call) {
99 auto symbol = dyn_cast<SymbolRefAttr>(call.getCallableForCallee());
100 if (visited.insert(symbol).second)
101 work.push_back(symbol);
104 loadSymbols.insert_range(opLoadSymbols[work[i]]);
106 storeSymbols.insert_range(opStoreSymbols[work[i]]);
109 loadSymbolsMap[thisSymbol] = std::move(loadSymbols);
110 storeSymbolsMap[thisSymbol] = std::move(storeSymbols);
118void MLProgramPipelineGlobals::processBlock(
119 Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad,
120 llvm::DenseSet<SymbolRefAttr> &symbolStore) {
122 llvm::DenseMap<SymbolRefAttr, Value> previousLoads;
123 llvm::DenseMap<SymbolRefAttr, Operation *> previousStores;
124 llvm::SmallVector<Operation *> toDelete;
125 for (
auto &op : block) {
129 if (
auto load = mlir::dyn_cast<GlobalLoadOp>(op)) {
130 auto ref =
load.getGlobal();
131 symbolLoad.insert(ref);
132 if (previousLoads.contains(ref)) {
133 toDelete.push_back(&op);
134 load.getResult().replaceAllUsesWith(previousLoads[ref]);
136 previousLoads[ref] =
load.getResult();
143 if (
auto store = mlir::dyn_cast<GlobalStoreOp>(op)) {
144 auto ref = store.getGlobal();
145 symbolStore.insert(ref);
146 auto it = previousStores.find(ref);
147 if (it != previousStores.end()) {
148 toDelete.push_back(it->getSecond());
151 previousLoads[ref] = store.getValue();
152 previousStores[ref] = &op;
158 if (
auto call = mlir::dyn_cast<CallOpInterface>(op)) {
160 loadSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];
162 storeSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];
164 for (
auto sym : loadSymbols) {
165 previousStores.erase(sym);
168 for (
auto sym : storeSymbols) {
169 previousLoads.erase(sym);
170 previousStores.erase(sym);
177 llvm::DenseSet<SymbolRefAttr> opSymbolLoad;
178 llvm::DenseSet<SymbolRefAttr> opSymbolStore;
179 for (
auto ®ion : op.getRegions()) {
180 for (
auto &block : region) {
181 processBlock(block, opSymbolLoad, opSymbolStore);
186 for (
auto change : opSymbolLoad) {
187 symbolLoad.insert(change);
188 previousStores.erase(change);
191 for (
auto change : opSymbolStore) {
192 symbolStore.insert(change);
193 previousLoads.erase(change);
194 previousStores.erase(change);
198 for (
auto *op : toDelete) {
203void MLProgramPipelineGlobals::runOnOperation() {
204 auto targetOp = getOperation();
205 if (
failed(buildGlobalMap(targetOp))) {
209 for (
auto &funcOp : *targetOp.getBody()) {
210 for (
auto ®ion : funcOp.getRegions()) {
211 for (
auto &block : region.getBlocks()) {
212 llvm::DenseSet<SymbolRefAttr> symbolsLoaded;
213 llvm::DenseSet<SymbolRefAttr> symbolsStored;
214 processBlock(block, symbolsLoaded, symbolsStored);
Block represents an ordered list of Operations.
Operation is the basic unit of execution within MLIR.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static WalkResult advance()
static WalkResult interrupt()
Include the generated interface declarations.