13 #include "llvm/Support/Debug.h"
14 #include "llvm/Support/DebugLog.h"
18 #define DEBUG_TYPE "transform-dialect-utils"
19 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
25 static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
26 return func1.isExternal() && (func2.isPublic() || func2.isExternal());
33 FunctionOpInterface func2) {
35 assert(func1->getParentOp() == func2->getParentOp() &&
36 "expected func1 and func2 to be in the same parent op");
39 if (func1.getFunctionType() != func2.getFunctionType()) {
40 return func1.emitError()
41 <<
"external definition has a mismatching signature ("
42 << func2.getFunctionType() <<
")";
48 StringAttr consumedName = td->getConsumedAttrName();
49 StringAttr readOnlyName = td->getReadOnlyAttrName();
50 for (
unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) {
51 bool isExternalConsumed = func2.getArgAttr(i, consumedName) !=
nullptr;
52 bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) !=
nullptr;
53 bool isConsumed = func1.getArgAttr(i, consumedName) !=
nullptr;
54 bool isReadonly = func1.getArgAttr(i, readOnlyName) !=
nullptr;
55 if (!isExternalConsumed && !isExternalReadonly) {
63 if ((isExternalConsumed && !isConsumed) ||
64 (isExternalReadonly && !isReadonly)) {
65 return func1.emitError()
66 <<
"external definition has mismatching consumption "
67 "annotations for argument #"
73 assert(func1.isExternal());
83 "requires target to implement the 'SymbolTable' trait");
85 "requires target to implement the 'SymbolTable' trait");
94 LDBG() <<
"renaming private symbols to resolve conflicts:";
96 for (
auto &&[symbolTable, otherSymbolTable] : llvm::zip(
99 &targetSymbolTable})) {
100 Operation *symbolTableOp = symbolTable->getOp();
102 auto symbolOp = dyn_cast<SymbolOpInterface>(op);
105 StringAttr name = symbolOp.getNameAttr();
106 LDBG() <<
" found @" << name.getValue();
110 cast_or_null<SymbolOpInterface>(otherSymbolTable->
lookup(name));
114 LDBG() <<
" collision found for @" << name.getValue();
117 if (
auto funcOp = dyn_cast<FunctionOpInterface>(op),
119 dyn_cast<FunctionOpInterface>(collidingOp.getOperation());
120 funcOp && collidingFuncOp) {
123 LDBG() <<
" but both ops are functions and will be merged";
128 LDBG() <<
" and both ops are function definitions";
132 auto renameToUnique =
133 [&](SymbolOpInterface op, SymbolOpInterface otherOp,
136 LDBG() <<
", renaming";
137 FailureOr<StringAttr> maybeNewName =
138 symbolTable.renameToUnique(op, {&otherSymbolTable});
139 if (
failed(maybeNewName)) {
141 diag.attachNote(otherOp->getLoc())
142 <<
"attempted renaming due to collision with this op";
145 LDBG() <<
" renamed to @" << maybeNewName->getValue();
149 if (symbolOp.isPrivate()) {
151 symbolOp, collidingOp, *symbolTable, *otherSymbolTable);
156 if (collidingOp.isPrivate()) {
158 collidingOp, symbolOp, *otherSymbolTable, *symbolTable);
163 LDBG() <<
", emitting error";
165 <<
"doubly defined symbol @" << name.getValue();
166 diag.attachNote(collidingOp->getLoc()) <<
"previously defined here";
175 return op->emitError() <<
"failed to verify input op after renaming";
181 LDBG() <<
"moving all symbols into target";
184 for (
Operation &op : other->getRegion(0).front()) {
185 if (
auto symbol = dyn_cast<SymbolOpInterface>(op))
186 opsToMove.push_back(symbol);
189 for (SymbolOpInterface op : opsToMove) {
191 auto collidingOp = cast_or_null<SymbolOpInterface>(
192 targetSymbolTable.
lookup(op.getNameAttr()));
195 LDBG() <<
" moving @" << op.getName();
201 LDBG() <<
" without collision";
207 auto funcOp = cast<FunctionOpInterface>(op.getOperation());
208 auto collidingFuncOp =
209 cast<FunctionOpInterface>(collidingOp.getOperation());
215 std::swap(funcOp, collidingFuncOp);
219 LDBG() <<
" with collision, trying to keep op at "
220 << collidingFuncOp.getLoc() <<
":\n"
224 targetSymbolTable.
remove(funcOp);
225 targetSymbolTable.
insert(collidingFuncOp);
226 assert(targetSymbolTable.
lookup(funcOp.getName()) == collidingFuncOp);
239 <<
"failed to verify target op after merging symbols";
241 LDBG() <<
"done merging ops";
static std::string diag(const llvm::Value &value)
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
void remove(Operation *op)
Remove the given symbol from the table, without deleting it.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...