13 #include "llvm/Support/Debug.h"
17 #define DEBUG_TYPE "transform-dialect-utils"
18 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
24 static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
25 return func1.isExternal() && (func2.isPublic() || func2.isExternal());
32 FunctionOpInterface func2) {
34 assert(func1->getParentOp() == func2->getParentOp() &&
35 "expected func1 and func2 to be in the same parent op");
38 if (func1.getFunctionType() != func2.getFunctionType()) {
39 return func1.emitError()
40 <<
"external definition has a mismatching signature ("
41 << func2.getFunctionType() <<
")";
47 StringAttr consumedName = td->getConsumedAttrName();
48 StringAttr readOnlyName = td->getReadOnlyAttrName();
49 for (
unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) {
50 bool isExternalConsumed = func2.getArgAttr(i, consumedName) !=
nullptr;
51 bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) !=
nullptr;
52 bool isConsumed = func1.getArgAttr(i, consumedName) !=
nullptr;
53 bool isReadonly = func1.getArgAttr(i, readOnlyName) !=
nullptr;
54 if (!isExternalConsumed && !isExternalReadonly) {
62 if ((isExternalConsumed && !isConsumed) ||
63 (isExternalReadonly && !isReadonly)) {
64 return func1.emitError()
65 <<
"external definition has mismatching consumption "
66 "annotations for argument #"
72 assert(func1.isExternal());
82 "requires target to implement the 'SymbolTable' trait");
84 "requires target to implement the 'SymbolTable' trait");
93 LLVM_DEBUG(
DBGS() <<
"renaming private symbols to resolve conflicts:\n");
95 for (
auto &&[symbolTable, otherSymbolTable] : llvm::zip(
98 &targetSymbolTable})) {
99 Operation *symbolTableOp = symbolTable->getOp();
101 auto symbolOp = dyn_cast<SymbolOpInterface>(op);
104 StringAttr name = symbolOp.getNameAttr();
105 LLVM_DEBUG(
DBGS() <<
" found @" << name.getValue() <<
"\n");
109 cast_or_null<SymbolOpInterface>(otherSymbolTable->
lookup(name));
113 LLVM_DEBUG(
DBGS() <<
" collision found for @" << name.getValue());
116 if (
auto funcOp = dyn_cast<FunctionOpInterface>(op),
118 dyn_cast<FunctionOpInterface>(collidingOp.getOperation());
119 funcOp && collidingFuncOp) {
122 LLVM_DEBUG(llvm::dbgs() <<
" but both ops are functions and "
128 LLVM_DEBUG(llvm::dbgs() <<
" and both ops are function definitions");
132 auto renameToUnique =
133 [&](SymbolOpInterface op, SymbolOpInterface otherOp,
136 LLVM_DEBUG(llvm::dbgs() <<
", renaming\n");
137 FailureOr<StringAttr> maybeNewName =
138 symbolTable.renameToUnique(op, {&otherSymbolTable});
139 if (failed(maybeNewName)) {
141 diag.attachNote(otherOp->getLoc())
142 <<
"attempted renaming due to collision with this op";
145 LLVM_DEBUG(
DBGS() <<
" renamed to @" << maybeNewName->getValue()
150 if (symbolOp.isPrivate()) {
152 symbolOp, collidingOp, *symbolTable, *otherSymbolTable);
157 if (collidingOp.isPrivate()) {
159 collidingOp, symbolOp, *otherSymbolTable, *symbolTable);
164 LLVM_DEBUG(llvm::dbgs() <<
", emitting error\n");
166 <<
"doubly defined symbol @" << name.getValue();
167 diag.attachNote(collidingOp->getLoc()) <<
"previously defined here";
176 return op->emitError() <<
"failed to verify input op after renaming";
182 LLVM_DEBUG(
DBGS() <<
"moving all symbols into target\n");
185 for (
Operation &op : other->getRegion(0).front()) {
186 if (
auto symbol = dyn_cast<SymbolOpInterface>(op))
187 opsToMove.push_back(symbol);
190 for (SymbolOpInterface op : opsToMove) {
192 auto collidingOp = cast_or_null<SymbolOpInterface>(
193 targetSymbolTable.
lookup(op.getNameAttr()));
196 LLVM_DEBUG(
DBGS() <<
" moving @" << op.getName());
202 LLVM_DEBUG(llvm::dbgs() <<
" without collision\n");
208 auto funcOp = cast<FunctionOpInterface>(op.getOperation());
209 auto collidingFuncOp =
210 cast<FunctionOpInterface>(collidingOp.getOperation());
216 std::swap(funcOp, collidingFuncOp);
220 LLVM_DEBUG(llvm::dbgs() <<
" with collision, trying to keep op at "
221 << collidingFuncOp.getLoc() <<
":\n"
222 << collidingFuncOp <<
"\n");
225 targetSymbolTable.
remove(funcOp);
226 targetSymbolTable.
insert(collidingFuncOp);
227 assert(targetSymbolTable.
lookup(funcOp.getName()) == collidingFuncOp);
240 <<
"failed to verify target op after merging symbols";
242 LLVM_DEBUG(
DBGS() <<
"done merging ops\n");
static std::string diag(const llvm::Value &value)
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
void remove(Operation *op)
Remove the given symbol from the table, without deleting it.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...