MLIR 22.0.0git
ModuleCombiner.cpp
Go to the documentation of this file.
1//===- ModuleCombiner.cpp - MLIR SPIR-V Module Combiner ---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the SPIR-V module combiner library.
10//
11//===----------------------------------------------------------------------===//
12
14
16#include "mlir/IR/Attributes.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/IR/SymbolTable.h"
19#include "llvm/ADT/Hashing.h"
20#include "llvm/ADT/STLExtras.h"
21#include "llvm/ADT/StringMap.h"
22
23using namespace mlir;
24
25static constexpr unsigned maxFreeID = 1 << 20;
26
27/// Returns an unused symbol in `module` for `oldSymbolName` by trying numeric
28/// suffix in `lastUsedID`.
29static StringAttr renameSymbol(StringRef oldSymName, unsigned &lastUsedID,
30 spirv::ModuleOp module) {
31 SmallString<64> newSymName(oldSymName);
32 newSymName.push_back('_');
33
34 MLIRContext *ctx = module->getContext();
35
36 while (lastUsedID < maxFreeID) {
37 auto possible = StringAttr::get(ctx, newSymName + Twine(++lastUsedID));
38 if (!SymbolTable::lookupSymbolIn(module, possible))
39 return possible;
40 }
41
42 return StringAttr::get(ctx, newSymName);
43}
44
45/// Checks if a symbol with the same name as `op` already exists in `source`.
46/// If so, renames `op` and updates all its references in `target`.
47static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op,
48 spirv::ModuleOp target,
49 spirv::ModuleOp source,
50 unsigned &lastUsedID) {
51 if (!SymbolTable::lookupSymbolIn(source, op.getName()))
52 return success();
53
54 StringRef oldSymName = op.getName();
55 StringAttr newSymName = renameSymbol(oldSymName, lastUsedID, target);
56
57 if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, target)))
58 return op.emitError("unable to update all symbol uses for ")
59 << oldSymName << " to " << newSymName;
60
61 SymbolTable::setSymbolName(op, newSymName);
62 return success();
63}
64
65/// Computes a hash code to represent `symbolOp` based on all its attributes
66/// except for the symbol name.
67///
68/// Note: We use the operation's name (not the symbol name) as part of the hash
69/// computation. This prevents, for example, mistakenly considering a global
70/// variable and a spec constant as duplicates because their descriptor set +
71/// binding and spec_id, respectively, happen to hash to the same value.
72static llvm::hash_code computeHash(SymbolOpInterface symbolOp) {
73 auto range =
74 llvm::make_filter_range(symbolOp->getAttrs(), [](NamedAttribute attr) {
75 return attr.getName() != SymbolTable::getSymbolAttrName();
76 });
77
78 return llvm::hash_combine(symbolOp->getName(),
79 llvm::hash_combine_range(range));
80}
81
82namespace mlir {
83namespace spirv {
84
86 OpBuilder &combinedModuleBuilder,
87 SymbolRenameListener symRenameListener) {
88 if (inputModules.empty())
89 return nullptr;
90
91 spirv::ModuleOp firstModule = inputModules.front();
92 auto addressingModel = firstModule.getAddressingModel();
93 auto memoryModel = firstModule.getMemoryModel();
94 auto vceTriple = firstModule.getVceTriple();
95
96 // First check whether there are conflicts between addressing/memory model.
97 // Return early if so.
98 for (auto module : inputModules) {
99 if (module.getAddressingModel() != addressingModel ||
100 module.getMemoryModel() != memoryModel ||
101 module.getVceTriple() != vceTriple) {
102 module.emitError("input modules differ in addressing model, memory "
103 "model, and/or VCE triple");
104 return nullptr;
105 }
106 }
107
108 auto combinedModule =
109 spirv::ModuleOp::create(combinedModuleBuilder, firstModule.getLoc(),
110 addressingModel, memoryModel, vceTriple);
111 combinedModuleBuilder.setInsertionPointToStart(combinedModule.getBody());
112
113 // In some cases, a symbol in the (current state of the) combined module is
114 // renamed in order to enable the conflicting symbol in the input module
115 // being merged. For example, if the conflict is between a global variable in
116 // the current combined module and a function in the input module, the global
117 // variable is renamed. In order to notify listeners of the symbol updates in
118 // such cases, we need to keep track of the module from which the renamed
119 // symbol in the combined module originated. This map keeps such information.
120 llvm::StringMap<spirv::ModuleOp> symNameToModuleMap;
121
122 unsigned lastUsedID = 0;
123
124 for (auto inputModule : inputModules) {
125 OwningOpRef<spirv::ModuleOp> moduleClone = inputModule.clone();
126
127 // In the combined module, rename all symbols that conflict with symbols
128 // from the current input module. This renaming applies to all ops except
129 // for spirv.funcs. This way, if the conflicting op in the input module is
130 // non-spirv.func, we rename that symbol instead and maintain the spirv.func
131 // in the combined module name as it is.
132 for (auto &op : *combinedModule.getBody()) {
133 auto symbolOp = dyn_cast<SymbolOpInterface>(op);
134 if (!symbolOp)
135 continue;
136
137 StringRef oldSymName = symbolOp.getName();
138
139 if (!isa<FuncOp>(op) &&
140 failed(updateSymbolAndAllUses(symbolOp, combinedModule, *moduleClone,
141 lastUsedID)))
142 return nullptr;
143
144 StringRef newSymName = symbolOp.getName();
145
146 if (symRenameListener && oldSymName != newSymName) {
147 spirv::ModuleOp originalModule = symNameToModuleMap.lookup(oldSymName);
148
149 if (!originalModule) {
150 inputModule.emitError(
151 "unable to find original spirv::ModuleOp for symbol ")
152 << oldSymName;
153 return nullptr;
154 }
155
156 symRenameListener(originalModule, oldSymName, newSymName);
157
158 // Since the symbol name is updated, there is no need to maintain the
159 // entry that associates the old symbol name with the original module.
160 symNameToModuleMap.erase(oldSymName);
161 // Instead, add a new entry to map the new symbol name to the original
162 // module in case it gets renamed again later.
163 symNameToModuleMap[newSymName] = originalModule;
164 }
165 }
166
167 // In the current input module, rename all symbols that conflict with
168 // symbols from the combined module. This includes renaming spirv.funcs.
169 for (auto &op : *moduleClone->getBody()) {
170 auto symbolOp = dyn_cast<SymbolOpInterface>(op);
171 if (!symbolOp)
172 continue;
173
174 StringRef oldSymName = symbolOp.getName();
175
176 if (failed(updateSymbolAndAllUses(symbolOp, *moduleClone, combinedModule,
177 lastUsedID)))
178 return nullptr;
179
180 StringRef newSymName = symbolOp.getName();
181
182 if (symRenameListener) {
183 if (oldSymName != newSymName)
184 symRenameListener(inputModule, oldSymName, newSymName);
185
186 // Insert the module associated with the symbol name.
187 auto emplaceResult =
188 symNameToModuleMap.try_emplace(newSymName, inputModule);
189
190 // If an entry with the same symbol name is already present, this must
191 // be a problem with the implementation, specially clean-up of the map
192 // while iterating over the combined module above.
193 if (!emplaceResult.second) {
194 inputModule.emitError("did not expect to find an entry for symbol ")
195 << symbolOp.getName();
196 return nullptr;
197 }
198 }
199 }
200
201 // Clone all the module's ops to the combined module.
202 for (auto &op : *moduleClone->getBody())
203 combinedModuleBuilder.insert(op.clone());
204 }
205
206 // Deduplicate identical global variables, spec constants, and functions.
209
210 for (auto &op : *combinedModule.getBody()) {
211 SymbolOpInterface symbolOp = dyn_cast<SymbolOpInterface>(op);
212 if (!symbolOp)
213 continue;
214
215 // Do not support ops with operands or results.
216 // Global variables, spec constants, and functions won't have
217 // operands/results, but just for safety here.
218 if (op.getNumOperands() != 0 || op.getNumResults() != 0)
219 continue;
220
221 // Deduplicating functions are not supported yet.
222 if (isa<FuncOp>(op))
223 continue;
224
225 auto result = hashToSymbolOp.try_emplace(computeHash(symbolOp), symbolOp);
226 if (result.second)
227 continue;
228
229 SymbolOpInterface replacementSymOp = result.first->second;
230
232 symbolOp, replacementSymOp.getNameAttr(), combinedModule))) {
233 symbolOp.emitError("unable to update all symbol uses for ")
234 << symbolOp.getName() << " to " << replacementSymOp.getName();
235 return nullptr;
236 }
237
238 eraseList.push_back(symbolOp);
239 }
240
241 for (auto symbolOp : eraseList)
242 symbolOp.erase();
243
244 return combinedModule;
245}
246
247} // namespace spirv
248} // namespace mlir
return success()
static StringAttr renameSymbol(StringRef oldSymName, unsigned &lastUsedID, spirv::ModuleOp module)
Returns an unused symbol in module for oldSymbolName by trying numeric suffix in lastUsedID.
static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op, spirv::ModuleOp target, spirv::ModuleOp source, unsigned &lastUsedID)
Checks if a symbol with the same name as op already exists in source.
static constexpr unsigned maxFreeID
static llvm::hash_code computeHash(SymbolOpInterface symbolOp)
Computes a hash code to represent symbolOp based on all its attributes except for the symbol name.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
This class helps build Operations.
Definition Builders.h:207
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
Definition Builders.cpp:421
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
Definition OwningOpRef.h:29
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
function_ref< void( spirv::ModuleOp originalModule, StringRef oldSymbol, StringRef newSymbol)> SymbolRenameListener
The listener function to receive symbol renaming events.
OwningOpRef< spirv::ModuleOp > combine(ArrayRef< spirv::ModuleOp > inputModules, OpBuilder &combinedModuleBuilder, SymbolRenameListener symRenameListener)
Combines a list of SPIR-V inputModules into one.
Include the generated interface declarations.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126