21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/Hashing.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/ADT/StringMap.h"
33 static StringAttr
renameSymbol(StringRef oldSymName,
unsigned &lastUsedID,
34 spirv::ModuleOp module) {
36 newSymName.push_back(
'_');
52 spirv::ModuleOp target,
53 spirv::ModuleOp source,
54 unsigned &lastUsedID) {
58 StringRef oldSymName = op.getName();
59 StringAttr newSymName =
renameSymbol(oldSymName, lastUsedID, target);
62 return op.emitError(
"unable to update all symbol uses for ")
63 << oldSymName <<
" to " << newSymName;
76 static llvm::hash_code
computeHash(SymbolOpInterface symbolOp) {
78 llvm::make_filter_range(symbolOp->getAttrs(), [](
NamedAttribute attr) {
79 return attr.getName() != SymbolTable::getSymbolAttrName();
82 return llvm::hash_combine(
84 llvm::hash_combine_range(range.begin(), range.end()));
93 if (inputModules.empty())
96 spirv::ModuleOp firstModule = inputModules.front();
97 auto addressingModel = firstModule.getAddressingModel();
98 auto memoryModel = firstModule.getMemoryModel();
99 auto vceTriple = firstModule.getVceTriple();
103 for (
auto module : inputModules) {
104 if (module.getAddressingModel() != addressingModel ||
105 module.getMemoryModel() != memoryModel ||
106 module.getVceTriple() != vceTriple) {
107 module.emitError(
"input modules differ in addressing model, memory "
108 "model, and/or VCE triple");
113 auto combinedModule = combinedModuleBuilder.
create<spirv::ModuleOp>(
114 firstModule.getLoc(), addressingModel, memoryModel, vceTriple);
124 llvm::StringMap<spirv::ModuleOp> symNameToModuleMap;
126 unsigned lastUsedID = 0;
128 for (
auto inputModule : inputModules) {
136 for (
auto &op : *combinedModule.getBody()) {
137 auto symbolOp = dyn_cast<SymbolOpInterface>(op);
141 StringRef oldSymName = symbolOp.getName();
143 if (!isa<FuncOp>(op) &&
148 StringRef newSymName = symbolOp.getName();
150 if (symRenameListener && oldSymName != newSymName) {
151 spirv::ModuleOp originalModule = symNameToModuleMap.lookup(oldSymName);
153 if (!originalModule) {
154 inputModule.emitError(
155 "unable to find original spirv::ModuleOp for symbol ")
160 symRenameListener(originalModule, oldSymName, newSymName);
164 symNameToModuleMap.erase(oldSymName);
167 symNameToModuleMap[newSymName] = originalModule;
173 for (
auto &op : *moduleClone->getBody()) {
174 auto symbolOp = dyn_cast<SymbolOpInterface>(op);
178 StringRef oldSymName = symbolOp.getName();
184 StringRef newSymName = symbolOp.getName();
186 if (symRenameListener) {
187 if (oldSymName != newSymName)
188 symRenameListener(inputModule, oldSymName, newSymName);
192 symNameToModuleMap.try_emplace(newSymName, inputModule);
197 if (!emplaceResult.second) {
198 inputModule.emitError(
"did not expect to find an entry for symbol ")
199 << symbolOp.getName();
206 for (
auto &op : *moduleClone->getBody())
207 combinedModuleBuilder.
insert(op.clone());
214 for (
auto &op : *combinedModule.getBody()) {
215 SymbolOpInterface symbolOp = dyn_cast<SymbolOpInterface>(op);
222 if (op.getNumOperands() != 0 || op.getNumResults() != 0)
229 auto result = hashToSymbolOp.try_emplace(
computeHash(symbolOp), symbolOp);
233 SymbolOpInterface replacementSymOp = result.first->second;
236 symbolOp, replacementSymOp.getNameAttr(), combinedModule))) {
237 symbolOp.emitError(
"unable to update all symbol uses for ")
238 << symbolOp.getName() <<
" to " << replacementSymOp.getName();
242 eraseList.push_back(symbolOp);
245 for (
auto symbolOp : eraseList)
248 return combinedModule;
static StringAttr renameSymbol(StringRef oldSymName, unsigned &lastUsedID, spirv::ModuleOp module)
Returns an unused symbol in module for oldSymbolName by trying numeric suffix in lastUsedID.
static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op, spirv::ModuleOp target, spirv::ModuleOp source, unsigned &lastUsedID)
Checks if a symbol with the same name as op already exists in source.
static constexpr unsigned maxFreeID
static llvm::hash_code computeHash(SymbolOpInterface symbolOp)
Computes a hash code to represent symbolOp based on all its attributes except for the symbol name.
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
OwningOpRef< spirv::ModuleOp > combine(ArrayRef< spirv::ModuleOp > inputModules, OpBuilder &combinedModuleBuilder, SymbolRenameListener symRenameListener)
Combines a list of SPIR-V inputModules into one.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...