MLIR 23.0.0git
Utils.cpp
Go to the documentation of this file.
1//===- Utils.cpp - Utils related to the transform dialect -------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12#include "mlir/IR/Verifier.h"
16#include "llvm/ADT/SCCIterator.h"
17#include "llvm/Support/Debug.h"
18#include "llvm/Support/DebugLog.h"
19
20using namespace mlir;
21
22#define DEBUG_TYPE "transform-dialect-utils"
23#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
24
25/// Return whether `func1` can be merged into `func2`. For that to work
26/// `func1` has to be a declaration (aka has to be external) and `func2`
27/// either has to be a declaration as well, or it has to be public (otherwise,
28/// it wouldn't be visible by `func1`).
29static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
30 return func1.isExternal() && (func2.isPublic() || func2.isExternal());
31}
32
33/// Merge `func1` into `func2`. The two ops must be inside the same parent op
34/// and mergable according to `canMergeInto`. The function erases `func1` such
35/// that only `func2` exists when the function returns.
36static LogicalResult mergeInto(FunctionOpInterface func1,
37 FunctionOpInterface func2) {
38 assert(canMergeInto(func1, func2));
39 assert(func1->getParentOp() == func2->getParentOp() &&
40 "expected func1 and func2 to be in the same parent op");
41
42 // Check that function signatures match.
43 if (func1.getFunctionType() != func2.getFunctionType()) {
44 return func1.emitError()
45 << "external definition has a mismatching signature ("
46 << func2.getFunctionType() << ")";
47 }
48
49 // Check and merge argument attributes.
50 MLIRContext *context = func1->getContext();
51 auto *td = context->getLoadedDialect<transform::TransformDialect>();
52 StringAttr consumedName = td->getConsumedAttrName();
53 StringAttr readOnlyName = td->getReadOnlyAttrName();
54 for (unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) {
55 bool isExternalConsumed = func2.getArgAttr(i, consumedName) != nullptr;
56 bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) != nullptr;
57 bool isConsumed = func1.getArgAttr(i, consumedName) != nullptr;
58 bool isReadonly = func1.getArgAttr(i, readOnlyName) != nullptr;
59 if (!isExternalConsumed && !isExternalReadonly) {
60 if (isConsumed)
61 func2.setArgAttr(i, consumedName, UnitAttr::get(context));
62 else if (isReadonly)
63 func2.setArgAttr(i, readOnlyName, UnitAttr::get(context));
64 continue;
65 }
66
67 if ((isExternalConsumed && !isConsumed) ||
68 (isExternalReadonly && !isReadonly)) {
69 return func1.emitError()
70 << "external definition has mismatching consumption "
71 "annotations for argument #"
72 << i;
73 }
74 }
75
76 // `func1` is the external one, so we can remove it.
77 assert(func1.isExternal());
78 func1->erase();
79
80 return success();
81}
82
84 const mlir::CallGraph callgraph(root);
85 for (auto scc = llvm::scc_begin(&callgraph); !scc.isAtEnd(); ++scc) {
86 if (!scc.hasCycle())
87 continue;
88
89 // Need to check this here additionally because this verification may run
90 // before we check the nested operations.
91 if ((*scc->begin())->isExternal())
92 return root->emitOpError() << "contains a call to an external "
93 "operation, which is not allowed";
94
95 Operation *first = (*scc->begin())->getCallableRegion()->getParentOp();
97 << "recursion not allowed in named sequences";
98 for (auto it = std::next(scc->begin()); it != scc->end(); ++it) {
99 // Need to check this here additionally because this verification may
100 // run before we check the nested operations.
101 if ((*it)->isExternal()) {
102 return root->emitOpError() << "contains a call to an external "
103 "operation, which is not allowed";
104 }
105
106 Operation *current = (*it)->getCallableRegion()->getParentOp();
107 diag.attachNote(current->getLoc()) << "operation on recursion stack";
108 }
109 return diag;
110 }
111 return success();
112}
113
114LogicalResult
117 assert(target->hasTrait<OpTrait::SymbolTable>() &&
118 "requires target to implement the 'SymbolTable' trait");
119 assert(other->hasTrait<OpTrait::SymbolTable>() &&
120 "requires target to implement the 'SymbolTable' trait");
121
122 SymbolTable targetSymbolTable(target);
123 InlinerInterface inliner(target->getContext());
124
125 // Collect all the functions that are called in `target` that cannot be
126 // inlined into `target`.
127 SmallPtrSet<Operation *, 1> noInlineCalls;
128 target->walk([&](CallOpInterface call) {
129 Operation *callable = nullptr;
130 CallInterfaceCallable callee = call.getCallableForCallee();
131 if (auto symRef = dyn_cast<SymbolRefAttr>(callee)) {
132 // Fall back to full resolution for nested symbols, the table is
133 // one-level only.
134 if (isa<FlatSymbolRefAttr>(symRef))
135 callable = targetSymbolTable.lookup(symRef.getLeafReference());
136 else
137 callable = SymbolTable::lookupNearestSymbolFrom(call, symRef);
138 } else if (auto value = dyn_cast<Value>(callee)) {
139 callable = value.getDefiningOp();
140 }
141
142 if (!callable)
143 return;
144
145 if (!inliner.isLegalToInline(call, callable, /*wouldBeCloned=*/false)) {
146 noInlineCalls.insert(call.getOperation());
147 }
148 return;
149 });
150
151 SymbolTable otherSymbolTable(*other);
152
153 // Step 1:
154 //
155 // Rename private symbols in both ops in order to resolve conflicts that can
156 // be resolved that way.
157 LDBG() << "renaming private symbols to resolve conflicts:";
158 // TODO: Do we *actually* need to test in both directions?
159 for (auto &&[symbolTable, otherSymbolTable] : llvm::zip(
160 SmallVector<SymbolTable *, 2>{&targetSymbolTable, &otherSymbolTable},
161 SmallVector<SymbolTable *, 2>{&otherSymbolTable,
162 &targetSymbolTable})) {
163 Operation *symbolTableOp = symbolTable->getOp();
164 for (Operation &op : symbolTableOp->getRegion(0).front()) {
165 auto symbolOp = dyn_cast<SymbolOpInterface>(op);
166 if (!symbolOp)
167 continue;
168 StringAttr name = symbolOp.getNameAttr();
169 LDBG() << " found @" << name.getValue();
170
171 // Check if there is a colliding op in the other module.
172 auto collidingOp =
173 cast_or_null<SymbolOpInterface>(otherSymbolTable->lookup(name));
174 if (!collidingOp)
175 continue;
176
177 LDBG() << " collision found for @" << name.getValue();
178
179 // Collisions are fine if both opt are functions and can be merged.
180 if (auto funcOp = dyn_cast<FunctionOpInterface>(op),
181 collidingFuncOp =
182 dyn_cast<FunctionOpInterface>(collidingOp.getOperation());
183 funcOp && collidingFuncOp) {
184 if (canMergeInto(funcOp, collidingFuncOp) ||
185 canMergeInto(collidingFuncOp, funcOp)) {
186 LDBG() << " but both ops are functions and will be merged";
187 continue;
188 }
189
190 // If they can't be merged, proceed like any other collision.
191 LDBG() << " and both ops are function definitions";
192 }
193
194 // Collision can be resolved by renaming if one of the ops is private.
195 auto renameToUnique =
196 [&](SymbolOpInterface op, SymbolOpInterface otherOp,
197 SymbolTable &symbolTable,
198 SymbolTable &otherSymbolTable) -> LogicalResult {
199 LDBG() << ", renaming";
200 FailureOr<StringAttr> maybeNewName =
201 symbolTable.renameToUnique(op, {&otherSymbolTable});
202 if (failed(maybeNewName)) {
203 InFlightDiagnostic diag = op->emitError("failed to rename symbol");
204 diag.attachNote(otherOp->getLoc())
205 << "attempted renaming due to collision with this op";
206 return diag;
207 }
208 LDBG() << " renamed to @" << maybeNewName->getValue();
209 return success();
210 };
211
212 if (symbolOp.isPrivate()) {
213 if (failed(renameToUnique(symbolOp, collidingOp, *symbolTable,
214 *otherSymbolTable)))
215 return failure();
216 continue;
217 }
218 if (collidingOp.isPrivate()) {
219 if (failed(renameToUnique(collidingOp, symbolOp, *otherSymbolTable,
220 *symbolTable)))
221 return failure();
222 continue;
223 }
224 LDBG() << ", emitting error";
225 InFlightDiagnostic diag = symbolOp.emitError()
226 << "doubly defined symbol @" << name.getValue();
227 diag.attachNote(collidingOp->getLoc()) << "previously defined here";
228 return diag;
229 }
230 }
231
232 // We only modified symbols above, so there is no need to verify everything
233 // again, just the symbol table.
234 for (auto *op : SmallVector<Operation *>{target, *other}) {
235 if (failed(mlir::detail::verifySymbolTable(op)))
236 return op->emitError()
237 << "failed to verify symbol table after symbol renaming";
238 }
239
240 // Step 2:
241 //
242 // Move all ops from `other` into target and merge public symbols.
243 LDBG() << "moving all symbols into target";
244 {
246 for (Operation &op : other->getRegion(0).front()) {
247 if (auto symbol = dyn_cast<SymbolOpInterface>(op))
248 opsToMove.push_back(symbol);
249 }
250
251 for (SymbolOpInterface op : opsToMove) {
252 // Remember potentially colliding op in the target module.
253 auto collidingOp = cast_or_null<SymbolOpInterface>(
254 targetSymbolTable.lookup(op.getNameAttr()));
255
256 // Move op even if we get a collision.
257 LDBG() << " moving @" << op.getName();
258 op->moveBefore(&target->getRegion(0).front(),
259 target->getRegion(0).front().end());
260
261 // If there is no collision, we are done -- keep the target symbol
262 // table in sync with the moved op so that subsequent lookups (and the
263 // post-merge validation below) remain efficient.
264 if (!collidingOp) {
265 LDBG() << " without collision";
266 targetSymbolTable.insert(op);
267 continue;
268 }
269
270 // The two colliding ops must both be functions because we have already
271 // emitted errors otherwise earlier.
272 auto funcOp = cast<FunctionOpInterface>(op.getOperation());
273 auto collidingFuncOp =
274 cast<FunctionOpInterface>(collidingOp.getOperation());
275
276 // Both ops are in the target module now and can be treated
277 // symmetrically, so w.l.o.g. we can reduce to merging `funcOp` into
278 // `collidingFuncOp`.
279 if (!canMergeInto(funcOp, collidingFuncOp)) {
280 std::swap(funcOp, collidingFuncOp);
281 }
282 assert(canMergeInto(funcOp, collidingFuncOp));
283
284 LDBG() << " with collision, trying to keep op at "
285 << collidingFuncOp.getLoc() << ":\n"
286 << collidingFuncOp;
287
288 // Update symbol table. This works with or without the previous `swap`.
289 targetSymbolTable.remove(funcOp);
290 targetSymbolTable.insert(collidingFuncOp);
291 assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp);
292
293 // Do the actual merging.
294 if (failed(mergeInto(funcOp, collidingFuncOp)))
295 return failure();
296 }
297 }
298
299 // Symbol merging only moves callable ops between symbol tables; it does not
300 // alter the bodies that were already valid in the source modules. The only
301 // invariants that may newly be violated after merging are:
302 // 1. a call now refers to a callee whose body is structurally not legal to
303 // inline at the call site (caught by the transform dialect's
304 // `DialectInlinerInterface` implementation), or
305 // 2. the merged call graph contains a recursive cycle, which is forbidden
306 // for `transform.named_sequence` callables (caught by the shared
307 // `verifyNoRecursionInCallGraph` helper).
308 // Use the inliner interface methods directly (without running the inlining
309 // pass) to validate (1), and reuse the dialect's call-graph verifier for
310 // (2). The call graph builder requires call/callable ops to be well-formed,
311 // so pre-verify them here without recursing into their bodies.
312 WalkResult preVerify = target->walk([](Operation *nested) {
313 if (!isa<CallableOpInterface, CallOpInterface>(nested))
314 return WalkResult::advance();
315 if (failed(mlir::verify(nested, /*verifyRecursively=*/false)))
316 return WalkResult::interrupt();
317 return WalkResult::advance();
318 });
319 if (preVerify.wasInterrupted())
320 return failure();
321
322 WalkResult inlineCheck = target->walk([&](CallOpInterface call) {
323 Operation *callable = nullptr;
324 CallInterfaceCallable callee = call.getCallableForCallee();
325 if (auto symRef = dyn_cast<SymbolRefAttr>(callee)) {
326 // Fall back to full resolution for nested symbols, the table is
327 // one-level only.
328 if (isa<FlatSymbolRefAttr>(symRef))
329 callable = targetSymbolTable.lookup(symRef.getLeafReference());
330 else
331 callable = SymbolTable::lookupNearestSymbolFrom(call, symRef);
332 } else if (auto value = dyn_cast<Value>(callee)) {
333 callable = value.getDefiningOp();
334 }
335
336 if (!callable)
337 return WalkResult::advance();
338 if (!noInlineCalls.contains(call.getOperation()) &&
339 !inliner.isLegalToInline(call, callable, /*wouldBeCloned=*/false)) {
341 call->emitError()
342 << "merged call is not legal to inline into its caller";
343 diag.attachNote(callable->getLoc()) << "callee defined here";
344 return WalkResult::interrupt();
345 }
346 return WalkResult::advance();
347 });
348 if (inlineCheck.wasInterrupted())
349 return failure();
350
351 LDBG() << "done merging ops";
353}
return success()
static LogicalResult mergeInto(FunctionOpInterface func1, FunctionOpInterface func2)
Merge func1 into func2.
Definition Utils.cpp:36
static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2)
Return whether func1 can be merged into func2.
Definition Utils.cpp:29
static std::string diag(const llvm::Value &value)
This class represents a diagnostic that is inflight and set to be reported.
This interface provides the hooks into the inlining interface.
virtual bool isLegalToInline(Operation *call, Operation *callable, bool wouldBeCloned) const
These hooks mirror the hooks for the DialectInlinerInterface, with default implementations that call ...
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:712
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:252
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
Definition OwningOpRef.h:29
Block & front()
Definition Region.h:65
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
void remove(Operation *op)
Remove the given symbol from the table, without deleting it.
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition WalkResult.h:51
static WalkResult interrupt()
Definition WalkResult.h:46
LogicalResult verifySymbolTable(Operation *op)
LogicalResult verifyNoRecursionInCallGraph(Operation *root)
Verify that the call graph inside root contains no cycles.
Definition Utils.cpp:83
LogicalResult mergeSymbolsInto(Operation *target, OwningOpRef< Operation * > other)
Merge all symbols from other into target.
Definition Utils.cpp:115
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition Verifier.cpp:480
A callable is either a symbol, or an SSA value, that is referenced by a call-like operation.