MLIR 22.0.0git
Utils.cpp
Go to the documentation of this file.
1//===- Utils.cpp - Utils related to the transform dialect -------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
11#include "mlir/IR/Verifier.h"
13#include "llvm/Support/Debug.h"
14#include "llvm/Support/DebugLog.h"
15
16using namespace mlir;
17
18#define DEBUG_TYPE "transform-dialect-utils"
19#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
20
21/// Return whether `func1` can be merged into `func2`. For that to work
22/// `func1` has to be a declaration (aka has to be external) and `func2`
23/// either has to be a declaration as well, or it has to be public (otherwise,
24/// it wouldn't be visible by `func1`).
25static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
26 return func1.isExternal() && (func2.isPublic() || func2.isExternal());
27}
28
29/// Merge `func1` into `func2`. The two ops must be inside the same parent op
30/// and mergable according to `canMergeInto`. The function erases `func1` such
31/// that only `func2` exists when the function returns.
32static InFlightDiagnostic mergeInto(FunctionOpInterface func1,
33 FunctionOpInterface func2) {
34 assert(canMergeInto(func1, func2));
35 assert(func1->getParentOp() == func2->getParentOp() &&
36 "expected func1 and func2 to be in the same parent op");
37
38 // Check that function signatures match.
39 if (func1.getFunctionType() != func2.getFunctionType()) {
40 return func1.emitError()
41 << "external definition has a mismatching signature ("
42 << func2.getFunctionType() << ")";
43 }
44
45 // Check and merge argument attributes.
46 MLIRContext *context = func1->getContext();
47 auto *td = context->getLoadedDialect<transform::TransformDialect>();
48 StringAttr consumedName = td->getConsumedAttrName();
49 StringAttr readOnlyName = td->getReadOnlyAttrName();
50 for (unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) {
51 bool isExternalConsumed = func2.getArgAttr(i, consumedName) != nullptr;
52 bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) != nullptr;
53 bool isConsumed = func1.getArgAttr(i, consumedName) != nullptr;
54 bool isReadonly = func1.getArgAttr(i, readOnlyName) != nullptr;
55 if (!isExternalConsumed && !isExternalReadonly) {
56 if (isConsumed)
57 func2.setArgAttr(i, consumedName, UnitAttr::get(context));
58 else if (isReadonly)
59 func2.setArgAttr(i, readOnlyName, UnitAttr::get(context));
60 continue;
61 }
62
63 if ((isExternalConsumed && !isConsumed) ||
64 (isExternalReadonly && !isReadonly)) {
65 return func1.emitError()
66 << "external definition has mismatching consumption "
67 "annotations for argument #"
68 << i;
69 }
70 }
71
72 // `func1` is the external one, so we can remove it.
73 assert(func1.isExternal());
74 func1->erase();
75
76 return InFlightDiagnostic();
77}
78
82 assert(target->hasTrait<OpTrait::SymbolTable>() &&
83 "requires target to implement the 'SymbolTable' trait");
84 assert(other->hasTrait<OpTrait::SymbolTable>() &&
85 "requires target to implement the 'SymbolTable' trait");
86
87 SymbolTable targetSymbolTable(target);
88 SymbolTable otherSymbolTable(*other);
89
90 // Step 1:
91 //
92 // Rename private symbols in both ops in order to resolve conflicts that can
93 // be resolved that way.
94 LDBG() << "renaming private symbols to resolve conflicts:";
95 // TODO: Do we *actually* need to test in both directions?
96 for (auto &&[symbolTable, otherSymbolTable] : llvm::zip(
97 SmallVector<SymbolTable *, 2>{&targetSymbolTable, &otherSymbolTable},
98 SmallVector<SymbolTable *, 2>{&otherSymbolTable,
99 &targetSymbolTable})) {
100 Operation *symbolTableOp = symbolTable->getOp();
101 for (Operation &op : symbolTableOp->getRegion(0).front()) {
102 auto symbolOp = dyn_cast<SymbolOpInterface>(op);
103 if (!symbolOp)
104 continue;
105 StringAttr name = symbolOp.getNameAttr();
106 LDBG() << " found @" << name.getValue();
107
108 // Check if there is a colliding op in the other module.
109 auto collidingOp =
110 cast_or_null<SymbolOpInterface>(otherSymbolTable->lookup(name));
111 if (!collidingOp)
112 continue;
113
114 LDBG() << " collision found for @" << name.getValue();
115
116 // Collisions are fine if both opt are functions and can be merged.
117 if (auto funcOp = dyn_cast<FunctionOpInterface>(op),
118 collidingFuncOp =
119 dyn_cast<FunctionOpInterface>(collidingOp.getOperation());
120 funcOp && collidingFuncOp) {
121 if (canMergeInto(funcOp, collidingFuncOp) ||
122 canMergeInto(collidingFuncOp, funcOp)) {
123 LDBG() << " but both ops are functions and will be merged";
124 continue;
125 }
126
127 // If they can't be merged, proceed like any other collision.
128 LDBG() << " and both ops are function definitions";
129 }
130
131 // Collision can be resolved by renaming if one of the ops is private.
132 auto renameToUnique =
133 [&](SymbolOpInterface op, SymbolOpInterface otherOp,
134 SymbolTable &symbolTable,
135 SymbolTable &otherSymbolTable) -> InFlightDiagnostic {
136 LDBG() << ", renaming";
137 FailureOr<StringAttr> maybeNewName =
138 symbolTable.renameToUnique(op, {&otherSymbolTable});
139 if (failed(maybeNewName)) {
140 InFlightDiagnostic diag = op->emitError("failed to rename symbol");
141 diag.attachNote(otherOp->getLoc())
142 << "attempted renaming due to collision with this op";
143 return diag;
144 }
145 LDBG() << " renamed to @" << maybeNewName->getValue();
146 return InFlightDiagnostic();
147 };
148
149 if (symbolOp.isPrivate()) {
150 InFlightDiagnostic diag = renameToUnique(
151 symbolOp, collidingOp, *symbolTable, *otherSymbolTable);
152 if (failed(diag))
153 return diag;
154 continue;
155 }
156 if (collidingOp.isPrivate()) {
157 InFlightDiagnostic diag = renameToUnique(
158 collidingOp, symbolOp, *otherSymbolTable, *symbolTable);
159 if (failed(diag))
160 return diag;
161 continue;
162 }
163 LDBG() << ", emitting error";
164 InFlightDiagnostic diag = symbolOp.emitError()
165 << "doubly defined symbol @" << name.getValue();
166 diag.attachNote(collidingOp->getLoc()) << "previously defined here";
167 return diag;
168 }
169 }
170
171 // TODO: This duplicates pass infrastructure. We should split this pass into
172 // several and let the pass infrastructure do the verification.
173 for (auto *op : SmallVector<Operation *>{target, *other}) {
174 if (failed(mlir::verify(op)))
175 return op->emitError() << "failed to verify input op after renaming";
176 }
177
178 // Step 2:
179 //
180 // Move all ops from `other` into target and merge public symbols.
181 LDBG() << "moving all symbols into target";
182 {
184 for (Operation &op : other->getRegion(0).front()) {
185 if (auto symbol = dyn_cast<SymbolOpInterface>(op))
186 opsToMove.push_back(symbol);
187 }
188
189 for (SymbolOpInterface op : opsToMove) {
190 // Remember potentially colliding op in the target module.
191 auto collidingOp = cast_or_null<SymbolOpInterface>(
192 targetSymbolTable.lookup(op.getNameAttr()));
193
194 // Move op even if we get a collision.
195 LDBG() << " moving @" << op.getName();
196 op->moveBefore(&target->getRegion(0).front(),
197 target->getRegion(0).front().end());
198
199 // If there is no collision, we are done.
200 if (!collidingOp) {
201 LDBG() << " without collision";
202 continue;
203 }
204
205 // The two colliding ops must both be functions because we have already
206 // emitted errors otherwise earlier.
207 auto funcOp = cast<FunctionOpInterface>(op.getOperation());
208 auto collidingFuncOp =
209 cast<FunctionOpInterface>(collidingOp.getOperation());
210
211 // Both ops are in the target module now and can be treated
212 // symmetrically, so w.l.o.g. we can reduce to merging `funcOp` into
213 // `collidingFuncOp`.
214 if (!canMergeInto(funcOp, collidingFuncOp)) {
215 std::swap(funcOp, collidingFuncOp);
216 }
217 assert(canMergeInto(funcOp, collidingFuncOp));
218
219 LDBG() << " with collision, trying to keep op at "
220 << collidingFuncOp.getLoc() << ":\n"
221 << collidingFuncOp;
222
223 // Update symbol table. This works with or without the previous `swap`.
224 targetSymbolTable.remove(funcOp);
225 targetSymbolTable.insert(collidingFuncOp);
226 assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp);
227
228 // Do the actual merging.
229 {
230 InFlightDiagnostic diag = mergeInto(funcOp, collidingFuncOp);
231 if (failed(diag))
232 return diag;
233 }
234 }
235 }
236
237 if (failed(mlir::verify(target)))
238 return target->emitError()
239 << "failed to verify target op after merging symbols";
240
241 LDBG() << "done merging ops";
242 return InFlightDiagnostic();
243}
static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2)
Return whether func1 can be merged into func2.
Definition Utils.cpp:25
static InFlightDiagnostic mergeInto(FunctionOpInterface func1, FunctionOpInterface func2)
Merge func1 into func2.
Definition Utils.cpp:32
static std::string diag(const llvm::Value &value)
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
Definition OwningOpRef.h:29
Block & front()
Definition Region.h:65
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
void remove(Operation *op)
Remove the given symbol from the table, without deleting it.
InFlightDiagnostic mergeSymbolsInto(Operation *target, OwningOpRef< Operation * > other)
Merge all symbols from other into target.
Definition Utils.cpp:80
Include the generated interface declarations.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition Verifier.cpp:423