19 #include "llvm/ADT/Hashing.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/StringMap.h"
29 static StringAttr
renameSymbol(StringRef oldSymName,
unsigned &lastUsedID,
30 spirv::ModuleOp module) {
32 newSymName.push_back(
'_');
48 spirv::ModuleOp target,
49 spirv::ModuleOp source,
50 unsigned &lastUsedID) {
54 StringRef oldSymName = op.getName();
55 StringAttr newSymName =
renameSymbol(oldSymName, lastUsedID, target);
58 return op.emitError(
"unable to update all symbol uses for ")
59 << oldSymName <<
" to " << newSymName;
72 static llvm::hash_code
computeHash(SymbolOpInterface symbolOp) {
74 llvm::make_filter_range(symbolOp->getAttrs(), [](
NamedAttribute attr) {
75 return attr.getName() != SymbolTable::getSymbolAttrName();
78 return llvm::hash_combine(symbolOp->getName(),
79 llvm::hash_combine_range(range));
88 if (inputModules.empty())
91 spirv::ModuleOp firstModule = inputModules.front();
92 auto addressingModel = firstModule.getAddressingModel();
93 auto memoryModel = firstModule.getMemoryModel();
94 auto vceTriple = firstModule.getVceTriple();
98 for (
auto module : inputModules) {
99 if (module.getAddressingModel() != addressingModel ||
100 module.getMemoryModel() != memoryModel ||
101 module.getVceTriple() != vceTriple) {
102 module.emitError(
"input modules differ in addressing model, memory "
103 "model, and/or VCE triple");
108 auto combinedModule =
109 spirv::ModuleOp::create(combinedModuleBuilder, firstModule.getLoc(),
110 addressingModel, memoryModel, vceTriple);
120 llvm::StringMap<spirv::ModuleOp> symNameToModuleMap;
122 unsigned lastUsedID = 0;
124 for (
auto inputModule : inputModules) {
132 for (
auto &op : *combinedModule.getBody()) {
133 auto symbolOp = dyn_cast<SymbolOpInterface>(op);
137 StringRef oldSymName = symbolOp.getName();
139 if (!isa<FuncOp>(op) &&
144 StringRef newSymName = symbolOp.getName();
146 if (symRenameListener && oldSymName != newSymName) {
147 spirv::ModuleOp originalModule = symNameToModuleMap.lookup(oldSymName);
149 if (!originalModule) {
150 inputModule.emitError(
151 "unable to find original spirv::ModuleOp for symbol ")
156 symRenameListener(originalModule, oldSymName, newSymName);
160 symNameToModuleMap.erase(oldSymName);
163 symNameToModuleMap[newSymName] = originalModule;
169 for (
auto &op : *moduleClone->getBody()) {
170 auto symbolOp = dyn_cast<SymbolOpInterface>(op);
174 StringRef oldSymName = symbolOp.getName();
180 StringRef newSymName = symbolOp.getName();
182 if (symRenameListener) {
183 if (oldSymName != newSymName)
184 symRenameListener(inputModule, oldSymName, newSymName);
188 symNameToModuleMap.try_emplace(newSymName, inputModule);
193 if (!emplaceResult.second) {
194 inputModule.emitError(
"did not expect to find an entry for symbol ")
195 << symbolOp.getName();
202 for (
auto &op : *moduleClone->getBody())
203 combinedModuleBuilder.
insert(op.clone());
210 for (
auto &op : *combinedModule.getBody()) {
211 SymbolOpInterface symbolOp = dyn_cast<SymbolOpInterface>(op);
218 if (op.getNumOperands() != 0 || op.getNumResults() != 0)
225 auto result = hashToSymbolOp.try_emplace(
computeHash(symbolOp), symbolOp);
229 SymbolOpInterface replacementSymOp = result.first->second;
232 symbolOp, replacementSymOp.getNameAttr(), combinedModule))) {
233 symbolOp.emitError(
"unable to update all symbol uses for ")
234 << symbolOp.getName() <<
" to " << replacementSymOp.getName();
238 eraseList.push_back(symbolOp);
241 for (
auto symbolOp : eraseList)
244 return combinedModule;
static StringAttr renameSymbol(StringRef oldSymName, unsigned &lastUsedID, spirv::ModuleOp module)
Returns an unused symbol in module for oldSymbolName by trying numeric suffix in lastUsedID.
static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op, spirv::ModuleOp target, spirv::ModuleOp source, unsigned &lastUsedID)
Checks if a symbol with the same name as op already exists in source.
static constexpr unsigned maxFreeID
static llvm::hash_code computeHash(SymbolOpInterface symbolOp)
Computes a hash code to represent symbolOp based on all its attributes except for the symbol name.
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
OwningOpRef< spirv::ModuleOp > combine(ArrayRef< spirv::ModuleOp > inputModules, OpBuilder &combinedModuleBuilder, SymbolRenameListener symRenameListener)
Combines a list of SPIR-V inputModules into one.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...