MLIR  19.0.0git
ModuleCombiner.cpp
Go to the documentation of this file.
1 //===- ModuleCombiner.cpp - MLIR SPIR-V Module Combiner ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the SPIR-V module combiner library.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/SymbolTable.h"
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/Hashing.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/ADT/StringMap.h"
26 
27 using namespace mlir;
28 
29 static constexpr unsigned maxFreeID = 1 << 20;
30 
31 /// Returns an unused symbol in `module` for `oldSymbolName` by trying numeric
32 /// suffix in `lastUsedID`.
33 static StringAttr renameSymbol(StringRef oldSymName, unsigned &lastUsedID,
34  spirv::ModuleOp module) {
35  SmallString<64> newSymName(oldSymName);
36  newSymName.push_back('_');
37 
38  MLIRContext *ctx = module->getContext();
39 
40  while (lastUsedID < maxFreeID) {
41  auto possible = StringAttr::get(ctx, newSymName + Twine(++lastUsedID));
42  if (!SymbolTable::lookupSymbolIn(module, possible))
43  return possible;
44  }
45 
46  return StringAttr::get(ctx, newSymName);
47 }
48 
49 /// Checks if a symbol with the same name as `op` already exists in `source`.
50 /// If so, renames `op` and updates all its references in `target`.
51 static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op,
52  spirv::ModuleOp target,
53  spirv::ModuleOp source,
54  unsigned &lastUsedID) {
55  if (!SymbolTable::lookupSymbolIn(source, op.getName()))
56  return success();
57 
58  StringRef oldSymName = op.getName();
59  StringAttr newSymName = renameSymbol(oldSymName, lastUsedID, target);
60 
61  if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, target)))
62  return op.emitError("unable to update all symbol uses for ")
63  << oldSymName << " to " << newSymName;
64 
65  SymbolTable::setSymbolName(op, newSymName);
66  return success();
67 }
68 
69 /// Computes a hash code to represent `symbolOp` based on all its attributes
70 /// except for the symbol name.
71 ///
72 /// Note: We use the operation's name (not the symbol name) as part of the hash
73 /// computation. This prevents, for example, mistakenly considering a global
74 /// variable and a spec constant as duplicates because their descriptor set +
75 /// binding and spec_id, respectively, happen to hash to the same value.
76 static llvm::hash_code computeHash(SymbolOpInterface symbolOp) {
77  auto range =
78  llvm::make_filter_range(symbolOp->getAttrs(), [](NamedAttribute attr) {
79  return attr.getName() != SymbolTable::getSymbolAttrName();
80  });
81 
82  return llvm::hash_combine(
83  symbolOp->getName(),
84  llvm::hash_combine_range(range.begin(), range.end()));
85 }
86 
87 namespace mlir {
88 namespace spirv {
89 
91  OpBuilder &combinedModuleBuilder,
92  SymbolRenameListener symRenameListener) {
93  if (inputModules.empty())
94  return nullptr;
95 
96  spirv::ModuleOp firstModule = inputModules.front();
97  auto addressingModel = firstModule.getAddressingModel();
98  auto memoryModel = firstModule.getMemoryModel();
99  auto vceTriple = firstModule.getVceTriple();
100 
101  // First check whether there are conflicts between addressing/memory model.
102  // Return early if so.
103  for (auto module : inputModules) {
104  if (module.getAddressingModel() != addressingModel ||
105  module.getMemoryModel() != memoryModel ||
106  module.getVceTriple() != vceTriple) {
107  module.emitError("input modules differ in addressing model, memory "
108  "model, and/or VCE triple");
109  return nullptr;
110  }
111  }
112 
113  auto combinedModule = combinedModuleBuilder.create<spirv::ModuleOp>(
114  firstModule.getLoc(), addressingModel, memoryModel, vceTriple);
115  combinedModuleBuilder.setInsertionPointToStart(combinedModule.getBody());
116 
117  // In some cases, a symbol in the (current state of the) combined module is
118  // renamed in order to enable the conflicting symbol in the input module
119  // being merged. For example, if the conflict is between a global variable in
120  // the current combined module and a function in the input module, the global
121  // variable is renamed. In order to notify listeners of the symbol updates in
122  // such cases, we need to keep track of the module from which the renamed
123  // symbol in the combined module originated. This map keeps such information.
124  llvm::StringMap<spirv::ModuleOp> symNameToModuleMap;
125 
126  unsigned lastUsedID = 0;
127 
128  for (auto inputModule : inputModules) {
129  OwningOpRef<spirv::ModuleOp> moduleClone = inputModule.clone();
130 
131  // In the combined module, rename all symbols that conflict with symbols
132  // from the current input module. This renaming applies to all ops except
133  // for spirv.funcs. This way, if the conflicting op in the input module is
134  // non-spirv.func, we rename that symbol instead and maintain the spirv.func
135  // in the combined module name as it is.
136  for (auto &op : *combinedModule.getBody()) {
137  auto symbolOp = dyn_cast<SymbolOpInterface>(op);
138  if (!symbolOp)
139  continue;
140 
141  StringRef oldSymName = symbolOp.getName();
142 
143  if (!isa<FuncOp>(op) &&
144  failed(updateSymbolAndAllUses(symbolOp, combinedModule, *moduleClone,
145  lastUsedID)))
146  return nullptr;
147 
148  StringRef newSymName = symbolOp.getName();
149 
150  if (symRenameListener && oldSymName != newSymName) {
151  spirv::ModuleOp originalModule = symNameToModuleMap.lookup(oldSymName);
152 
153  if (!originalModule) {
154  inputModule.emitError(
155  "unable to find original spirv::ModuleOp for symbol ")
156  << oldSymName;
157  return nullptr;
158  }
159 
160  symRenameListener(originalModule, oldSymName, newSymName);
161 
162  // Since the symbol name is updated, there is no need to maintain the
163  // entry that associates the old symbol name with the original module.
164  symNameToModuleMap.erase(oldSymName);
165  // Instead, add a new entry to map the new symbol name to the original
166  // module in case it gets renamed again later.
167  symNameToModuleMap[newSymName] = originalModule;
168  }
169  }
170 
171  // In the current input module, rename all symbols that conflict with
172  // symbols from the combined module. This includes renaming spirv.funcs.
173  for (auto &op : *moduleClone->getBody()) {
174  auto symbolOp = dyn_cast<SymbolOpInterface>(op);
175  if (!symbolOp)
176  continue;
177 
178  StringRef oldSymName = symbolOp.getName();
179 
180  if (failed(updateSymbolAndAllUses(symbolOp, *moduleClone, combinedModule,
181  lastUsedID)))
182  return nullptr;
183 
184  StringRef newSymName = symbolOp.getName();
185 
186  if (symRenameListener) {
187  if (oldSymName != newSymName)
188  symRenameListener(inputModule, oldSymName, newSymName);
189 
190  // Insert the module associated with the symbol name.
191  auto emplaceResult =
192  symNameToModuleMap.try_emplace(newSymName, inputModule);
193 
194  // If an entry with the same symbol name is already present, this must
195  // be a problem with the implementation, specially clean-up of the map
196  // while iterating over the combined module above.
197  if (!emplaceResult.second) {
198  inputModule.emitError("did not expect to find an entry for symbol ")
199  << symbolOp.getName();
200  return nullptr;
201  }
202  }
203  }
204 
205  // Clone all the module's ops to the combined module.
206  for (auto &op : *moduleClone->getBody())
207  combinedModuleBuilder.insert(op.clone());
208  }
209 
210  // Deduplicate identical global variables, spec constants, and functions.
213 
214  for (auto &op : *combinedModule.getBody()) {
215  SymbolOpInterface symbolOp = dyn_cast<SymbolOpInterface>(op);
216  if (!symbolOp)
217  continue;
218 
219  // Do not support ops with operands or results.
220  // Global variables, spec constants, and functions won't have
221  // operands/results, but just for safety here.
222  if (op.getNumOperands() != 0 || op.getNumResults() != 0)
223  continue;
224 
225  // Deduplicating functions are not supported yet.
226  if (isa<FuncOp>(op))
227  continue;
228 
229  auto result = hashToSymbolOp.try_emplace(computeHash(symbolOp), symbolOp);
230  if (result.second)
231  continue;
232 
233  SymbolOpInterface replacementSymOp = result.first->second;
234 
236  symbolOp, replacementSymOp.getNameAttr(), combinedModule))) {
237  symbolOp.emitError("unable to update all symbol uses for ")
238  << symbolOp.getName() << " to " << replacementSymOp.getName();
239  return nullptr;
240  }
241 
242  eraseList.push_back(symbolOp);
243  }
244 
245  for (auto symbolOp : eraseList)
246  symbolOp.erase();
247 
248  return combinedModule;
249 }
250 
251 } // namespace spirv
252 } // namespace mlir
static StringAttr renameSymbol(StringRef oldSymName, unsigned &lastUsedID, spirv::ModuleOp module)
Returns an unused symbol in module for oldSymbolName by trying numeric suffix in lastUsedID.
static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op, spirv::ModuleOp target, spirv::ModuleOp source, unsigned &lastUsedID)
Checks if a symbol with the same name as op already exists in source.
static constexpr unsigned maxFreeID
static llvm::hash_code computeHash(SymbolOpInterface symbolOp)
Computes a hash code to represent symbolOp based on all its attributes except for the symbol name.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:202
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
Definition: Builders.cpp:428
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:717
unsigned getNumOperands()
Definition: Operation.h:341
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
OwningOpRef< spirv::ModuleOp > combine(ArrayRef< spirv::ModuleOp > inputModules, OpBuilder &combinedModuleBuilder, SymbolRenameListener symRenameListener)
Combines a list of SPIR-V inputModules into one.
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26