MLIR  22.0.0git
ModuleCombiner.cpp
Go to the documentation of this file.
1 //===- ModuleCombiner.cpp - MLIR SPIR-V Module Combiner ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the SPIR-V module combiner library.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
16 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/SymbolTable.h"
19 #include "llvm/ADT/Hashing.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/StringMap.h"
22 
23 using namespace mlir;
24 
25 static constexpr unsigned maxFreeID = 1 << 20;
26 
27 /// Returns an unused symbol in `module` for `oldSymbolName` by trying numeric
28 /// suffix in `lastUsedID`.
29 static StringAttr renameSymbol(StringRef oldSymName, unsigned &lastUsedID,
30  spirv::ModuleOp module) {
31  SmallString<64> newSymName(oldSymName);
32  newSymName.push_back('_');
33 
34  MLIRContext *ctx = module->getContext();
35 
36  while (lastUsedID < maxFreeID) {
37  auto possible = StringAttr::get(ctx, newSymName + Twine(++lastUsedID));
38  if (!SymbolTable::lookupSymbolIn(module, possible))
39  return possible;
40  }
41 
42  return StringAttr::get(ctx, newSymName);
43 }
44 
45 /// Checks if a symbol with the same name as `op` already exists in `source`.
46 /// If so, renames `op` and updates all its references in `target`.
47 static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op,
48  spirv::ModuleOp target,
49  spirv::ModuleOp source,
50  unsigned &lastUsedID) {
51  if (!SymbolTable::lookupSymbolIn(source, op.getName()))
52  return success();
53 
54  StringRef oldSymName = op.getName();
55  StringAttr newSymName = renameSymbol(oldSymName, lastUsedID, target);
56 
57  if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, target)))
58  return op.emitError("unable to update all symbol uses for ")
59  << oldSymName << " to " << newSymName;
60 
61  SymbolTable::setSymbolName(op, newSymName);
62  return success();
63 }
64 
65 /// Computes a hash code to represent `symbolOp` based on all its attributes
66 /// except for the symbol name.
67 ///
68 /// Note: We use the operation's name (not the symbol name) as part of the hash
69 /// computation. This prevents, for example, mistakenly considering a global
70 /// variable and a spec constant as duplicates because their descriptor set +
71 /// binding and spec_id, respectively, happen to hash to the same value.
72 static llvm::hash_code computeHash(SymbolOpInterface symbolOp) {
73  auto range =
74  llvm::make_filter_range(symbolOp->getAttrs(), [](NamedAttribute attr) {
75  return attr.getName() != SymbolTable::getSymbolAttrName();
76  });
77 
78  return llvm::hash_combine(symbolOp->getName(),
79  llvm::hash_combine_range(range));
80 }
81 
82 namespace mlir {
83 namespace spirv {
84 
86  OpBuilder &combinedModuleBuilder,
87  SymbolRenameListener symRenameListener) {
88  if (inputModules.empty())
89  return nullptr;
90 
91  spirv::ModuleOp firstModule = inputModules.front();
92  auto addressingModel = firstModule.getAddressingModel();
93  auto memoryModel = firstModule.getMemoryModel();
94  auto vceTriple = firstModule.getVceTriple();
95 
96  // First check whether there are conflicts between addressing/memory model.
97  // Return early if so.
98  for (auto module : inputModules) {
99  if (module.getAddressingModel() != addressingModel ||
100  module.getMemoryModel() != memoryModel ||
101  module.getVceTriple() != vceTriple) {
102  module.emitError("input modules differ in addressing model, memory "
103  "model, and/or VCE triple");
104  return nullptr;
105  }
106  }
107 
108  auto combinedModule =
109  spirv::ModuleOp::create(combinedModuleBuilder, firstModule.getLoc(),
110  addressingModel, memoryModel, vceTriple);
111  combinedModuleBuilder.setInsertionPointToStart(combinedModule.getBody());
112 
113  // In some cases, a symbol in the (current state of the) combined module is
114  // renamed in order to enable the conflicting symbol in the input module
115  // being merged. For example, if the conflict is between a global variable in
116  // the current combined module and a function in the input module, the global
117  // variable is renamed. In order to notify listeners of the symbol updates in
118  // such cases, we need to keep track of the module from which the renamed
119  // symbol in the combined module originated. This map keeps such information.
120  llvm::StringMap<spirv::ModuleOp> symNameToModuleMap;
121 
122  unsigned lastUsedID = 0;
123 
124  for (auto inputModule : inputModules) {
125  OwningOpRef<spirv::ModuleOp> moduleClone = inputModule.clone();
126 
127  // In the combined module, rename all symbols that conflict with symbols
128  // from the current input module. This renaming applies to all ops except
129  // for spirv.funcs. This way, if the conflicting op in the input module is
130  // non-spirv.func, we rename that symbol instead and maintain the spirv.func
131  // in the combined module name as it is.
132  for (auto &op : *combinedModule.getBody()) {
133  auto symbolOp = dyn_cast<SymbolOpInterface>(op);
134  if (!symbolOp)
135  continue;
136 
137  StringRef oldSymName = symbolOp.getName();
138 
139  if (!isa<FuncOp>(op) &&
140  failed(updateSymbolAndAllUses(symbolOp, combinedModule, *moduleClone,
141  lastUsedID)))
142  return nullptr;
143 
144  StringRef newSymName = symbolOp.getName();
145 
146  if (symRenameListener && oldSymName != newSymName) {
147  spirv::ModuleOp originalModule = symNameToModuleMap.lookup(oldSymName);
148 
149  if (!originalModule) {
150  inputModule.emitError(
151  "unable to find original spirv::ModuleOp for symbol ")
152  << oldSymName;
153  return nullptr;
154  }
155 
156  symRenameListener(originalModule, oldSymName, newSymName);
157 
158  // Since the symbol name is updated, there is no need to maintain the
159  // entry that associates the old symbol name with the original module.
160  symNameToModuleMap.erase(oldSymName);
161  // Instead, add a new entry to map the new symbol name to the original
162  // module in case it gets renamed again later.
163  symNameToModuleMap[newSymName] = originalModule;
164  }
165  }
166 
167  // In the current input module, rename all symbols that conflict with
168  // symbols from the combined module. This includes renaming spirv.funcs.
169  for (auto &op : *moduleClone->getBody()) {
170  auto symbolOp = dyn_cast<SymbolOpInterface>(op);
171  if (!symbolOp)
172  continue;
173 
174  StringRef oldSymName = symbolOp.getName();
175 
176  if (failed(updateSymbolAndAllUses(symbolOp, *moduleClone, combinedModule,
177  lastUsedID)))
178  return nullptr;
179 
180  StringRef newSymName = symbolOp.getName();
181 
182  if (symRenameListener) {
183  if (oldSymName != newSymName)
184  symRenameListener(inputModule, oldSymName, newSymName);
185 
186  // Insert the module associated with the symbol name.
187  auto emplaceResult =
188  symNameToModuleMap.try_emplace(newSymName, inputModule);
189 
190  // If an entry with the same symbol name is already present, this must
191  // be a problem with the implementation, specially clean-up of the map
192  // while iterating over the combined module above.
193  if (!emplaceResult.second) {
194  inputModule.emitError("did not expect to find an entry for symbol ")
195  << symbolOp.getName();
196  return nullptr;
197  }
198  }
199  }
200 
201  // Clone all the module's ops to the combined module.
202  for (auto &op : *moduleClone->getBody())
203  combinedModuleBuilder.insert(op.clone());
204  }
205 
206  // Deduplicate identical global variables, spec constants, and functions.
209 
210  for (auto &op : *combinedModule.getBody()) {
211  SymbolOpInterface symbolOp = dyn_cast<SymbolOpInterface>(op);
212  if (!symbolOp)
213  continue;
214 
215  // Do not support ops with operands or results.
216  // Global variables, spec constants, and functions won't have
217  // operands/results, but just for safety here.
218  if (op.getNumOperands() != 0 || op.getNumResults() != 0)
219  continue;
220 
221  // Deduplicating functions are not supported yet.
222  if (isa<FuncOp>(op))
223  continue;
224 
225  auto result = hashToSymbolOp.try_emplace(computeHash(symbolOp), symbolOp);
226  if (result.second)
227  continue;
228 
229  SymbolOpInterface replacementSymOp = result.first->second;
230 
232  symbolOp, replacementSymOp.getNameAttr(), combinedModule))) {
233  symbolOp.emitError("unable to update all symbol uses for ")
234  << symbolOp.getName() << " to " << replacementSymOp.getName();
235  return nullptr;
236  }
237 
238  eraseList.push_back(symbolOp);
239  }
240 
241  for (auto symbolOp : eraseList)
242  symbolOp.erase();
243 
244  return combinedModule;
245 }
246 
247 } // namespace spirv
248 } // namespace mlir
static StringAttr renameSymbol(StringRef oldSymName, unsigned &lastUsedID, spirv::ModuleOp module)
Returns an unused symbol in module for oldSymbolName by trying numeric suffix in lastUsedID.
static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op, spirv::ModuleOp target, spirv::ModuleOp source, unsigned &lastUsedID)
Checks if a symbol with the same name as op already exists in source.
static constexpr unsigned maxFreeID
static llvm::hash_code computeHash(SymbolOpInterface symbolOp)
Computes a hash code to represent symbolOp based on all its attributes except for the symbol name.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
Definition: Builders.cpp:416
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
OwningOpRef< spirv::ModuleOp > combine(ArrayRef< spirv::ModuleOp > inputModules, OpBuilder &combinedModuleBuilder, SymbolRenameListener symRenameListener)
Combines a list of SPIR-V inputModules into one.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...