22#include "llvm/ADT/SCCIterator.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/Support/DebugLog.h"
26#define DEBUG_TYPE "inlining"
42 assert(symbolUses &&
"expected uses to be valid");
46 auto refIt = resolvedRefs.try_emplace(use.getSymbolRef());
54 auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp);
57 node = cg.
lookupNode(callableOp.getCallableRegion());
60 callback(node, use.getUser());
86 CGUseList(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable);
90 void dropCallUses(CallGraphNode *userNode, Operation *callOp, CallGraph &cg);
93 void eraseNode(CallGraphNode *node);
96 bool isDead(CallGraphNode *node)
const;
100 bool hasOneUseAndDiscardable(CallGraphNode *node)
const;
103 void recomputeUses(CallGraphNode *node, CallGraph &cg);
107 void mergeUsesAfterInlining(CallGraphNode *
lhs, CallGraphNode *
rhs);
111 void decrementDiscardableUses(CGUser &uses);
122 SymbolTableCollection &symbolTable;
128 : symbolTable(symbolTable) {
133 auto walkFn = [&](
Operation *symbolTableOp,
bool allUsesVisible) {
136 if (
auto callable = dyn_cast<CallableOpInterface>(&op)) {
137 if (
auto *node = cg.
lookupNode(callable.getCallableRegion())) {
138 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
139 if (symbol && (allUsesVisible || symbol.isPrivate()) &&
140 symbol.canDiscardOnUseEmpty()) {
141 discardableSymNodeUses.try_emplace(node, 0);
155 for (
auto &it : alwaysLiveNodes)
156 discardableSymNodeUses.erase(it.second);
160 recomputeUses(node, cg);
165 auto &userRefs = nodeUses[userNode].innerUses;
166 auto walkFn = [&](CallGraphNode *node, Operation *user) {
167 auto parentIt = userRefs.find(node);
168 if (parentIt == userRefs.end())
171 --discardableSymNodeUses[node];
177void CGUseList::eraseNode(CallGraphNode *node) {
179 for (
auto &edge : *node)
181 eraseNode(edge.getTarget());
184 auto useIt = nodeUses.find(node);
185 assert(useIt != nodeUses.end() &&
"expected node to be valid");
186 decrementDiscardableUses(useIt->getSecond());
187 nodeUses.erase(useIt);
188 discardableSymNodeUses.erase(node);
191bool CGUseList::isDead(CallGraphNode *node)
const {
194 if (!isa<SymbolOpInterface>(nodeOp))
198 auto symbolIt = discardableSymNodeUses.find(node);
199 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0;
202bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node)
const {
205 if (!isa<SymbolOpInterface>(nodeOp))
209 auto symbolIt = discardableSymNodeUses.find(node);
210 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1;
213void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) {
215 CGUser &uses = nodeUses[node];
216 decrementDiscardableUses(uses);
221 auto walkFn = [&](CallGraphNode *refNode, Operation *user) {
222 auto discardSymIt = discardableSymNodeUses.find(refNode);
223 if (discardSymIt == discardableSymNodeUses.end())
226 if (user != parentOp)
227 ++uses.innerUses[refNode];
228 else if (!uses.topLevelUses.insert(refNode).second)
230 ++discardSymIt->second;
235void CGUseList::mergeUsesAfterInlining(CallGraphNode *
lhs, CallGraphNode *
rhs) {
236 auto &lhsUses = nodeUses[
lhs], &rhsUses = nodeUses[
rhs];
237 for (
auto &useIt : lhsUses.innerUses) {
238 rhsUses.innerUses[useIt.first] += useIt.second;
239 discardableSymNodeUses[useIt.first] += useIt.second;
243void CGUseList::decrementDiscardableUses(CGUser &uses) {
244 for (CallGraphNode *node : uses.topLevelUses)
245 --discardableSymNodeUses[node];
246 for (
auto &it : uses.innerUses)
247 discardableSymNodeUses[it.first] -= it.second;
258 CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator)
259 : parentIterator(parentIterator) {}
261 std::vector<CallGraphNode *>::iterator begin() {
return nodes.begin(); }
262 std::vector<CallGraphNode *>::iterator end() {
return nodes.end(); }
265 void reset(
const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; }
268 void remove(CallGraphNode *node) {
269 auto it = llvm::find(nodes, node);
270 if (it != nodes.end()) {
272 parentIterator.ReplaceNode(node,
nullptr);
277 std::vector<CallGraphNode *> nodes;
278 llvm::scc_iterator<const CallGraph *> &parentIterator;
286 function_ref<LogicalResult(CallGraphSCC &)> sccTransformer) {
287 llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg);
288 CallGraphSCC currentSCC(cgi);
289 while (!cgi.isAtEnd()) {
292 currentSCC.reset(*cgi);
294 if (failed(sccTransformer(currentSCC)))
307 bool traverseNestedCGNodes) {
311 for (
Block &block : blocks)
312 worklist.emplace_back(&block, node);
315 addToWorklist(sourceNode, blocks);
316 while (!worklist.empty()) {
318 std::tie(block, sourceNode) = worklist.pop_back_val();
321 if (
auto call = dyn_cast<CallOpInterface>(op)) {
324 if (SymbolRefAttr symRef = dyn_cast<SymbolRefAttr>(callable)) {
325 if (!isa<FlatSymbolRefAttr>(symRef))
331 calls.emplace_back(call, sourceNode, targetNode);
340 if (traverseNestedCGNodes || !nestedNode)
341 addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion);
352 if (llvm::dyn_cast_if_present<SymbolRefAttr>(op.getCallableForCallee()))
354 return "_unnamed_callee_";
363 while (inlineHistoryID.has_value()) {
364 assert(*inlineHistoryID < inlineHistory.size() &&
365 "Invalid inline history ID");
366 if (inlineHistory[*inlineHistoryID].first == node)
368 inlineHistoryID = inlineHistory[*inlineHistoryID].second;
375struct InlinerInterfaceImpl :
public InlinerInterface {
376 InlinerInterfaceImpl(MLIRContext *context, CallGraph &cg,
377 SymbolTableCollection &symbolTable)
378 : InlinerInterface(context), cg(cg), symbolTable(symbolTable) {}
383 processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks)
final {
386 Region *region = inlinedBlocks.
begin()->getParent();
389 assert(region &&
"expected valid parent node");
397 void markForDeletion(CallGraphNode *node) { deadNodes.insert(node); }
401 void eraseDeadCallables() {
402 for (CallGraphNode *node : deadNodes)
407 SmallPtrSet<CallGraphNode *, 8> deadNodes;
410 SmallVector<ResolvedCall, 8> calls;
416 SymbolTableCollection &symbolTable;
430 LogicalResult inlineSCC(InlinerInterfaceImpl &inlinerIface,
431 CGUseList &useList, CallGraphSCC ¤tSCC,
438 LogicalResult optimizeSCC(
CallGraph &cg, CGUseList &useList,
451 llvm::StringMap<OpPassManager> &pipelines);
455 LogicalResult inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
456 CGUseList &useList, CallGraphSCC ¤tSCC);
468 CallGraphSCC ¤tSCC,
473 unsigned iterationCount = 0;
475 if (failed(optimizeSCC(inlinerIface.cg, useList, currentSCC, context)))
477 if (failed(inlineCallsInSCC(inlinerIface, useList, currentSCC)))
479 }
while (++iterationCount < inliner.config.getMaxInliningIterations());
483LogicalResult Inliner::Impl::optimizeSCC(
CallGraph &cg, CGUseList &useList,
484 CallGraphSCC ¤tSCC,
488 for (
auto *node : currentSCC) {
503 nodesToVisit.push_back(node);
505 if (nodesToVisit.empty())
509 if (failed(optimizeSCCAsync(nodesToVisit, context)))
514 useList.recomputeUses(node, cg);
527 const auto &opPipelines = inliner.config.getOpPipelines();
528 if (pipelines.size() < numThreads) {
529 pipelines.reserve(numThreads);
530 pipelines.resize(numThreads, opPipelines);
535 for (CallGraphNode *node : nodesToVisit)
539 std::vector<std::atomic<bool>> activePMs(pipelines.size());
540 llvm::fill(activePMs,
false);
543 auto it = llvm::find_if(activePMs, [](std::atomic<bool> &isActive) {
544 bool expectedInactive =
false;
545 return isActive.compare_exchange_strong(expectedInactive,
true);
547 assert(it != activePMs.end() &&
548 "could not find inactive pass manager for thread");
549 unsigned pmIndex = it - activePMs.begin();
552 LogicalResult
result = optimizeCallable(node, pipelines[pmIndex]);
555 activePMs[pmIndex].store(
false);
561Inliner::Impl::optimizeCallable(CallGraphNode *node,
562 llvm::StringMap<OpPassManager> &pipelines) {
565 auto pipelineIt = pipelines.find(opName);
566 const auto &defaultPipeline = inliner.config.getDefaultPipeline();
567 if (pipelineIt == pipelines.end()) {
569 if (!defaultPipeline)
572 OpPassManager defaultPM(opName);
573 defaultPipeline(defaultPM);
574 pipelineIt = pipelines.try_emplace(opName, std::move(defaultPM)).first;
576 return inliner.runPipelineHelper(inliner.pass, pipelineIt->second, callable);
582Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
583 CGUseList &useList, CallGraphSCC ¤tSCC) {
584 CallGraph &cg = inlinerIface.cg;
585 auto &calls = inlinerIface.calls;
588 llvm::SmallSetVector<CallGraphNode *, 1> deadNodes;
593 for (CallGraphNode *node : currentSCC) {
598 if (useList.isDead(node)) {
599 deadNodes.insert(node);
602 inlinerIface.symbolTable, calls,
610 using InlineHistoryT = std::optional<size_t>;
611 SmallVector<std::pair<CallGraphNode *, InlineHistoryT>, 8> inlineHistory;
612 std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{});
615 LDBG() <<
"* Inliner: Initial calls in SCC are: {";
616 for (
unsigned i = 0, e = calls.size(); i < e; ++i)
617 LDBG() <<
" " << i <<
". " << calls[i].call <<
",";
623 bool inlinedAnyCalls =
false;
624 for (
unsigned i = 0; i < calls.size(); ++i) {
625 if (deadNodes.contains(calls[i].sourceNode))
629 InlineHistoryT inlineHistoryID = callHistory[i];
632 bool doInline = !inHistory && shouldInline(it);
633 CallOpInterface call = it.
call;
636 LDBG() <<
"* Inlining call: " << i <<
". " << call;
638 LDBG() <<
"* Not inlining call: " << i <<
". " << call;
643 unsigned prevSize = calls.size();
648 bool inlineInPlace = useList.hasOneUseAndDiscardable(it.
targetNode);
650 LogicalResult inlineResult =
651 inlineCall(inlinerIface, inliner.config.getCloneCallback(), call,
652 cast<CallableOpInterface>(targetRegion->
getParentOp()),
653 targetRegion, !inlineInPlace);
654 if (
failed(inlineResult)) {
655 LDBG() <<
"** Failed to inline";
658 inlinedAnyCalls =
true;
662 InlineHistoryT newInlineHistoryID{inlineHistory.size()};
663 inlineHistory.push_back(std::make_pair(it.
targetNode, inlineHistoryID));
665 auto historyToString = [](InlineHistoryT h) {
666 return h.has_value() ? std::to_string(*h) :
"root";
668 LDBG() <<
"* new inlineHistory entry: " << newInlineHistoryID <<
". ["
669 <<
getNodeName(call) <<
", " << historyToString(inlineHistoryID)
672 for (
unsigned k = prevSize; k != calls.size(); ++k) {
673 callHistory.push_back(newInlineHistoryID);
674 LDBG() <<
"* new call " << k <<
" {" << calls[k].call
675 <<
"}\n with historyID = " << newInlineHistoryID
676 <<
", added due to inlining of\n call {" << call
677 <<
"}\n with historyID = " << historyToString(inlineHistoryID);
681 useList.dropCallUses(it.
sourceNode, call.getOperation(), cg);
694 for (CallGraphNode *node : deadNodes) {
695 currentSCC.remove(node);
696 inlinerIface.markForDeletion(node);
699 return success(inlinedAnyCalls);
703bool Inliner::Impl::shouldInline(
ResolvedCall &resolvedCall) {
706 if (resolvedCall.
call->hasTrait<OpTrait::IsTerminator>())
712 [&](CallGraphNode::Edge
const &edge) ->
bool {
713 return edge.getTarget() == resolvedCall.targetNode ||
714 edge.getTarget() == resolvedCall.sourceNode;
721 if (callableRegion->
isAncestor(resolvedCall.
call->getParentRegion()))
726 if (!inliner.config.getCanHandleMultipleBlocks()) {
727 bool calleeHasMultipleBlocks =
728 llvm::hasNItemsOrMore(*callableRegion, 2);
733 auto callerRegionSupportsMultipleBlocks = [&]() {
735 resolvedCall.
call->getParentOp()->getName() ||
736 !resolvedCall.
call->getParentOp()
737 ->mightHaveTrait<OpTrait::SingleBlock>();
739 if (calleeHasMultipleBlocks && !callerRegionSupportsMultipleBlocks())
743 if (!inliner.isProfitableToInline(resolvedCall))
752 auto *context = op->getContext();
758 InlinerInterfaceImpl inlinerIface(context, cg, symbolTable);
759 CGUseList useList(op, cg, symbolTable);
761 return impl.inlineSCC(inlinerIface, useList, scc, context);
767 inlinerIface.eraseDeadCallables();
static void collectCallOps(iterator_range< Region::iterator > blocks, CallGraphNode *sourceNode, CallGraph &cg, SymbolTableCollection &symbolTable, SmallVectorImpl< ResolvedCall > &calls, bool traverseNestedCGNodes)
Collect all of the callable operations within the given range of blocks.
Inliner::ResolvedCall ResolvedCall
static void walkReferencedSymbolNodes(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable, DenseMap< Attribute, CallGraphNode * > &resolvedRefs, function_ref< void(CallGraphNode *, Operation *)> callback)
Walk all of the used symbol callgraph nodes referenced with the given op.
static std::string getNodeName(CallOpInterface op)
static bool inlineHistoryIncludes(CallGraphNode *node, std::optional< size_t > inlineHistoryID, MutableArrayRef< std::pair< CallGraphNode *, std::optional< size_t > > > inlineHistory)
Return true if the specified inlineHistoryID indicates an inline history that already includes node.
static LogicalResult runTransformOnCGSCCs(const CallGraph &cg, function_ref< LogicalResult(CallGraphSCC &)> sccTransformer)
Run a given transformation over the SCCs of the callgraph in a bottom up traversal.
Block represents an ordered list of Operations.
This class represents a single callable in the callgraph.
bool isExternal() const
Returns true if this node is an external node.
bool hasChildren() const
Returns true if this node has any child edges.
Region * getCallableRegion() const
Returns the callable region this node represents.
CallGraphNode * resolveCallable(CallOpInterface call, SymbolTableCollection &symbolTable) const
Resolve the callable for given callee to a node in the callgraph, or the external node if a valid nod...
CallGraphNode * lookupNode(Region *region) const
Lookup a call graph node for the given region, or nullptr if none is registered.
LogicalResult inlineSCC(InlinerInterfaceImpl &inlinerIface, CGUseList &useList, CallGraphSCC ¤tSCC, MLIRContext *context)
Attempt to inline calls within the given scc, and run simplifications, until a fixed point is reached...
Inliner(Operation *op, CallGraph &cg, Pass &pass, AnalysisManager am, RunPipelineHelperTy runPipelineHelper, const InlinerConfig &config, ProfitabilityCallbackTy isProfitableToInline)
LogicalResult doInlining()
Perform inlining on a OpTrait::SymbolTable operation.
MLIRContext is the top-level object for a collection of MLIR operations.
unsigned getNumThreads()
Return the number of threads used by the thread pool in this context.
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
bool use_empty()
Returns true if this operation has no uses.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
bool hasOneUse()
Returns true if this operation has exactly one use.
Block * getBlock()
Returns the operation block that contains this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
void erase()
Remove this operation from its parent block and delete it.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
iterator_range< OpIterator > getOps()
Operation * getParentOp()
Return the parent operation this region is attached to.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class represents a specific symbol use.
static void walkSymbolTables(Operation *op, bool allSymUsesVisible, function_ref< void(Operation *, bool)> callback)
Walks all symbol table operations nested within, and including, op.
static std::optional< UseRange > getSymbolUses(Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
Include the generated interface declarations.
LogicalResult failableParallelForEach(MLIRContext *context, IteratorT begin, IteratorT end, FuncT &&func)
Invoke the given function on the elements between [begin, end) asynchronously.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
static std::string debugString(T &&op)
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
LogicalResult inlineCall(InlinerInterface &interface, function_ref< InlinerInterface::CloneCallbackSigTy > cloneCallback, CallOpInterface call, CallableOpInterface callable, Region *src, bool shouldCloneInlinedRegion=true)
This function inlines a given region, 'src', of a callable operation, 'callable', into the location d...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
llvm::function_ref< Fn > function_ref
A callable is either a symbol, or an SSA value, that is referenced by a call-like operation.
This struct represents a resolved call to a given callgraph node.
CallGraphNode * sourceNode
CallGraphNode * targetNode