MLIR 22.0.0git
Inliner.cpp
Go to the documentation of this file.
1//===- Inliner.cpp ---- SCC-based inliner ---------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements Inliner that uses a basic inlining
10// algorithm that operates bottom up over the Strongly Connect Components(SCCs)
11// of the CallGraph. This enables a more incremental propagation of inlining
12// decisions from the leafs to the roots of the callgraph.
13//
14//===----------------------------------------------------------------------===//
15
17#include "mlir/IR/Threading.h"
22#include "llvm/ADT/SCCIterator.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/Support/DebugLog.h"
25
26#define DEBUG_TYPE "inlining"
27
28using namespace mlir;
29
31
32//===----------------------------------------------------------------------===//
33// Symbol Use Tracking
34//===----------------------------------------------------------------------===//
35
36/// Walk all of the used symbol callgraph nodes referenced with the given op.
38 Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable,
40 function_ref<void(CallGraphNode *, Operation *)> callback) {
41 auto symbolUses = SymbolTable::getSymbolUses(op);
42 assert(symbolUses && "expected uses to be valid");
43
44 Operation *symbolTableOp = op->getParentOp();
45 for (const SymbolTable::SymbolUse &use : *symbolUses) {
46 auto refIt = resolvedRefs.try_emplace(use.getSymbolRef());
47 CallGraphNode *&node = refIt.first->second;
48
49 // If this is the first instance of this reference, try to resolve a
50 // callgraph node for it.
51 if (refIt.second) {
52 auto *symbolOp = symbolTable.lookupNearestSymbolFrom(symbolTableOp,
53 use.getSymbolRef());
54 auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp);
55 if (!callableOp)
56 continue;
57 node = cg.lookupNode(callableOp.getCallableRegion());
58 }
59 if (node)
60 callback(node, use.getUser());
61 }
62}
63
64//===----------------------------------------------------------------------===//
65// CGUseList
66//===----------------------------------------------------------------------===//
67
68namespace {
69/// This struct tracks the uses of callgraph nodes that can be dropped when
70/// use_empty. It directly tracks and manages a use-list for all of the
71/// call-graph nodes. This is necessary because many callgraph nodes are
72/// referenced by SymbolRefAttr, which has no mechanism akin to the SSA `Use`
73/// class.
74struct CGUseList {
75 /// This struct tracks the uses of callgraph nodes within a specific
76 /// operation.
77 struct CGUser {
78 /// Any nodes referenced in the top-level attribute list of this user. We
79 /// use a set here because the number of references does not matter.
80 DenseSet<CallGraphNode *> topLevelUses;
81
82 /// Uses of nodes referenced by nested operations.
84 };
85
86 CGUseList(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable);
87
88 /// Drop uses of nodes referred to by the given call operation that resides
89 /// within 'userNode'.
90 void dropCallUses(CallGraphNode *userNode, Operation *callOp, CallGraph &cg);
91
92 /// Remove the given node from the use list.
93 void eraseNode(CallGraphNode *node);
94
95 /// Returns true if the given callgraph node has no uses and can be pruned.
96 bool isDead(CallGraphNode *node) const;
97
98 /// Returns true if the given callgraph node has a single use and can be
99 /// discarded.
100 bool hasOneUseAndDiscardable(CallGraphNode *node) const;
101
102 /// Recompute the uses held by the given callgraph node.
103 void recomputeUses(CallGraphNode *node, CallGraph &cg);
104
105 /// Merge the uses of 'lhs' with the uses of the 'rhs' after inlining a copy
106 /// of 'lhs' into 'rhs'.
107 void mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs);
108
109private:
110 /// Decrement the uses of discardable nodes referenced by the given user.
111 void decrementDiscardableUses(CGUser &uses);
112
113 /// A mapping between a discardable callgraph node (that is a symbol) and the
114 /// number of uses for this node.
115 DenseMap<CallGraphNode *, int> discardableSymNodeUses;
116
117 /// A mapping between a callgraph node and the symbol callgraph nodes that it
118 /// uses.
120
121 /// A symbol table to use when resolving call lookups.
122 SymbolTableCollection &symbolTable;
123};
124} // namespace
125
126CGUseList::CGUseList(Operation *op, CallGraph &cg,
127 SymbolTableCollection &symbolTable)
128 : symbolTable(symbolTable) {
129 /// A set of callgraph nodes that are always known to be live during inlining.
131
132 // Walk each of the symbol tables looking for discardable callgraph nodes.
133 auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
134 for (Operation &op : symbolTableOp->getRegion(0).getOps()) {
135 // If this is a callgraph operation, check to see if it is discardable.
136 if (auto callable = dyn_cast<CallableOpInterface>(&op)) {
137 if (auto *node = cg.lookupNode(callable.getCallableRegion())) {
138 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
139 if (symbol && (allUsesVisible || symbol.isPrivate()) &&
140 symbol.canDiscardOnUseEmpty()) {
141 discardableSymNodeUses.try_emplace(node, 0);
142 }
143 continue;
144 }
145 }
146 // Otherwise, check for any referenced nodes. These will be always-live.
147 walkReferencedSymbolNodes(&op, cg, symbolTable, alwaysLiveNodes,
148 [](CallGraphNode *, Operation *) {});
149 }
150 };
151 SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
152 walkFn);
153
154 // Drop the use information for any discardable nodes that are always live.
155 for (auto &it : alwaysLiveNodes)
156 discardableSymNodeUses.erase(it.second);
157
158 // Compute the uses for each of the callable nodes in the graph.
159 for (CallGraphNode *node : cg)
160 recomputeUses(node, cg);
161}
162
163void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp,
164 CallGraph &cg) {
165 auto &userRefs = nodeUses[userNode].innerUses;
166 auto walkFn = [&](CallGraphNode *node, Operation *user) {
167 auto parentIt = userRefs.find(node);
168 if (parentIt == userRefs.end())
169 return;
170 --parentIt->second;
171 --discardableSymNodeUses[node];
172 };
174 walkReferencedSymbolNodes(callOp, cg, symbolTable, resolvedRefs, walkFn);
175}
176
177void CGUseList::eraseNode(CallGraphNode *node) {
178 // Drop all child nodes.
179 for (auto &edge : *node)
180 if (edge.isChild())
181 eraseNode(edge.getTarget());
182
183 // Drop the uses held by this node and erase it.
184 auto useIt = nodeUses.find(node);
185 assert(useIt != nodeUses.end() && "expected node to be valid");
186 decrementDiscardableUses(useIt->getSecond());
187 nodeUses.erase(useIt);
188 discardableSymNodeUses.erase(node);
189}
190
191bool CGUseList::isDead(CallGraphNode *node) const {
192 // If the parent operation isn't a symbol, simply check normal SSA deadness.
193 Operation *nodeOp = node->getCallableRegion()->getParentOp();
194 if (!isa<SymbolOpInterface>(nodeOp))
195 return isMemoryEffectFree(nodeOp) && nodeOp->use_empty();
196
197 // Otherwise, check the number of symbol uses.
198 auto symbolIt = discardableSymNodeUses.find(node);
199 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0;
200}
201
202bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const {
203 // If this isn't a symbol node, check for side-effects and SSA use count.
204 Operation *nodeOp = node->getCallableRegion()->getParentOp();
205 if (!isa<SymbolOpInterface>(nodeOp))
206 return isMemoryEffectFree(nodeOp) && nodeOp->hasOneUse();
207
208 // Otherwise, check the number of symbol uses.
209 auto symbolIt = discardableSymNodeUses.find(node);
210 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1;
211}
212
213void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) {
214 Operation *parentOp = node->getCallableRegion()->getParentOp();
215 CGUser &uses = nodeUses[node];
216 decrementDiscardableUses(uses);
217
218 // Collect the new discardable uses within this node.
219 uses = CGUser();
221 auto walkFn = [&](CallGraphNode *refNode, Operation *user) {
222 auto discardSymIt = discardableSymNodeUses.find(refNode);
223 if (discardSymIt == discardableSymNodeUses.end())
224 return;
225
226 if (user != parentOp)
227 ++uses.innerUses[refNode];
228 else if (!uses.topLevelUses.insert(refNode).second)
229 return;
230 ++discardSymIt->second;
231 };
232 walkReferencedSymbolNodes(parentOp, cg, symbolTable, resolvedRefs, walkFn);
233}
234
235void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) {
236 auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs];
237 for (auto &useIt : lhsUses.innerUses) {
238 rhsUses.innerUses[useIt.first] += useIt.second;
239 discardableSymNodeUses[useIt.first] += useIt.second;
240 }
241}
242
243void CGUseList::decrementDiscardableUses(CGUser &uses) {
244 for (CallGraphNode *node : uses.topLevelUses)
245 --discardableSymNodeUses[node];
246 for (auto &it : uses.innerUses)
247 discardableSymNodeUses[it.first] -= it.second;
248}
249
250//===----------------------------------------------------------------------===//
251// CallGraph traversal
252//===----------------------------------------------------------------------===//
253
254namespace {
255/// This class represents a specific callgraph SCC.
256class CallGraphSCC {
257public:
258 CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator)
259 : parentIterator(parentIterator) {}
260 /// Return a range over the nodes within this SCC.
261 std::vector<CallGraphNode *>::iterator begin() { return nodes.begin(); }
262 std::vector<CallGraphNode *>::iterator end() { return nodes.end(); }
263
264 /// Reset the nodes of this SCC with those provided.
265 void reset(const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; }
266
267 /// Remove the given node from this SCC.
268 void remove(CallGraphNode *node) {
269 auto it = llvm::find(nodes, node);
270 if (it != nodes.end()) {
271 nodes.erase(it);
272 parentIterator.ReplaceNode(node, nullptr);
273 }
274 }
275
276private:
277 std::vector<CallGraphNode *> nodes;
278 llvm::scc_iterator<const CallGraph *> &parentIterator;
279};
280} // namespace
281
282/// Run a given transformation over the SCCs of the callgraph in a bottom up
283/// traversal.
284static LogicalResult runTransformOnCGSCCs(
285 const CallGraph &cg,
286 function_ref<LogicalResult(CallGraphSCC &)> sccTransformer) {
287 llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg);
288 CallGraphSCC currentSCC(cgi);
289 while (!cgi.isAtEnd()) {
290 // Copy the current SCC and increment so that the transformer can modify the
291 // SCC without invalidating our iterator.
292 currentSCC.reset(*cgi);
293 ++cgi;
294 if (failed(sccTransformer(currentSCC)))
295 return failure();
296 }
297 return success();
298}
299
300/// Collect all of the callable operations within the given range of blocks. If
301/// `traverseNestedCGNodes` is true, this will also collect call operations
302/// inside of nested callgraph nodes.
304 CallGraphNode *sourceNode, CallGraph &cg,
305 SymbolTableCollection &symbolTable,
307 bool traverseNestedCGNodes) {
309 auto addToWorklist = [&](CallGraphNode *node,
311 for (Block &block : blocks)
312 worklist.emplace_back(&block, node);
313 };
314
315 addToWorklist(sourceNode, blocks);
316 while (!worklist.empty()) {
317 Block *block;
318 std::tie(block, sourceNode) = worklist.pop_back_val();
319
320 for (Operation &op : *block) {
321 if (auto call = dyn_cast<CallOpInterface>(op)) {
322 // TODO: Support inlining nested call references.
323 CallInterfaceCallable callable = call.getCallableForCallee();
324 if (SymbolRefAttr symRef = dyn_cast<SymbolRefAttr>(callable)) {
325 if (!isa<FlatSymbolRefAttr>(symRef))
326 continue;
327 }
328
329 CallGraphNode *targetNode = cg.resolveCallable(call, symbolTable);
330 if (!targetNode->isExternal())
331 calls.emplace_back(call, sourceNode, targetNode);
332 continue;
333 }
334
335 // If this is not a call, traverse the nested regions. If
336 // `traverseNestedCGNodes` is false, then don't traverse nested call graph
337 // regions.
338 for (auto &nestedRegion : op.getRegions()) {
339 CallGraphNode *nestedNode = cg.lookupNode(&nestedRegion);
340 if (traverseNestedCGNodes || !nestedNode)
341 addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion);
342 }
343 }
344 }
345}
346
347//===----------------------------------------------------------------------===//
348// InlinerInterfaceImpl
349//===----------------------------------------------------------------------===//
350
351static std::string getNodeName(CallOpInterface op) {
352 if (llvm::dyn_cast_if_present<SymbolRefAttr>(op.getCallableForCallee()))
353 return debugString(op);
354 return "_unnamed_callee_";
355}
356
357/// Return true if the specified `inlineHistoryID` indicates an inline history
358/// that already includes `node`.
360 CallGraphNode *node, std::optional<size_t> inlineHistoryID,
361 MutableArrayRef<std::pair<CallGraphNode *, std::optional<size_t>>>
362 inlineHistory) {
363 while (inlineHistoryID.has_value()) {
364 assert(*inlineHistoryID < inlineHistory.size() &&
365 "Invalid inline history ID");
366 if (inlineHistory[*inlineHistoryID].first == node)
367 return true;
368 inlineHistoryID = inlineHistory[*inlineHistoryID].second;
369 }
370 return false;
371}
372
373namespace {
374/// This class provides a specialization of the main inlining interface.
375struct InlinerInterfaceImpl : public InlinerInterface {
376 InlinerInterfaceImpl(MLIRContext *context, CallGraph &cg,
377 SymbolTableCollection &symbolTable)
378 : InlinerInterface(context), cg(cg), symbolTable(symbolTable) {}
379
380 /// Process a set of blocks that have been inlined. This callback is invoked
381 /// *before* inlined terminator operations have been processed.
382 void
383 processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final {
384 // Find the closest callgraph node from the first block.
385 CallGraphNode *node;
386 Region *region = inlinedBlocks.begin()->getParent();
387 while (!(node = cg.lookupNode(region))) {
388 region = region->getParentRegion();
389 assert(region && "expected valid parent node");
390 }
391
392 collectCallOps(inlinedBlocks, node, cg, symbolTable, calls,
393 /*traverseNestedCGNodes=*/true);
394 }
395
396 /// Mark the given callgraph node for deletion.
397 void markForDeletion(CallGraphNode *node) { deadNodes.insert(node); }
398
399 /// This method properly disposes of callables that became dead during
400 /// inlining. This should not be called while iterating over the SCCs.
401 void eraseDeadCallables() {
402 for (CallGraphNode *node : deadNodes)
403 node->getCallableRegion()->getParentOp()->erase();
404 }
405
406 /// The set of callables known to be dead.
407 SmallPtrSet<CallGraphNode *, 8> deadNodes;
408
409 /// The current set of call instructions to consider for inlining.
410 SmallVector<ResolvedCall, 8> calls;
411
412 /// The callgraph being operated on.
413 CallGraph &cg;
414
415 /// A symbol table to use when resolving call lookups.
416 SymbolTableCollection &symbolTable;
417};
418} // namespace
419
420namespace mlir {
421
423public:
424 Impl(Inliner &inliner) : inliner(inliner) {}
425
426 /// Attempt to inline calls within the given scc, and run simplifications,
427 /// until a fixed point is reached. This allows for the inlining of newly
428 /// devirtualized calls. Returns failure if there was a fatal error during
429 /// inlining.
430 LogicalResult inlineSCC(InlinerInterfaceImpl &inlinerIface,
431 CGUseList &useList, CallGraphSCC &currentSCC,
432 MLIRContext *context);
433
434private:
435 /// Optimize the nodes within the given SCC with one of the held optimization
436 /// pass pipelines. Returns failure if an error occurred during the
437 /// optimization of the SCC, success otherwise.
438 LogicalResult optimizeSCC(CallGraph &cg, CGUseList &useList,
439 CallGraphSCC &currentSCC, MLIRContext *context);
440
441 /// Optimize the nodes within the given SCC in parallel. Returns failure if an
442 /// error occurred during the optimization of the SCC, success otherwise.
443 LogicalResult optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
444 MLIRContext *context);
445
446 /// Optimize the given callable node with one of the pass managers provided
447 /// with `pipelines`, or the generic pre-inline pipeline. Returns failure if
448 /// an error occurred during the optimization of the callable, success
449 /// otherwise.
450 LogicalResult optimizeCallable(CallGraphNode *node,
451 llvm::StringMap<OpPassManager> &pipelines);
452
453 /// Attempt to inline calls within the given scc. This function returns
454 /// success if any calls were inlined, failure otherwise.
455 LogicalResult inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
456 CGUseList &useList, CallGraphSCC &currentSCC);
457
458 /// Returns true if the given call should be inlined.
459 bool shouldInline(ResolvedCall &resolvedCall);
460
461private:
462 Inliner &inliner;
464};
465
466LogicalResult Inliner::Impl::inlineSCC(InlinerInterfaceImpl &inlinerIface,
467 CGUseList &useList,
468 CallGraphSCC &currentSCC,
469 MLIRContext *context) {
470 // Continuously simplify and inline until we either reach a fixed point, or
471 // hit the maximum iteration count. Simplifying early helps to refine the cost
472 // model, and in future iterations may devirtualize new calls.
473 unsigned iterationCount = 0;
474 do {
475 if (failed(optimizeSCC(inlinerIface.cg, useList, currentSCC, context)))
476 return failure();
477 if (failed(inlineCallsInSCC(inlinerIface, useList, currentSCC)))
478 break;
479 } while (++iterationCount < inliner.config.getMaxInliningIterations());
480 return success();
481}
482
483LogicalResult Inliner::Impl::optimizeSCC(CallGraph &cg, CGUseList &useList,
484 CallGraphSCC &currentSCC,
485 MLIRContext *context) {
486 // Collect the sets of nodes to simplify.
488 for (auto *node : currentSCC) {
489 if (node->isExternal())
490 continue;
491
492 // Don't simplify nodes with children. Nodes with children require special
493 // handling as we may remove the node during simplification. In the future,
494 // we should be able to handle this case with proper node deletion tracking.
495 if (node->hasChildren())
496 continue;
497
498 // We also won't apply simplifications to nodes that can't have passes
499 // scheduled on them.
500 auto *region = node->getCallableRegion();
502 continue;
503 nodesToVisit.push_back(node);
504 }
505 if (nodesToVisit.empty())
506 return success();
507
508 // Optimize each of the nodes within the SCC in parallel.
509 if (failed(optimizeSCCAsync(nodesToVisit, context)))
510 return failure();
511
512 // Recompute the uses held by each of the nodes.
513 for (CallGraphNode *node : nodesToVisit)
514 useList.recomputeUses(node, cg);
515 return success();
516}
517
518LogicalResult
519Inliner::Impl::optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
520 MLIRContext *ctx) {
521 // We must maintain a fixed pool of pass managers which is at least as large
522 // as the maximum parallelism of the failableParallelForEach below.
523 // Note: The number of pass managers here needs to remain constant
524 // to prevent issues with pass instrumentations that rely on having the same
525 // pass manager for the main thread.
526 size_t numThreads = ctx->getNumThreads();
527 const auto &opPipelines = inliner.config.getOpPipelines();
528 if (pipelines.size() < numThreads) {
529 pipelines.reserve(numThreads);
530 pipelines.resize(numThreads, opPipelines);
531 }
532
533 // Ensure an analysis manager has been constructed for each of the nodes.
534 // This prevents thread races when running the nested pipelines.
535 for (CallGraphNode *node : nodesToVisit)
536 inliner.am.nest(node->getCallableRegion()->getParentOp());
537
538 // An atomic failure variable for the async executors.
539 std::vector<std::atomic<bool>> activePMs(pipelines.size());
540 llvm::fill(activePMs, false);
541 return failableParallelForEach(ctx, nodesToVisit, [&](CallGraphNode *node) {
542 // Find a pass manager for this operation.
543 auto it = llvm::find_if(activePMs, [](std::atomic<bool> &isActive) {
544 bool expectedInactive = false;
545 return isActive.compare_exchange_strong(expectedInactive, true);
546 });
547 assert(it != activePMs.end() &&
548 "could not find inactive pass manager for thread");
549 unsigned pmIndex = it - activePMs.begin();
550
551 // Optimize this callable node.
552 LogicalResult result = optimizeCallable(node, pipelines[pmIndex]);
553
554 // Reset the active bit for this pass manager.
555 activePMs[pmIndex].store(false);
556 return result;
557 });
558}
559
560LogicalResult
561Inliner::Impl::optimizeCallable(CallGraphNode *node,
562 llvm::StringMap<OpPassManager> &pipelines) {
563 Operation *callable = node->getCallableRegion()->getParentOp();
564 StringRef opName = callable->getName().getStringRef();
565 auto pipelineIt = pipelines.find(opName);
566 const auto &defaultPipeline = inliner.config.getDefaultPipeline();
567 if (pipelineIt == pipelines.end()) {
568 // If a pipeline didn't exist, use the generic pipeline if possible.
569 if (!defaultPipeline)
570 return success();
571
572 OpPassManager defaultPM(opName);
573 defaultPipeline(defaultPM);
574 pipelineIt = pipelines.try_emplace(opName, std::move(defaultPM)).first;
575 }
576 return inliner.runPipelineHelper(inliner.pass, pipelineIt->second, callable);
577}
578
579/// Attempt to inline calls within the given scc. This function returns
580/// success if any calls were inlined, failure otherwise.
581LogicalResult
582Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
583 CGUseList &useList, CallGraphSCC &currentSCC) {
584 CallGraph &cg = inlinerIface.cg;
585 auto &calls = inlinerIface.calls;
586
587 // A set of dead nodes to remove after inlining.
588 llvm::SmallSetVector<CallGraphNode *, 1> deadNodes;
589
590 // Collect all of the direct calls within the nodes of the current SCC. We
591 // don't traverse nested callgraph nodes, because they are handled separately
592 // likely within a different SCC.
593 for (CallGraphNode *node : currentSCC) {
594 if (node->isExternal())
595 continue;
596
597 // Don't collect calls if the node is already dead.
598 if (useList.isDead(node)) {
599 deadNodes.insert(node);
600 } else {
601 collectCallOps(*node->getCallableRegion(), node, cg,
602 inlinerIface.symbolTable, calls,
603 /*traverseNestedCGNodes=*/false);
604 }
605 }
606
607 // When inlining a callee produces new call sites, we want to keep track of
608 // the fact that they were inlined from the callee. This allows us to avoid
609 // infinite inlining.
610 using InlineHistoryT = std::optional<size_t>;
611 SmallVector<std::pair<CallGraphNode *, InlineHistoryT>, 8> inlineHistory;
612 std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{});
613
614 LLVM_DEBUG({
615 LDBG() << "* Inliner: Initial calls in SCC are: {";
616 for (unsigned i = 0, e = calls.size(); i < e; ++i)
617 LDBG() << " " << i << ". " << calls[i].call << ",";
618 LDBG() << "}";
619 });
620
621 // Try to inline each of the call operations. Don't cache the end iterator
622 // here as more calls may be added during inlining.
623 bool inlinedAnyCalls = false;
624 for (unsigned i = 0; i < calls.size(); ++i) {
625 if (deadNodes.contains(calls[i].sourceNode))
626 continue;
627 ResolvedCall it = calls[i];
628
629 InlineHistoryT inlineHistoryID = callHistory[i];
630 bool inHistory =
631 inlineHistoryIncludes(it.targetNode, inlineHistoryID, inlineHistory);
632 bool doInline = !inHistory && shouldInline(it);
633 CallOpInterface call = it.call;
634 LLVM_DEBUG({
635 if (doInline)
636 LDBG() << "* Inlining call: " << i << ". " << call;
637 else
638 LDBG() << "* Not inlining call: " << i << ". " << call;
639 });
640 if (!doInline)
641 continue;
642
643 unsigned prevSize = calls.size();
644 Region *targetRegion = it.targetNode->getCallableRegion();
645
646 // If this is the last call to the target node and the node is discardable,
647 // then inline it in-place and delete the node if successful.
648 bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode);
649
650 LogicalResult inlineResult =
651 inlineCall(inlinerIface, inliner.config.getCloneCallback(), call,
652 cast<CallableOpInterface>(targetRegion->getParentOp()),
653 targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace);
654 if (failed(inlineResult)) {
655 LDBG() << "** Failed to inline";
656 continue;
657 }
658 inlinedAnyCalls = true;
659
660 // Create a inline history entry for this inlined call, so that we remember
661 // that new callsites came about due to inlining Callee.
662 InlineHistoryT newInlineHistoryID{inlineHistory.size()};
663 inlineHistory.push_back(std::make_pair(it.targetNode, inlineHistoryID));
664
665 auto historyToString = [](InlineHistoryT h) {
666 return h.has_value() ? std::to_string(*h) : "root";
667 };
668 LDBG() << "* new inlineHistory entry: " << newInlineHistoryID << ". ["
669 << getNodeName(call) << ", " << historyToString(inlineHistoryID)
670 << "]";
671
672 for (unsigned k = prevSize; k != calls.size(); ++k) {
673 callHistory.push_back(newInlineHistoryID);
674 LDBG() << "* new call " << k << " {" << calls[k].call
675 << "}\n with historyID = " << newInlineHistoryID
676 << ", added due to inlining of\n call {" << call
677 << "}\n with historyID = " << historyToString(inlineHistoryID);
678 }
679
680 // If the inlining was successful, Merge the new uses into the source node.
681 useList.dropCallUses(it.sourceNode, call.getOperation(), cg);
682 useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode);
683
684 // then erase the call.
685 call.erase();
686
687 // If we inlined in place, mark the node for deletion.
688 if (inlineInPlace) {
689 useList.eraseNode(it.targetNode);
690 deadNodes.insert(it.targetNode);
691 }
692 }
693
694 for (CallGraphNode *node : deadNodes) {
695 currentSCC.remove(node);
696 inlinerIface.markForDeletion(node);
697 }
698 calls.clear();
699 return success(inlinedAnyCalls);
700}
701
702/// Returns true if the given call should be inlined.
703bool Inliner::Impl::shouldInline(ResolvedCall &resolvedCall) {
704 // Don't allow inlining terminator calls. We currently don't support this
705 // case.
706 if (resolvedCall.call->hasTrait<OpTrait::IsTerminator>())
707 return false;
708
709 // Don't allow inlining if the target is a self-recursive function.
710 // Don't allow inlining if the call graph is like A->B->A.
711 if (llvm::count_if(*resolvedCall.targetNode,
712 [&](CallGraphNode::Edge const &edge) -> bool {
713 return edge.getTarget() == resolvedCall.targetNode ||
714 edge.getTarget() == resolvedCall.sourceNode;
715 }) > 0)
716 return false;
717
718 // Don't allow inlining if the target is an ancestor of the call. This
719 // prevents inlining recursively.
720 Region *callableRegion = resolvedCall.targetNode->getCallableRegion();
721 if (callableRegion->isAncestor(resolvedCall.call->getParentRegion()))
722 return false;
723
724 // Don't allow inlining if the callee has multiple blocks (unstructured
725 // control flow) but we cannot be sure that the caller region supports that.
726 if (!inliner.config.getCanHandleMultipleBlocks()) {
727 bool calleeHasMultipleBlocks =
728 llvm::hasNItemsOrMore(*callableRegion, /*N=*/2);
729 // If both parent ops have the same type, it is safe to inline. Otherwise,
730 // decide based on whether the op has the SingleBlock trait or not.
731 // Note: This check does currently not account for
732 // SizedRegion/MaxSizedRegion.
733 auto callerRegionSupportsMultipleBlocks = [&]() {
734 return callableRegion->getParentOp()->getName() ==
735 resolvedCall.call->getParentOp()->getName() ||
736 !resolvedCall.call->getParentOp()
737 ->mightHaveTrait<OpTrait::SingleBlock>();
738 };
739 if (calleeHasMultipleBlocks && !callerRegionSupportsMultipleBlocks())
740 return false;
741 }
742
743 if (!inliner.isProfitableToInline(resolvedCall))
744 return false;
745
746 // Otherwise, inline.
747 return true;
748}
749
750LogicalResult Inliner::doInlining() {
751 Impl impl(*this);
752 auto *context = op->getContext();
753 // Run the inline transform in post-order over the SCCs in the callgraph.
754 SymbolTableCollection symbolTable;
755 // FIXME: some clean-up can be done for the arguments
756 // of the Impl's methods, if the inlinerIface and useList
757 // become the states of the Impl.
758 InlinerInterfaceImpl inlinerIface(context, cg, symbolTable);
759 CGUseList useList(op, cg, symbolTable);
760 LogicalResult result = runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) {
761 return impl.inlineSCC(inlinerIface, useList, scc, context);
762 });
763 if (failed(result))
764 return result;
765
766 // After inlining, make sure to erase any callables proven to be dead.
767 inlinerIface.eraseDeadCallables();
768 return success();
769}
770} // namespace mlir
return success()
lhs
static void collectCallOps(iterator_range< Region::iterator > blocks, CallGraphNode *sourceNode, CallGraph &cg, SymbolTableCollection &symbolTable, SmallVectorImpl< ResolvedCall > &calls, bool traverseNestedCGNodes)
Collect all of the callable operations within the given range of blocks.
Definition Inliner.cpp:303
Inliner::ResolvedCall ResolvedCall
Definition Inliner.cpp:30
static void walkReferencedSymbolNodes(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable, DenseMap< Attribute, CallGraphNode * > &resolvedRefs, function_ref< void(CallGraphNode *, Operation *)> callback)
Walk all of the used symbol callgraph nodes referenced with the given op.
Definition Inliner.cpp:37
static std::string getNodeName(CallOpInterface op)
Definition Inliner.cpp:351
static bool inlineHistoryIncludes(CallGraphNode *node, std::optional< size_t > inlineHistoryID, MutableArrayRef< std::pair< CallGraphNode *, std::optional< size_t > > > inlineHistory)
Return true if the specified inlineHistoryID indicates an inline history that already includes node.
Definition Inliner.cpp:359
static LogicalResult runTransformOnCGSCCs(const CallGraph &cg, function_ref< LogicalResult(CallGraphSCC &)> sccTransformer)
Run a given transformation over the SCCs of the callgraph in a bottom up traversal.
Definition Inliner.cpp:284
Block represents an ordered list of Operations.
Definition Block.h:33
This class represents a single callable in the callgraph.
Definition CallGraph.h:40
bool isExternal() const
Returns true if this node is an external node.
Definition CallGraph.cpp:32
bool hasChildren() const
Returns true if this node has any child edges.
Definition CallGraph.cpp:59
Region * getCallableRegion() const
Returns the callable region this node represents.
Definition CallGraph.cpp:36
iterator begin() const
Definition CallGraph.h:111
CallGraphNode * resolveCallable(CallOpInterface call, SymbolTableCollection &symbolTable) const
Resolve the callable for given callee to a node in the callgraph, or the external node if a valid nod...
CallGraphNode * lookupNode(Region *region) const
Lookup a call graph node for the given region, or nullptr if none is registered.
LogicalResult inlineSCC(InlinerInterfaceImpl &inlinerIface, CGUseList &useList, CallGraphSCC &currentSCC, MLIRContext *context)
Attempt to inline calls within the given scc, and run simplifications, until a fixed point is reached...
Definition Inliner.cpp:466
Impl(Inliner &inliner)
Definition Inliner.cpp:424
Inliner(Operation *op, CallGraph &cg, Pass &pass, AnalysisManager am, RunPipelineHelperTy runPipelineHelper, const InlinerConfig &config, ProfitabilityCallbackTy isProfitableToInline)
Definition Inliner.h:127
LogicalResult doInlining()
Perform inlining on a OpTrait::SymbolTable operation.
Definition Inliner.cpp:750
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
unsigned getNumThreads()
Return the number of threads used by the thread pool in this context.
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
bool use_empty()
Returns true if this operation has no uses.
Definition Operation.h:852
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition Operation.h:849
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
void erase()
Remove this operation from its parent block and delete it.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition Region.cpp:45
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition Region.h:222
iterator_range< OpIterator > getOps()
Definition Region.h:172
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class represents a specific symbol use.
static void walkSymbolTables(Operation *op, bool allSymUsesVisible, function_ref< void(Operation *, bool)> callback)
Walks all symbol table operations nested within, and including, op.
static std::optional< UseRange > getSymbolUses(Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
LogicalResult failableParallelForEach(MLIRContext *context, IteratorT begin, IteratorT end, FuncT &&func)
Invoke the given function on the elements between [begin, end) asynchronously.
Definition Threading.h:36
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
static std::string debugString(T &&op)
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
LogicalResult inlineCall(InlinerInterface &interface, function_ref< InlinerInterface::CloneCallbackSigTy > cloneCallback, CallOpInterface call, CallableOpInterface callable, Region *src, bool shouldCloneInlinedRegion=true)
This function inlines a given region, 'src', of a callable operation, 'callable', into the location d...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
A callable is either a symbol, or an SSA value, that is referenced by a call-like operation.
This struct represents a resolved call to a given callgraph node.
Definition Inliner.h:109
CallGraphNode * sourceNode
Definition Inliner.h:114
CallOpInterface call
Definition Inliner.h:113
CallGraphNode * targetNode
Definition Inliner.h:114