21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/Hashing.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/ADT/StringMap.h"
33 static StringAttr
renameSymbol(StringRef oldSymName,
unsigned &lastUsedID,
34 spirv::ModuleOp module) {
36 newSymName.push_back(
'_');
52 spirv::ModuleOp target,
53 spirv::ModuleOp source,
54 unsigned &lastUsedID) {
58 StringRef oldSymName = op.
getName();
59 StringAttr newSymName =
renameSymbol(oldSymName, lastUsedID, target);
62 return op.
emitError(
"unable to update all symbol uses for ")
63 << oldSymName <<
" to " << newSymName;
76 static llvm::hash_code
computeHash(SymbolOpInterface symbolOp) {
78 llvm::make_filter_range(symbolOp->getAttrs(), [](
NamedAttribute attr) {
79 return attr.getName() != SymbolTable::getSymbolAttrName();
82 return llvm::hash_combine(
84 llvm::hash_combine_range(range.begin(), range.end()));
93 if (inputModules.empty())
96 spirv::ModuleOp firstModule = inputModules.front();
97 auto addressingModel = firstModule.getAddressingModel();
98 auto memoryModel = firstModule.getMemoryModel();
99 auto vceTriple = firstModule.getVceTriple();
103 for (
auto module : inputModules) {
104 if (module.getAddressingModel() != addressingModel ||
105 module.getMemoryModel() != memoryModel ||
106 module.getVceTriple() != vceTriple) {
107 module.emitError(
"input modules differ in addressing model, memory "
108 "model, and/or VCE triple");
113 auto combinedModule = combinedModuleBuilder.
create<spirv::ModuleOp>(
114 firstModule.getLoc(), addressingModel, memoryModel, vceTriple);
124 llvm::StringMap<spirv::ModuleOp> symNameToModuleMap;
126 unsigned lastUsedID = 0;
128 for (
auto inputModule : inputModules) {
136 for (
auto &op : *combinedModule.getBody()) {
137 auto symbolOp = dyn_cast<SymbolOpInterface>(op);
141 StringRef oldSymName = symbolOp.getName();
143 if (!isa<FuncOp>(op) &&
148 StringRef newSymName = symbolOp.getName();
150 if (symRenameListener && oldSymName != newSymName) {
151 spirv::ModuleOp originalModule = symNameToModuleMap.lookup(oldSymName);
153 if (!originalModule) {
154 inputModule.emitError(
155 "unable to find original spirv::ModuleOp for symbol ")
160 symRenameListener(originalModule, oldSymName, newSymName);
164 symNameToModuleMap.erase(oldSymName);
167 symNameToModuleMap[newSymName] = originalModule;
173 for (
auto &op : *moduleClone->getBody()) {
174 auto symbolOp = dyn_cast<SymbolOpInterface>(op);
178 StringRef oldSymName = symbolOp.getName();
184 StringRef newSymName = symbolOp.getName();
186 if (symRenameListener) {
187 if (oldSymName != newSymName)
188 symRenameListener(inputModule, oldSymName, newSymName);
192 symNameToModuleMap.try_emplace(newSymName, inputModule);
197 if (!emplaceResult.second) {
198 inputModule.emitError(
"did not expect to find an entry for symbol ")
199 << symbolOp.getName();
206 for (
auto &op : *moduleClone->getBody())
214 for (
auto &op : *combinedModule.getBody()) {
215 SymbolOpInterface symbolOp = dyn_cast<SymbolOpInterface>(op);
229 auto result = hashToSymbolOp.try_emplace(
computeHash(symbolOp), symbolOp);
233 SymbolOpInterface replacementSymOp = result.first->second;
236 symbolOp, replacementSymOp.getNameAttr(), combinedModule))) {
237 symbolOp.emitError(
"unable to update all symbol uses for ")
238 << symbolOp.getName() <<
" to " << replacementSymOp.getName();
242 eraseList.push_back(symbolOp);
245 for (
auto symbolOp : eraseList)
248 return combinedModule;
static StringAttr renameSymbol(StringRef oldSymName, unsigned &lastUsedID, spirv::ModuleOp module)
Returns an unused symbol in module for oldSymbolName by trying numeric suffix in lastUsedID.
static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op, spirv::ModuleOp target, spirv::ModuleOp source, unsigned &lastUsedID)
Checks if a symbol with the same name as op already exists in source.
static constexpr unsigned maxFreeID
static llvm::hash_code computeHash(SymbolOpInterface symbolOp)
Computes a hash code to represent symbolOp based on all its attributes except for the symbol name.
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
unsigned getNumResults()
Return the number of results held by this operation.
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
OwningOpRef< spirv::ModuleOp > combine(ArrayRef< spirv::ModuleOp > inputModules, OpBuilder &combinedModuleBuilder, SymbolRenameListener symRenameListener)
Combines a list of SPIR-V inputModules into one.
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.