36static LogicalResult
mergeInto(FunctionOpInterface func1,
37 FunctionOpInterface func2) {
39 assert(func1->getParentOp() == func2->getParentOp() &&
40 "expected func1 and func2 to be in the same parent op");
43 if (func1.getFunctionType() != func2.getFunctionType()) {
44 return func1.emitError()
45 <<
"external definition has a mismatching signature ("
46 << func2.getFunctionType() <<
")";
52 StringAttr consumedName = td->getConsumedAttrName();
53 StringAttr readOnlyName = td->getReadOnlyAttrName();
54 for (
unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) {
55 bool isExternalConsumed = func2.getArgAttr(i, consumedName) !=
nullptr;
56 bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) !=
nullptr;
57 bool isConsumed = func1.getArgAttr(i, consumedName) !=
nullptr;
58 bool isReadonly = func1.getArgAttr(i, readOnlyName) !=
nullptr;
59 if (!isExternalConsumed && !isExternalReadonly) {
61 func2.setArgAttr(i, consumedName, UnitAttr::get(context));
63 func2.setArgAttr(i, readOnlyName, UnitAttr::get(context));
67 if ((isExternalConsumed && !isConsumed) ||
68 (isExternalReadonly && !isReadonly)) {
69 return func1.emitError()
70 <<
"external definition has mismatching consumption "
71 "annotations for argument #"
77 assert(func1.isExternal());
118 "requires target to implement the 'SymbolTable' trait");
120 "requires target to implement the 'SymbolTable' trait");
128 target->walk([&](CallOpInterface call) {
131 if (
auto symRef = dyn_cast<SymbolRefAttr>(callee)) {
134 if (isa<FlatSymbolRefAttr>(symRef))
135 callable = targetSymbolTable.
lookup(symRef.getLeafReference());
138 }
else if (
auto value = dyn_cast<Value>(callee)) {
139 callable = value.getDefiningOp();
146 noInlineCalls.insert(call.getOperation());
157 LDBG() <<
"renaming private symbols to resolve conflicts:";
159 for (
auto &&[symbolTable, otherSymbolTable] : llvm::zip(
162 &targetSymbolTable})) {
163 Operation *symbolTableOp = symbolTable->getOp();
165 auto symbolOp = dyn_cast<SymbolOpInterface>(op);
168 StringAttr name = symbolOp.getNameAttr();
169 LDBG() <<
" found @" << name.getValue();
173 cast_or_null<SymbolOpInterface>(otherSymbolTable->
lookup(name));
177 LDBG() <<
" collision found for @" << name.getValue();
180 if (
auto funcOp = dyn_cast<FunctionOpInterface>(op),
182 dyn_cast<FunctionOpInterface>(collidingOp.getOperation());
183 funcOp && collidingFuncOp) {
186 LDBG() <<
" but both ops are functions and will be merged";
191 LDBG() <<
" and both ops are function definitions";
195 auto renameToUnique =
196 [&](SymbolOpInterface op, SymbolOpInterface otherOp,
199 LDBG() <<
", renaming";
200 FailureOr<StringAttr> maybeNewName =
201 symbolTable.renameToUnique(op, {&otherSymbolTable});
202 if (failed(maybeNewName)) {
204 diag.attachNote(otherOp->getLoc())
205 <<
"attempted renaming due to collision with this op";
208 LDBG() <<
" renamed to @" << maybeNewName->getValue();
212 if (symbolOp.isPrivate()) {
213 if (failed(renameToUnique(symbolOp, collidingOp, *symbolTable,
218 if (collidingOp.isPrivate()) {
219 if (failed(renameToUnique(collidingOp, symbolOp, *otherSymbolTable,
224 LDBG() <<
", emitting error";
226 <<
"doubly defined symbol @" << name.getValue();
227 diag.attachNote(collidingOp->getLoc()) <<
"previously defined here";
236 return op->emitError()
237 <<
"failed to verify symbol table after symbol renaming";
243 LDBG() <<
"moving all symbols into target";
246 for (
Operation &op : other->getRegion(0).front()) {
247 if (
auto symbol = dyn_cast<SymbolOpInterface>(op))
248 opsToMove.push_back(symbol);
251 for (SymbolOpInterface op : opsToMove) {
253 auto collidingOp = cast_or_null<SymbolOpInterface>(
254 targetSymbolTable.
lookup(op.getNameAttr()));
257 LDBG() <<
" moving @" << op.getName();
258 op->moveBefore(&
target->getRegion(0).front(),
259 target->getRegion(0).front().end());
265 LDBG() <<
" without collision";
266 targetSymbolTable.
insert(op);
272 auto funcOp = cast<FunctionOpInterface>(op.getOperation());
273 auto collidingFuncOp =
274 cast<FunctionOpInterface>(collidingOp.getOperation());
280 std::swap(funcOp, collidingFuncOp);
284 LDBG() <<
" with collision, trying to keep op at "
285 << collidingFuncOp.getLoc() <<
":\n"
289 targetSymbolTable.
remove(funcOp);
290 targetSymbolTable.
insert(collidingFuncOp);
291 assert(targetSymbolTable.
lookup(funcOp.getName()) == collidingFuncOp);
294 if (failed(
mergeInto(funcOp, collidingFuncOp)))
313 if (!isa<CallableOpInterface, CallOpInterface>(nested))
325 if (
auto symRef = dyn_cast<SymbolRefAttr>(callee)) {
328 if (isa<FlatSymbolRefAttr>(symRef))
329 callable = targetSymbolTable.
lookup(symRef.getLeafReference());
332 }
else if (
auto value = dyn_cast<Value>(callee)) {
333 callable = value.getDefiningOp();
338 if (!noInlineCalls.contains(call.getOperation()) &&
342 <<
"merged call is not legal to inline into its caller";
343 diag.attachNote(callable->
getLoc()) <<
"callee defined here";
351 LDBG() <<
"done merging ops";
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.