22 #include "llvm/ADT/SCCIterator.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/Support/DebugLog.h"
26 #define DEBUG_TYPE "inlining"
42 assert(symbolUses &&
"expected uses to be valid");
46 auto refIt = resolvedRefs.try_emplace(use.getSymbolRef());
54 auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp);
57 node = cg.
lookupNode(callableOp.getCallableRegion());
60 callback(node, use.getUser());
111 void decrementDiscardableUses(CGUser &uses);
128 : symbolTable(symbolTable) {
133 auto walkFn = [&](
Operation *symbolTableOp,
bool allUsesVisible) {
136 if (
auto callable = dyn_cast<CallableOpInterface>(&op)) {
137 if (
auto *node = cg.
lookupNode(callable.getCallableRegion())) {
138 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
139 if (symbol && (allUsesVisible || symbol.isPrivate()) &&
140 symbol.canDiscardOnUseEmpty()) {
141 discardableSymNodeUses.try_emplace(node, 0);
151 SymbolTable::walkSymbolTables(op, !op->
getBlock(),
155 for (
auto &it : alwaysLiveNodes)
156 discardableSymNodeUses.erase(it.second);
160 recomputeUses(node, cg);
165 auto &userRefs = nodeUses[userNode].innerUses;
167 auto parentIt = userRefs.find(node);
168 if (parentIt == userRefs.end())
171 --discardableSymNodeUses[node];
179 for (
auto &edge : *node)
181 eraseNode(edge.getTarget());
184 auto useIt = nodeUses.find(node);
185 assert(useIt != nodeUses.end() &&
"expected node to be valid");
186 decrementDiscardableUses(useIt->getSecond());
187 nodeUses.erase(useIt);
188 discardableSymNodeUses.erase(node);
194 if (!isa<SymbolOpInterface>(nodeOp))
198 auto symbolIt = discardableSymNodeUses.find(node);
199 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0;
202 bool CGUseList::hasOneUseAndDiscardable(
CallGraphNode *node)
const {
205 if (!isa<SymbolOpInterface>(nodeOp))
209 auto symbolIt = discardableSymNodeUses.find(node);
210 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1;
215 CGUser &uses = nodeUses[node];
216 decrementDiscardableUses(uses);
222 auto discardSymIt = discardableSymNodeUses.find(refNode);
223 if (discardSymIt == discardableSymNodeUses.end())
226 if (user != parentOp)
227 ++uses.innerUses[refNode];
228 else if (!uses.topLevelUses.insert(refNode).second)
230 ++discardSymIt->second;
236 auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs];
237 for (
auto &useIt : lhsUses.innerUses) {
238 rhsUses.innerUses[useIt.first] += useIt.second;
239 discardableSymNodeUses[useIt.first] += useIt.second;
243 void CGUseList::decrementDiscardableUses(CGUser &uses) {
245 --discardableSymNodeUses[node];
246 for (
auto &it : uses.innerUses)
247 discardableSymNodeUses[it.first] -= it.second;
258 CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator)
259 : parentIterator(parentIterator) {}
261 std::vector<CallGraphNode *>::iterator begin() {
return nodes.begin(); }
262 std::vector<CallGraphNode *>::iterator end() {
return nodes.end(); }
265 void reset(
const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; }
269 auto it = llvm::find(nodes, node);
270 if (it != nodes.end()) {
272 parentIterator.ReplaceNode(node,
nullptr);
277 std::vector<CallGraphNode *> nodes;
278 llvm::scc_iterator<const CallGraph *> &parentIterator;
286 function_ref<LogicalResult(CallGraphSCC &)> sccTransformer) {
287 llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg);
288 CallGraphSCC currentSCC(cgi);
289 while (!cgi.isAtEnd()) {
292 currentSCC.reset(*cgi);
294 if (
failed(sccTransformer(currentSCC)))
307 bool traverseNestedCGNodes) {
311 for (
Block &block : blocks)
312 worklist.emplace_back(&block, node);
315 addToWorklist(sourceNode, blocks);
316 while (!worklist.empty()) {
318 std::tie(block, sourceNode) = worklist.pop_back_val();
321 if (
auto call = dyn_cast<CallOpInterface>(op)) {
324 if (SymbolRefAttr symRef = dyn_cast<SymbolRefAttr>(callable)) {
325 if (!isa<FlatSymbolRefAttr>(symRef))
331 calls.emplace_back(call, sourceNode, targetNode);
340 if (traverseNestedCGNodes || !nestedNode)
341 addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion);
352 if (llvm::dyn_cast_if_present<SymbolRefAttr>(op.getCallableForCallee()))
354 return "_unnamed_callee_";
363 while (inlineHistoryID.has_value()) {
364 assert(*inlineHistoryID < inlineHistory.size() &&
365 "Invalid inline history ID");
366 if (inlineHistory[*inlineHistoryID].first == node)
368 inlineHistoryID = inlineHistory[*inlineHistoryID].second;
386 Region *region = inlinedBlocks.
begin()->getParent();
389 assert(region &&
"expected valid parent node");
397 void markForDeletion(
CallGraphNode *node) { deadNodes.insert(node); }
401 void eraseDeadCallables() {
430 LogicalResult
inlineSCC(InlinerInterfaceImpl &inlinerIface,
431 CGUseList &useList, CallGraphSCC ¤tSCC,
438 LogicalResult optimizeSCC(
CallGraph &cg, CGUseList &useList,
451 llvm::StringMap<OpPassManager> &pipelines);
455 LogicalResult inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
456 CGUseList &useList, CallGraphSCC ¤tSCC);
468 CallGraphSCC ¤tSCC,
473 unsigned iterationCount = 0;
475 if (
failed(optimizeSCC(inlinerIface.cg, useList, currentSCC, context)))
477 if (
failed(inlineCallsInSCC(inlinerIface, useList, currentSCC)))
483 LogicalResult Inliner::Impl::optimizeSCC(
CallGraph &cg, CGUseList &useList,
484 CallGraphSCC ¤tSCC,
488 for (
auto *node : currentSCC) {
503 nodesToVisit.push_back(node);
505 if (nodesToVisit.empty())
509 if (
failed(optimizeSCCAsync(nodesToVisit, context)))
514 useList.recomputeUses(node, cg);
527 const auto &opPipelines = inliner.config.getOpPipelines();
528 if (pipelines.size() < numThreads) {
529 pipelines.reserve(numThreads);
530 pipelines.resize(numThreads, opPipelines);
539 std::vector<std::atomic<bool>> activePMs(pipelines.size());
540 llvm::fill(activePMs,
false);
543 auto it = llvm::find_if(activePMs, [](std::atomic<bool> &isActive) {
544 bool expectedInactive =
false;
545 return isActive.compare_exchange_strong(expectedInactive,
true);
547 assert(it != activePMs.end() &&
548 "could not find inactive pass manager for thread");
549 unsigned pmIndex = it - activePMs.begin();
552 LogicalResult result = optimizeCallable(node, pipelines[pmIndex]);
555 activePMs[pmIndex].store(
false);
562 llvm::StringMap<OpPassManager> &pipelines) {
565 auto pipelineIt = pipelines.find(opName);
566 const auto &defaultPipeline = inliner.config.getDefaultPipeline();
567 if (pipelineIt == pipelines.end()) {
569 if (!defaultPipeline)
573 defaultPipeline(defaultPM);
574 pipelineIt = pipelines.try_emplace(opName, std::move(defaultPM)).first;
576 return inliner.runPipelineHelper(inliner.pass, pipelineIt->second, callable);
582 Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
583 CGUseList &useList, CallGraphSCC ¤tSCC) {
585 auto &calls = inlinerIface.calls;
588 llvm::SmallSetVector<CallGraphNode *, 1> deadNodes;
598 if (useList.isDead(node)) {
599 deadNodes.insert(node);
602 inlinerIface.symbolTable, calls,
610 using InlineHistoryT = std::optional<size_t>;
612 std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{});
615 LDBG() <<
"* Inliner: Initial calls in SCC are: {";
616 for (
unsigned i = 0, e = calls.size(); i < e; ++i)
617 LDBG() <<
" " << i <<
". " << calls[i].call <<
",";
623 bool inlinedAnyCalls =
false;
624 for (
unsigned i = 0; i < calls.size(); ++i) {
625 if (deadNodes.contains(calls[i].sourceNode))
629 InlineHistoryT inlineHistoryID = callHistory[i];
632 bool doInline = !inHistory && shouldInline(it);
633 CallOpInterface call = it.
call;
636 LDBG() <<
"* Inlining call: " << i <<
". " << call;
638 LDBG() <<
"* Not inlining call: " << i <<
". " << call;
643 unsigned prevSize = calls.size();
648 bool inlineInPlace = useList.hasOneUseAndDiscardable(it.
targetNode);
650 LogicalResult inlineResult =
651 inlineCall(inlinerIface, inliner.config.getCloneCallback(), call,
652 cast<CallableOpInterface>(targetRegion->
getParentOp()),
653 targetRegion, !inlineInPlace);
654 if (
failed(inlineResult)) {
655 LDBG() <<
"** Failed to inline";
658 inlinedAnyCalls =
true;
662 InlineHistoryT newInlineHistoryID{inlineHistory.size()};
663 inlineHistory.push_back(std::make_pair(it.
targetNode, inlineHistoryID));
665 auto historyToString = [](InlineHistoryT h) {
666 return h.has_value() ? std::to_string(*h) :
"root";
668 LDBG() <<
"* new inlineHistory entry: " << newInlineHistoryID <<
". ["
669 <<
getNodeName(call) <<
", " << historyToString(inlineHistoryID)
672 for (
unsigned k = prevSize; k != calls.size(); ++k) {
673 callHistory.push_back(newInlineHistoryID);
674 LDBG() <<
"* new call " << k <<
" {" << calls[k].call
675 <<
"}\n with historyID = " << newInlineHistoryID
676 <<
", added due to inlining of\n call {" << call
677 <<
"}\n with historyID = " << historyToString(inlineHistoryID);
681 useList.dropCallUses(it.
sourceNode, call.getOperation(), cg);
695 currentSCC.remove(node);
696 inlinerIface.markForDeletion(node);
699 return success(inlinedAnyCalls);
703 bool Inliner::Impl::shouldInline(
ResolvedCall &resolvedCall) {
713 return edge.getTarget() == resolvedCall.targetNode ||
714 edge.getTarget() == resolvedCall.sourceNode;
721 if (callableRegion->
isAncestor(resolvedCall.
call->getParentRegion()))
726 if (!inliner.config.getCanHandleMultipleBlocks()) {
727 bool calleeHasMultipleBlocks =
728 llvm::hasNItemsOrMore(*callableRegion, 2);
733 auto callerRegionSupportsMultipleBlocks = [&]() {
735 resolvedCall.
call->getParentOp()->getName() ||
736 !resolvedCall.
call->getParentOp()
739 if (calleeHasMultipleBlocks && !callerRegionSupportsMultipleBlocks())
743 if (!inliner.isProfitableToInline(resolvedCall))
758 InlinerInterfaceImpl inlinerIface(context, cg, symbolTable);
759 CGUseList useList(op, cg, symbolTable);
761 return impl.inlineSCC(inlinerIface, useList, scc, context);
767 inlinerIface.eraseDeadCallables();
static void collectCallOps(iterator_range< Region::iterator > blocks, CallGraphNode *sourceNode, CallGraph &cg, SymbolTableCollection &symbolTable, SmallVectorImpl< ResolvedCall > &calls, bool traverseNestedCGNodes)
Collect all of the callable operations within the given range of blocks.
Inliner::ResolvedCall ResolvedCall
static void walkReferencedSymbolNodes(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable, DenseMap< Attribute, CallGraphNode * > &resolvedRefs, function_ref< void(CallGraphNode *, Operation *)> callback)
Walk all of the used symbol callgraph nodes referenced with the given op.
static bool inlineHistoryIncludes(CallGraphNode *node, std::optional< size_t > inlineHistoryID, MutableArrayRef< std::pair< CallGraphNode *, std::optional< size_t >>> inlineHistory)
Return true if the specified inlineHistoryID indicates an inline history that already includes node.
static std::string getNodeName(CallOpInterface op)
static LogicalResult runTransformOnCGSCCs(const CallGraph &cg, function_ref< LogicalResult(CallGraphSCC &)> sccTransformer)
Run a given transformation over the SCCs of the callgraph in a bottom up traversal.
Block represents an ordered list of Operations.
This class represents a directed edge between two nodes in the callgraph.
This class represents a single callable in the callgraph.
bool isExternal() const
Returns true if this node is an external node.
bool hasChildren() const
Returns true if this node has any child edges.
Region * getCallableRegion() const
Returns the callable region this node represents.
CallGraphNode * resolveCallable(CallOpInterface call, SymbolTableCollection &symbolTable) const
Resolve the callable for given callee to a node in the callgraph, or the external node if a valid nod...
CallGraphNode * lookupNode(Region *region) const
Lookup a call graph node for the given region, or nullptr if none is registered.
unsigned getMaxInliningIterations() const
This interface provides the hooks into the inlining interface.
virtual void processInlinedBlocks(iterator_range< Region::iterator > inlinedBlocks)
Process a set of blocks that have been inlined.
LogicalResult inlineSCC(InlinerInterfaceImpl &inlinerIface, CGUseList &useList, CallGraphSCC ¤tSCC, MLIRContext *context)
Attempt to inline calls within the given scc, and run simplifications, until a fixed point is reached...
This is an implementation of the inliner that operates bottom up over the Strongly Connected Componen...
LogicalResult doInlining()
Perform inlining on a OpTrait::SymbolTable operation.
MLIRContext is the top-level object for a collection of MLIR operations.
unsigned getNumThreads()
Return the number of threads used by the thread pool in this context.
This class represents a pass manager that runs passes on either a specific operation type,...
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
bool use_empty()
Returns true if this operation has no uses.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
bool hasOneUse()
Returns true if this operation has exactly one use.
MLIRContext * getContext()
Return the context this operation is associated with.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Operation * getParentOp()
Return the parent operation this region is attached to.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class represents a specific symbol use.
static std::optional< UseRange > getSymbolUses(Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
Include the generated interface declarations.
LogicalResult failableParallelForEach(MLIRContext *context, IteratorT begin, IteratorT end, FuncT &&func)
Invoke the given function on the elements between [begin, end) asynchronously.
static std::string debugString(T &&op)
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
LogicalResult inlineCall(InlinerInterface &interface, function_ref< InlinerInterface::CloneCallbackSigTy > cloneCallback, CallOpInterface call, CallableOpInterface callable, Region *src, bool shouldCloneInlinedRegion=true)
This function inlines a given region, 'src', of a callable operation, 'callable', into the location d...
A callable is either a symbol, or an SSA value, that is referenced by a call-like operation.
This struct represents a resolved call to a given callgraph node.
CallGraphNode * sourceNode
CallGraphNode * targetNode
This class provides APIs and verifiers for ops with regions having a single block.