21 #include "llvm/ADT/ArrayRef.h" 22 #include "llvm/ADT/Hashing.h" 23 #include "llvm/ADT/STLExtras.h" 24 #include "llvm/ADT/StringExtras.h" 25 #include "llvm/ADT/StringMap.h" 33 static StringAttr
renameSymbol(StringRef oldSymName,
unsigned &lastUsedID,
34 spirv::ModuleOp module) {
36 newSymName.push_back(
'_');
41 auto possible = StringAttr::get(ctx, newSymName + Twine(++lastUsedID));
46 return StringAttr::get(ctx, newSymName);
52 spirv::ModuleOp target,
53 spirv::ModuleOp source,
54 unsigned &lastUsedID) {
58 StringRef oldSymName = op.getName();
59 StringAttr newSymName =
renameSymbol(oldSymName, lastUsedID, target);
62 return op.emitError(
"unable to update all symbol uses for ")
63 << oldSymName <<
" to " << newSymName;
76 static llvm::hash_code
computeHash(SymbolOpInterface symbolOp) {
78 llvm::make_filter_range(symbolOp->getAttrs(), [](
NamedAttribute attr) {
82 return llvm::hash_combine(
84 llvm::hash_combine_range(range.begin(), range.end()));
93 if (inputModules.empty())
96 spirv::ModuleOp firstModule = inputModules.front();
97 auto addressingModel = firstModule.addressing_model();
98 auto memoryModel = firstModule.memory_model();
99 auto vceTriple = firstModule.vce_triple();
103 for (
auto module : inputModules) {
104 if (module.addressing_model() != addressingModel ||
105 module.memory_model() != memoryModel ||
106 module.vce_triple() != vceTriple) {
107 module.emitError(
"input modules differ in addressing model, memory " 108 "model, and/or VCE triple");
113 auto combinedModule = combinedModuleBuilder.
create<spirv::ModuleOp>(
114 firstModule.getLoc(), addressingModel, memoryModel, vceTriple);
124 llvm::StringMap<spirv::ModuleOp> symNameToModuleMap;
126 unsigned lastUsedID = 0;
128 for (
auto inputModule : inputModules) {
136 for (
auto &op : *combinedModule.getBody()) {
137 auto symbolOp = dyn_cast<SymbolOpInterface>(op);
141 StringRef oldSymName = symbolOp.getName();
143 if (!isa<FuncOp>(op) &&
148 StringRef newSymName = symbolOp.getName();
150 if (symRenameListener && oldSymName != newSymName) {
151 spirv::ModuleOp originalModule = symNameToModuleMap.lookup(oldSymName);
153 if (!originalModule) {
154 inputModule.emitError(
155 "unable to find original spirv::ModuleOp for symbol ")
160 symRenameListener(originalModule, oldSymName, newSymName);
164 symNameToModuleMap.erase(oldSymName);
167 symNameToModuleMap[newSymName] = originalModule;
173 for (
auto &op : *moduleClone->getBody()) {
174 auto symbolOp = dyn_cast<SymbolOpInterface>(op);
178 StringRef oldSymName = symbolOp.getName();
184 StringRef newSymName = symbolOp.getName();
186 if (symRenameListener) {
187 if (oldSymName != newSymName)
188 symRenameListener(inputModule, oldSymName, newSymName);
192 symNameToModuleMap.try_emplace(newSymName, inputModule);
197 if (!emplaceResult.second) {
198 inputModule.emitError(
"did not expect to find an entry for symbol ")
199 << symbolOp.getName();
206 for (
auto &op : *moduleClone->getBody())
207 combinedModuleBuilder.
insert(op.clone());
214 for (
auto &op : *combinedModule.getBody()) {
215 SymbolOpInterface symbolOp = dyn_cast<SymbolOpInterface>(op);
222 if (op.getNumOperands() != 0 || op.getNumResults() != 0)
229 auto result = hashToSymbolOp.try_emplace(
computeHash(symbolOp), symbolOp);
233 SymbolOpInterface replacementSymOp = result.first->second;
236 symbolOp, replacementSymOp.getNameAttr(), combinedModule))) {
237 symbolOp.emitError(
"unable to update all symbol uses for ")
238 << symbolOp.getName() <<
" to " << replacementSymOp.getName();
242 eraseList.push_back(symbolOp);
245 for (
auto symbolOp : eraseList)
248 return combinedModule;
TODO: Remove this file when SCCP and integer range analysis have been ported to the new framework...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
static StringAttr renameSymbol(StringRef oldSymName, unsigned &lastUsedID, spirv::ModuleOp module)
Returns an unsed symbol in module for oldSymbolName by trying numeric suffix in lastUsedID.
NamedAttribute represents a combination of a name and an Attribute value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
This class represents an efficient way to signal success or failure.
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
OwningOpRef< spirv::ModuleOp > combine(ArrayRef< spirv::ModuleOp > inputModules, OpBuilder &combinedModuleBuilder, SymbolRenameListener symRenameListener)
Combines a list of SPIR-V inputModules into one.
MLIRContext is the top-level object for a collection of MLIR operations.
static constexpr unsigned maxFreeID
static llvm::hash_code computeHash(SymbolOpInterface symbolOp)
Computes a hash code to represent symbolOp based on all its attributes except for the symbol name...
static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op, spirv::ModuleOp target, spirv::ModuleOp source, unsigned &lastUsedID)
Checks if a symbol with the same name as op already exists in source.
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
This class helps build Operations.