MLIR  21.0.0git
ModuleCombiner.cpp
Go to the documentation of this file.
1 //===- ModuleCombiner.cpp - MLIR SPIR-V Module Combiner ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the SPIR-V module combiner library.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/SymbolTable.h"
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/Hashing.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/ADT/StringMap.h"
26 
27 using namespace mlir;
28 
29 static constexpr unsigned maxFreeID = 1 << 20;
30 
31 /// Returns an unused symbol in `module` for `oldSymbolName` by trying numeric
32 /// suffix in `lastUsedID`.
33 static StringAttr renameSymbol(StringRef oldSymName, unsigned &lastUsedID,
34  spirv::ModuleOp module) {
35  SmallString<64> newSymName(oldSymName);
36  newSymName.push_back('_');
37 
38  MLIRContext *ctx = module->getContext();
39 
40  while (lastUsedID < maxFreeID) {
41  auto possible = StringAttr::get(ctx, newSymName + Twine(++lastUsedID));
42  if (!SymbolTable::lookupSymbolIn(module, possible))
43  return possible;
44  }
45 
46  return StringAttr::get(ctx, newSymName);
47 }
48 
49 /// Checks if a symbol with the same name as `op` already exists in `source`.
50 /// If so, renames `op` and updates all its references in `target`.
51 static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op,
52  spirv::ModuleOp target,
53  spirv::ModuleOp source,
54  unsigned &lastUsedID) {
55  if (!SymbolTable::lookupSymbolIn(source, op.getName()))
56  return success();
57 
58  StringRef oldSymName = op.getName();
59  StringAttr newSymName = renameSymbol(oldSymName, lastUsedID, target);
60 
61  if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, target)))
62  return op.emitError("unable to update all symbol uses for ")
63  << oldSymName << " to " << newSymName;
64 
65  SymbolTable::setSymbolName(op, newSymName);
66  return success();
67 }
68 
69 /// Computes a hash code to represent `symbolOp` based on all its attributes
70 /// except for the symbol name.
71 ///
72 /// Note: We use the operation's name (not the symbol name) as part of the hash
73 /// computation. This prevents, for example, mistakenly considering a global
74 /// variable and a spec constant as duplicates because their descriptor set +
75 /// binding and spec_id, respectively, happen to hash to the same value.
76 static llvm::hash_code computeHash(SymbolOpInterface symbolOp) {
77  auto range =
78  llvm::make_filter_range(symbolOp->getAttrs(), [](NamedAttribute attr) {
79  return attr.getName() != SymbolTable::getSymbolAttrName();
80  });
81 
82  return llvm::hash_combine(symbolOp->getName(),
83  llvm::hash_combine_range(range));
84 }
85 
86 namespace mlir {
87 namespace spirv {
88 
90  OpBuilder &combinedModuleBuilder,
91  SymbolRenameListener symRenameListener) {
92  if (inputModules.empty())
93  return nullptr;
94 
95  spirv::ModuleOp firstModule = inputModules.front();
96  auto addressingModel = firstModule.getAddressingModel();
97  auto memoryModel = firstModule.getMemoryModel();
98  auto vceTriple = firstModule.getVceTriple();
99 
100  // First check whether there are conflicts between addressing/memory model.
101  // Return early if so.
102  for (auto module : inputModules) {
103  if (module.getAddressingModel() != addressingModel ||
104  module.getMemoryModel() != memoryModel ||
105  module.getVceTriple() != vceTriple) {
106  module.emitError("input modules differ in addressing model, memory "
107  "model, and/or VCE triple");
108  return nullptr;
109  }
110  }
111 
112  auto combinedModule = combinedModuleBuilder.create<spirv::ModuleOp>(
113  firstModule.getLoc(), addressingModel, memoryModel, vceTriple);
114  combinedModuleBuilder.setInsertionPointToStart(combinedModule.getBody());
115 
116  // In some cases, a symbol in the (current state of the) combined module is
117  // renamed in order to enable the conflicting symbol in the input module
118  // being merged. For example, if the conflict is between a global variable in
119  // the current combined module and a function in the input module, the global
120  // variable is renamed. In order to notify listeners of the symbol updates in
121  // such cases, we need to keep track of the module from which the renamed
122  // symbol in the combined module originated. This map keeps such information.
123  llvm::StringMap<spirv::ModuleOp> symNameToModuleMap;
124 
125  unsigned lastUsedID = 0;
126 
127  for (auto inputModule : inputModules) {
128  OwningOpRef<spirv::ModuleOp> moduleClone = inputModule.clone();
129 
130  // In the combined module, rename all symbols that conflict with symbols
131  // from the current input module. This renaming applies to all ops except
132  // for spirv.funcs. This way, if the conflicting op in the input module is
133  // non-spirv.func, we rename that symbol instead and maintain the spirv.func
134  // in the combined module name as it is.
135  for (auto &op : *combinedModule.getBody()) {
136  auto symbolOp = dyn_cast<SymbolOpInterface>(op);
137  if (!symbolOp)
138  continue;
139 
140  StringRef oldSymName = symbolOp.getName();
141 
142  if (!isa<FuncOp>(op) &&
143  failed(updateSymbolAndAllUses(symbolOp, combinedModule, *moduleClone,
144  lastUsedID)))
145  return nullptr;
146 
147  StringRef newSymName = symbolOp.getName();
148 
149  if (symRenameListener && oldSymName != newSymName) {
150  spirv::ModuleOp originalModule = symNameToModuleMap.lookup(oldSymName);
151 
152  if (!originalModule) {
153  inputModule.emitError(
154  "unable to find original spirv::ModuleOp for symbol ")
155  << oldSymName;
156  return nullptr;
157  }
158 
159  symRenameListener(originalModule, oldSymName, newSymName);
160 
161  // Since the symbol name is updated, there is no need to maintain the
162  // entry that associates the old symbol name with the original module.
163  symNameToModuleMap.erase(oldSymName);
164  // Instead, add a new entry to map the new symbol name to the original
165  // module in case it gets renamed again later.
166  symNameToModuleMap[newSymName] = originalModule;
167  }
168  }
169 
170  // In the current input module, rename all symbols that conflict with
171  // symbols from the combined module. This includes renaming spirv.funcs.
172  for (auto &op : *moduleClone->getBody()) {
173  auto symbolOp = dyn_cast<SymbolOpInterface>(op);
174  if (!symbolOp)
175  continue;
176 
177  StringRef oldSymName = symbolOp.getName();
178 
179  if (failed(updateSymbolAndAllUses(symbolOp, *moduleClone, combinedModule,
180  lastUsedID)))
181  return nullptr;
182 
183  StringRef newSymName = symbolOp.getName();
184 
185  if (symRenameListener) {
186  if (oldSymName != newSymName)
187  symRenameListener(inputModule, oldSymName, newSymName);
188 
189  // Insert the module associated with the symbol name.
190  auto emplaceResult =
191  symNameToModuleMap.try_emplace(newSymName, inputModule);
192 
193  // If an entry with the same symbol name is already present, this must
194  // be a problem with the implementation, specially clean-up of the map
195  // while iterating over the combined module above.
196  if (!emplaceResult.second) {
197  inputModule.emitError("did not expect to find an entry for symbol ")
198  << symbolOp.getName();
199  return nullptr;
200  }
201  }
202  }
203 
204  // Clone all the module's ops to the combined module.
205  for (auto &op : *moduleClone->getBody())
206  combinedModuleBuilder.insert(op.clone());
207  }
208 
209  // Deduplicate identical global variables, spec constants, and functions.
212 
213  for (auto &op : *combinedModule.getBody()) {
214  SymbolOpInterface symbolOp = dyn_cast<SymbolOpInterface>(op);
215  if (!symbolOp)
216  continue;
217 
218  // Do not support ops with operands or results.
219  // Global variables, spec constants, and functions won't have
220  // operands/results, but just for safety here.
221  if (op.getNumOperands() != 0 || op.getNumResults() != 0)
222  continue;
223 
224  // Deduplicating functions are not supported yet.
225  if (isa<FuncOp>(op))
226  continue;
227 
228  auto result = hashToSymbolOp.try_emplace(computeHash(symbolOp), symbolOp);
229  if (result.second)
230  continue;
231 
232  SymbolOpInterface replacementSymOp = result.first->second;
233 
235  symbolOp, replacementSymOp.getNameAttr(), combinedModule))) {
236  symbolOp.emitError("unable to update all symbol uses for ")
237  << symbolOp.getName() << " to " << replacementSymOp.getName();
238  return nullptr;
239  }
240 
241  eraseList.push_back(symbolOp);
242  }
243 
244  for (auto symbolOp : eraseList)
245  symbolOp.erase();
246 
247  return combinedModule;
248 }
249 
250 } // namespace spirv
251 } // namespace mlir
static StringAttr renameSymbol(StringRef oldSymName, unsigned &lastUsedID, spirv::ModuleOp module)
Returns an unused symbol in module for oldSymbolName by trying numeric suffix in lastUsedID.
static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op, spirv::ModuleOp target, spirv::ModuleOp source, unsigned &lastUsedID)
Checks if a symbol with the same name as op already exists in source.
static constexpr unsigned maxFreeID
static llvm::hash_code computeHash(SymbolOpInterface symbolOp)
Computes a hash code to represent symbolOp based on all its attributes except for the symbol name.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
Definition: Builders.cpp:417
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
OwningOpRef< spirv::ModuleOp > combine(ArrayRef< spirv::ModuleOp > inputModules, OpBuilder &combinedModuleBuilder, SymbolRenameListener symRenameListener)
Combines a list of SPIR-V inputModules into one.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...