36static LogicalResult
mergeInto(FunctionOpInterface func1,
37 FunctionOpInterface func2) {
39 assert(func1->getParentOp() == func2->getParentOp() &&
40 "expected func1 and func2 to be in the same parent op");
43 if (func1.getFunctionType() != func2.getFunctionType()) {
44 return func1.emitError()
45 <<
"external definition has a mismatching signature ("
46 << func2.getFunctionType() <<
")";
52 StringAttr consumedName = td->getConsumedAttrName();
53 StringAttr readOnlyName = td->getReadOnlyAttrName();
54 for (
unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) {
55 bool isExternalConsumed = func2.getArgAttr(i, consumedName) !=
nullptr;
56 bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) !=
nullptr;
57 bool isConsumed = func1.getArgAttr(i, consumedName) !=
nullptr;
58 bool isReadonly = func1.getArgAttr(i, readOnlyName) !=
nullptr;
59 if (!isExternalConsumed && !isExternalReadonly) {
61 func2.setArgAttr(i, consumedName, UnitAttr::get(context));
63 func2.setArgAttr(i, readOnlyName, UnitAttr::get(context));
67 if ((isExternalConsumed && !isConsumed) ||
68 (isExternalReadonly && !isReadonly)) {
69 return func1.emitError()
70 <<
"external definition has mismatching consumption "
71 "annotations for argument #"
77 assert(func1.isExternal());
118 "requires target to implement the 'SymbolTable' trait");
120 "requires target to implement the 'SymbolTable' trait");
129 LDBG() <<
"renaming private symbols to resolve conflicts:";
131 for (
auto &&[symbolTable, otherSymbolTable] : llvm::zip(
134 &targetSymbolTable})) {
135 Operation *symbolTableOp = symbolTable->getOp();
137 auto symbolOp = dyn_cast<SymbolOpInterface>(op);
140 StringAttr name = symbolOp.getNameAttr();
141 LDBG() <<
" found @" << name.getValue();
145 cast_or_null<SymbolOpInterface>(otherSymbolTable->
lookup(name));
149 LDBG() <<
" collision found for @" << name.getValue();
152 if (
auto funcOp = dyn_cast<FunctionOpInterface>(op),
154 dyn_cast<FunctionOpInterface>(collidingOp.getOperation());
155 funcOp && collidingFuncOp) {
158 LDBG() <<
" but both ops are functions and will be merged";
163 LDBG() <<
" and both ops are function definitions";
167 auto renameToUnique =
168 [&](SymbolOpInterface op, SymbolOpInterface otherOp,
171 LDBG() <<
", renaming";
172 FailureOr<StringAttr> maybeNewName =
173 symbolTable.renameToUnique(op, {&otherSymbolTable});
174 if (failed(maybeNewName)) {
176 diag.attachNote(otherOp->getLoc())
177 <<
"attempted renaming due to collision with this op";
180 LDBG() <<
" renamed to @" << maybeNewName->getValue();
184 if (symbolOp.isPrivate()) {
185 if (failed(renameToUnique(symbolOp, collidingOp, *symbolTable,
190 if (collidingOp.isPrivate()) {
191 if (failed(renameToUnique(collidingOp, symbolOp, *otherSymbolTable,
196 LDBG() <<
", emitting error";
198 <<
"doubly defined symbol @" << name.getValue();
199 diag.attachNote(collidingOp->getLoc()) <<
"previously defined here";
208 return op->emitError()
209 <<
"failed to verify symbol table after symbol renaming";
215 LDBG() <<
"moving all symbols into target";
217 for (
Operation &op : other->getRegion(0).front()) {
218 if (
auto symbol = dyn_cast<SymbolOpInterface>(op))
219 processedSymbols.push_back(symbol);
222 for (SymbolOpInterface &op : processedSymbols) {
224 auto collidingOp = cast_or_null<SymbolOpInterface>(
225 targetSymbolTable.
lookup(op.getNameAttr()));
228 LDBG() <<
" moving @" << op.getName();
229 op->moveBefore(&
target->getRegion(0).front(),
230 target->getRegion(0).front().end());
236 LDBG() <<
" without collision";
237 targetSymbolTable.
insert(op);
243 auto funcOp = cast<FunctionOpInterface>(op.getOperation());
244 auto collidingFuncOp =
245 cast<FunctionOpInterface>(collidingOp.getOperation());
251 std::swap(funcOp, collidingFuncOp);
255 LDBG() <<
" with collision, trying to keep op at "
256 << collidingFuncOp.getLoc() <<
":\n"
260 targetSymbolTable.
remove(funcOp);
261 targetSymbolTable.
insert(collidingFuncOp);
262 assert(targetSymbolTable.
lookup(funcOp.getName()) == collidingFuncOp);
265 if (failed(
mergeInto(funcOp, collidingFuncOp)))
270 op = collidingFuncOp;
287 if (!isa<CallableOpInterface, CallOpInterface>(nested))
300 if (
auto symRef = dyn_cast<SymbolRefAttr>(callee)) {
303 if (isa<FlatSymbolRefAttr>(symRef))
304 callable = targetSymbolTable.
lookup(symRef.getLeafReference());
307 }
else if (
auto value = dyn_cast<Value>(callee)) {
308 callable = value.getDefiningOp();
315 if (!llvm::is_contained(processedSymbols, callable))
320 <<
"merged call is not legal to inline into its caller";
321 diag.attachNote(callable->
getLoc()) <<
"callee defined here";
329 LDBG() <<
"done merging ops";
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.