16#define GEN_PASS_DEF_MLPROGRAMPIPELINEGLOBALSPASS
17#include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc"
21class MLProgramPipelineGlobals
24 void runOnOperation()
override;
27 LogicalResult buildGlobalMap(ModuleOp op);
38 for (
auto *op = baseOp; op; op = op->
getParentOp()) {
48LogicalResult MLProgramPipelineGlobals::buildGlobalMap(ModuleOp module) {
50 auto res =
module->walk([&](Operation *op) {
51 if (auto caller = mlir::dyn_cast<CallOpInterface>(op)) {
52 auto callable = caller.getCallableForCallee();
54 if (mlir::isa<Value>(callable)) {
58 auto symbol = mlir::dyn_cast<SymbolRefAttr>(callable);
59 auto *
func = getFromSymbol(op, symbol);
63 callableMap[symbol] =
func;
68 if (res.wasInterrupted()) {
76 for (
auto callable : callableMap) {
80 callable.getSecond()->walk(
81 [&](GlobalLoadOp op) { loadSymbols.insert(op.getGlobal()); });
83 callable.getSecond()->walk(
84 [&](GlobalStoreOp op) { storeSymbols.insert(op.getGlobal()); });
86 opLoadSymbols[callable.getFirst()] = std::move(loadSymbols);
87 opStoreSymbols[callable.getFirst()] = std::move(storeSymbols);
93 for (
auto callable : callableMap) {
94 SymbolRefAttr thisSymbol = llvm::dyn_cast<SymbolRefAttr>(callable.first);
95 llvm::SmallVector<SymbolRefAttr> work = {thisSymbol};
96 llvm::DenseSet<SymbolRefAttr> visited = {thisSymbol};
97 llvm::DenseSet<SymbolRefAttr> loadSymbols;
98 llvm::DenseSet<SymbolRefAttr> storeSymbols;
100 for (
size_t i = 0; i < work.size(); ++i) {
104 auto it = callableMap.find(work[i]);
105 assert(it != callableMap.end() &&
"Expected callable in callableMap");
106 it->second->walk([&](CallOpInterface call) {
107 auto symbol = cast<SymbolRefAttr>(call.getCallableForCallee());
108 if (visited.insert(symbol).second)
109 work.push_back(symbol);
112 loadSymbols.insert_range(opLoadSymbols[work[i]]);
114 storeSymbols.insert_range(opStoreSymbols[work[i]]);
117 loadSymbolsMap[thisSymbol] = std::move(loadSymbols);
118 storeSymbolsMap[thisSymbol] = std::move(storeSymbols);
126void MLProgramPipelineGlobals::processBlock(
127 Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad,
128 llvm::DenseSet<SymbolRefAttr> &symbolStore) {
130 llvm::DenseMap<SymbolRefAttr, Value> previousLoads;
131 llvm::DenseMap<SymbolRefAttr, Operation *> previousStores;
132 llvm::SmallVector<Operation *> toDelete;
133 for (
auto &op : block) {
137 if (
auto load = mlir::dyn_cast<GlobalLoadOp>(op)) {
138 auto ref =
load.getGlobal();
139 symbolLoad.insert(ref);
140 if (previousLoads.contains(ref)) {
141 toDelete.push_back(&op);
142 load.getResult().replaceAllUsesWith(previousLoads[ref]);
144 previousLoads[ref] =
load.getResult();
151 if (
auto store = mlir::dyn_cast<GlobalStoreOp>(op)) {
152 auto ref = store.getGlobal();
153 symbolStore.insert(ref);
154 auto it = previousStores.find(ref);
155 if (it != previousStores.end()) {
156 toDelete.push_back(it->getSecond());
159 previousLoads[ref] = store.getValue();
160 previousStores[ref] = &op;
166 if (
auto call = mlir::dyn_cast<CallOpInterface>(op)) {
168 loadSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];
170 storeSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];
172 for (
auto sym : loadSymbols) {
173 previousStores.erase(sym);
176 for (
auto sym : storeSymbols) {
177 previousLoads.erase(sym);
178 previousStores.erase(sym);
185 llvm::DenseSet<SymbolRefAttr> opSymbolLoad;
186 llvm::DenseSet<SymbolRefAttr> opSymbolStore;
187 for (
auto ®ion : op.getRegions()) {
188 for (
auto &block : region) {
189 processBlock(block, opSymbolLoad, opSymbolStore);
194 for (
auto change : opSymbolLoad) {
195 symbolLoad.insert(change);
196 previousStores.erase(change);
199 for (
auto change : opSymbolStore) {
200 symbolStore.insert(change);
201 previousLoads.erase(change);
202 previousStores.erase(change);
206 for (
auto *op : toDelete) {
211void MLProgramPipelineGlobals::runOnOperation() {
212 auto targetOp = getOperation();
213 if (
failed(buildGlobalMap(targetOp))) {
217 for (
auto &funcOp : *targetOp.getBody()) {
218 for (
auto ®ion : funcOp.getRegions()) {
219 for (
auto &block : region.getBlocks()) {
220 llvm::DenseSet<SymbolRefAttr> symbolsLoaded;
221 llvm::DenseSet<SymbolRefAttr> symbolsStored;
222 processBlock(block, symbolsLoaded, symbolsStored);
Block represents an ordered list of Operations.
Operation is the basic unit of execution within MLIR.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static WalkResult advance()
static WalkResult interrupt()
Include the generated interface declarations.