21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/Hashing.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/ADT/StringMap.h"
33 static StringAttr
renameSymbol(StringRef oldSymName,
unsigned &lastUsedID,
34 spirv::ModuleOp module) {
36 newSymName.push_back(
'_');
52 spirv::ModuleOp target,
53 spirv::ModuleOp source,
54 unsigned &lastUsedID) {
58 StringRef oldSymName = op.getName();
59 StringAttr newSymName =
renameSymbol(oldSymName, lastUsedID, target);
62 return op.emitError(
"unable to update all symbol uses for ")
63 << oldSymName <<
" to " << newSymName;
76 static llvm::hash_code
computeHash(SymbolOpInterface symbolOp) {
78 llvm::make_filter_range(symbolOp->getAttrs(), [](
NamedAttribute attr) {
79 return attr.getName() != SymbolTable::getSymbolAttrName();
82 return llvm::hash_combine(symbolOp->getName(),
83 llvm::hash_combine_range(range));
92 if (inputModules.empty())
95 spirv::ModuleOp firstModule = inputModules.front();
96 auto addressingModel = firstModule.getAddressingModel();
97 auto memoryModel = firstModule.getMemoryModel();
98 auto vceTriple = firstModule.getVceTriple();
102 for (
auto module : inputModules) {
103 if (module.getAddressingModel() != addressingModel ||
104 module.getMemoryModel() != memoryModel ||
105 module.getVceTriple() != vceTriple) {
106 module.emitError(
"input modules differ in addressing model, memory "
107 "model, and/or VCE triple");
112 auto combinedModule = combinedModuleBuilder.
create<spirv::ModuleOp>(
113 firstModule.getLoc(), addressingModel, memoryModel, vceTriple);
123 llvm::StringMap<spirv::ModuleOp> symNameToModuleMap;
125 unsigned lastUsedID = 0;
127 for (
auto inputModule : inputModules) {
135 for (
auto &op : *combinedModule.getBody()) {
136 auto symbolOp = dyn_cast<SymbolOpInterface>(op);
140 StringRef oldSymName = symbolOp.getName();
142 if (!isa<FuncOp>(op) &&
147 StringRef newSymName = symbolOp.getName();
149 if (symRenameListener && oldSymName != newSymName) {
150 spirv::ModuleOp originalModule = symNameToModuleMap.lookup(oldSymName);
152 if (!originalModule) {
153 inputModule.emitError(
154 "unable to find original spirv::ModuleOp for symbol ")
159 symRenameListener(originalModule, oldSymName, newSymName);
163 symNameToModuleMap.erase(oldSymName);
166 symNameToModuleMap[newSymName] = originalModule;
172 for (
auto &op : *moduleClone->getBody()) {
173 auto symbolOp = dyn_cast<SymbolOpInterface>(op);
177 StringRef oldSymName = symbolOp.getName();
183 StringRef newSymName = symbolOp.getName();
185 if (symRenameListener) {
186 if (oldSymName != newSymName)
187 symRenameListener(inputModule, oldSymName, newSymName);
191 symNameToModuleMap.try_emplace(newSymName, inputModule);
196 if (!emplaceResult.second) {
197 inputModule.emitError(
"did not expect to find an entry for symbol ")
198 << symbolOp.getName();
205 for (
auto &op : *moduleClone->getBody())
206 combinedModuleBuilder.
insert(op.clone());
213 for (
auto &op : *combinedModule.getBody()) {
214 SymbolOpInterface symbolOp = dyn_cast<SymbolOpInterface>(op);
221 if (op.getNumOperands() != 0 || op.getNumResults() != 0)
228 auto result = hashToSymbolOp.try_emplace(
computeHash(symbolOp), symbolOp);
232 SymbolOpInterface replacementSymOp = result.first->second;
235 symbolOp, replacementSymOp.getNameAttr(), combinedModule))) {
236 symbolOp.emitError(
"unable to update all symbol uses for ")
237 << symbolOp.getName() <<
" to " << replacementSymOp.getName();
241 eraseList.push_back(symbolOp);
244 for (
auto symbolOp : eraseList)
247 return combinedModule;
static StringAttr renameSymbol(StringRef oldSymName, unsigned &lastUsedID, spirv::ModuleOp module)
Returns an unused symbol in module for oldSymbolName by trying numeric suffix in lastUsedID.
static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op, spirv::ModuleOp target, spirv::ModuleOp source, unsigned &lastUsedID)
Checks if a symbol with the same name as op already exists in source.
static constexpr unsigned maxFreeID
static llvm::hash_code computeHash(SymbolOpInterface symbolOp)
Computes a hash code to represent symbolOp based on all its attributes except for the symbol name.
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
OwningOpRef< spirv::ModuleOp > combine(ArrayRef< spirv::ModuleOp > inputModules, OpBuilder &combinedModuleBuilder, SymbolRenameListener symRenameListener)
Combines a list of SPIR-V inputModules into one.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...