23 #include "llvm/ADT/SCCIterator.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallPtrSet.h"
26 #include "llvm/Support/Debug.h"
28 #define DEBUG_TYPE "inlining"
44 assert(symbolUses &&
"expected uses to be valid");
48 auto refIt = resolvedRefs.insert({use.getSymbolRef(),
nullptr});
56 auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp);
59 node = cg.
lookupNode(callableOp.getCallableRegion());
62 callback(node, use.getUser());
113 void decrementDiscardableUses(CGUser &uses);
130 : symbolTable(symbolTable) {
135 auto walkFn = [&](
Operation *symbolTableOp,
bool allUsesVisible) {
138 if (
auto callable = dyn_cast<CallableOpInterface>(&op)) {
139 if (
auto *node = cg.
lookupNode(callable.getCallableRegion())) {
140 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
141 if (symbol && (allUsesVisible || symbol.isPrivate()) &&
142 symbol.canDiscardOnUseEmpty()) {
143 discardableSymNodeUses.try_emplace(node, 0);
153 SymbolTable::walkSymbolTables(op, !op->
getBlock(),
157 for (
auto &it : alwaysLiveNodes)
158 discardableSymNodeUses.erase(it.second);
162 recomputeUses(node, cg);
167 auto &userRefs = nodeUses[userNode].innerUses;
169 auto parentIt = userRefs.find(node);
170 if (parentIt == userRefs.end())
173 --discardableSymNodeUses[node];
181 for (
auto &edge : *node)
183 eraseNode(edge.getTarget());
186 auto useIt = nodeUses.find(node);
187 assert(useIt != nodeUses.end() &&
"expected node to be valid");
188 decrementDiscardableUses(useIt->getSecond());
189 nodeUses.erase(useIt);
190 discardableSymNodeUses.erase(node);
196 if (!isa<SymbolOpInterface>(nodeOp))
200 auto symbolIt = discardableSymNodeUses.find(node);
201 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0;
204 bool CGUseList::hasOneUseAndDiscardable(
CallGraphNode *node)
const {
207 if (!isa<SymbolOpInterface>(nodeOp))
211 auto symbolIt = discardableSymNodeUses.find(node);
212 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1;
217 CGUser &uses = nodeUses[node];
218 decrementDiscardableUses(uses);
224 auto discardSymIt = discardableSymNodeUses.find(refNode);
225 if (discardSymIt == discardableSymNodeUses.end())
228 if (user != parentOp)
229 ++uses.innerUses[refNode];
230 else if (!uses.topLevelUses.insert(refNode).second)
232 ++discardSymIt->second;
238 auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs];
239 for (
auto &useIt : lhsUses.innerUses) {
240 rhsUses.innerUses[useIt.first] += useIt.second;
241 discardableSymNodeUses[useIt.first] += useIt.second;
245 void CGUseList::decrementDiscardableUses(CGUser &uses) {
247 --discardableSymNodeUses[node];
248 for (
auto &it : uses.innerUses)
249 discardableSymNodeUses[it.first] -= it.second;
260 CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator)
261 : parentIterator(parentIterator) {}
263 std::vector<CallGraphNode *>::iterator begin() {
return nodes.begin(); }
264 std::vector<CallGraphNode *>::iterator end() {
return nodes.end(); }
267 void reset(
const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; }
271 auto it = llvm::find(nodes, node);
272 if (it != nodes.end()) {
274 parentIterator.ReplaceNode(node,
nullptr);
279 std::vector<CallGraphNode *> nodes;
280 llvm::scc_iterator<const CallGraph *> &parentIterator;
288 function_ref<LogicalResult(CallGraphSCC &)> sccTransformer) {
289 llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg);
290 CallGraphSCC currentSCC(cgi);
291 while (!cgi.isAtEnd()) {
294 currentSCC.reset(*cgi);
296 if (failed(sccTransformer(currentSCC)))
309 bool traverseNestedCGNodes) {
313 for (
Block &block : blocks)
314 worklist.emplace_back(&block, node);
317 addToWorklist(sourceNode, blocks);
318 while (!worklist.empty()) {
320 std::tie(block, sourceNode) = worklist.pop_back_val();
323 if (
auto call = dyn_cast<CallOpInterface>(op)) {
326 if (SymbolRefAttr symRef = dyn_cast<SymbolRefAttr>(callable)) {
327 if (!isa<FlatSymbolRefAttr>(symRef))
333 calls.emplace_back(call, sourceNode, targetNode);
342 if (traverseNestedCGNodes || !nestedNode)
343 addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion);
355 if (llvm::dyn_cast_if_present<SymbolRefAttr>(op.getCallableForCallee()))
357 return "_unnamed_callee_";
367 while (inlineHistoryID.has_value()) {
368 assert(*inlineHistoryID < inlineHistory.size() &&
369 "Invalid inline history ID");
370 if (inlineHistory[*inlineHistoryID].first == node)
372 inlineHistoryID = inlineHistory[*inlineHistoryID].second;
390 Region *region = inlinedBlocks.
begin()->getParent();
393 assert(region &&
"expected valid parent node");
401 void markForDeletion(
CallGraphNode *node) { deadNodes.insert(node); }
405 void eraseDeadCallables() {
434 LogicalResult
inlineSCC(InlinerInterfaceImpl &inlinerIface,
435 CGUseList &useList, CallGraphSCC ¤tSCC,
442 LogicalResult optimizeSCC(
CallGraph &cg, CGUseList &useList,
455 llvm::StringMap<OpPassManager> &pipelines);
459 LogicalResult inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
460 CGUseList &useList, CallGraphSCC ¤tSCC);
472 CallGraphSCC ¤tSCC,
477 unsigned iterationCount = 0;
479 if (failed(optimizeSCC(inlinerIface.cg, useList, currentSCC, context)))
481 if (failed(inlineCallsInSCC(inlinerIface, useList, currentSCC)))
487 LogicalResult Inliner::Impl::optimizeSCC(
CallGraph &cg, CGUseList &useList,
488 CallGraphSCC ¤tSCC,
492 for (
auto *node : currentSCC) {
507 nodesToVisit.push_back(node);
509 if (nodesToVisit.empty())
513 if (failed(optimizeSCCAsync(nodesToVisit, context)))
518 useList.recomputeUses(node, cg);
531 const auto &opPipelines = inliner.config.getOpPipelines();
532 if (pipelines.size() < numThreads) {
533 pipelines.reserve(numThreads);
534 pipelines.resize(numThreads, opPipelines);
543 std::vector<std::atomic<bool>> activePMs(pipelines.size());
544 std::fill(activePMs.begin(), activePMs.end(),
false);
547 auto it = llvm::find_if(activePMs, [](std::atomic<bool> &isActive) {
548 bool expectedInactive =
false;
549 return isActive.compare_exchange_strong(expectedInactive,
true);
551 assert(it != activePMs.end() &&
552 "could not find inactive pass manager for thread");
553 unsigned pmIndex = it - activePMs.begin();
556 LogicalResult result = optimizeCallable(node, pipelines[pmIndex]);
559 activePMs[pmIndex].store(
false);
566 llvm::StringMap<OpPassManager> &pipelines) {
569 auto pipelineIt = pipelines.find(opName);
570 const auto &defaultPipeline = inliner.config.getDefaultPipeline();
571 if (pipelineIt == pipelines.end()) {
573 if (!defaultPipeline)
577 defaultPipeline(defaultPM);
578 pipelineIt = pipelines.try_emplace(opName, std::move(defaultPM)).first;
580 return inliner.runPipelineHelper(inliner.pass, pipelineIt->second, callable);
586 Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
587 CGUseList &useList, CallGraphSCC ¤tSCC) {
589 auto &calls = inlinerIface.calls;
592 llvm::SmallSetVector<CallGraphNode *, 1> deadNodes;
602 if (useList.isDead(node)) {
603 deadNodes.insert(node);
606 inlinerIface.symbolTable, calls,
614 using InlineHistoryT = std::optional<size_t>;
616 std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{});
619 llvm::dbgs() <<
"* Inliner: Initial calls in SCC are: {\n";
620 for (
unsigned i = 0, e = calls.size(); i < e; ++i)
621 llvm::dbgs() <<
" " << i <<
". " << calls[i].call <<
",\n";
622 llvm::dbgs() <<
"}\n";
627 bool inlinedAnyCalls =
false;
628 for (
unsigned i = 0; i < calls.size(); ++i) {
629 if (deadNodes.contains(calls[i].sourceNode))
633 InlineHistoryT inlineHistoryID = callHistory[i];
636 bool doInline = !inHistory && shouldInline(it);
637 CallOpInterface call = it.
call;
640 llvm::dbgs() <<
"* Inlining call: " << i <<
". " << call <<
"\n";
642 llvm::dbgs() <<
"* Not inlining call: " << i <<
". " << call <<
"\n";
647 unsigned prevSize = calls.size();
652 bool inlineInPlace = useList.hasOneUseAndDiscardable(it.
targetNode);
654 LogicalResult inlineResult =
655 inlineCall(inlinerIface, inliner.config.getCloneCallback(), call,
656 cast<CallableOpInterface>(targetRegion->
getParentOp()),
657 targetRegion, !inlineInPlace);
658 if (failed(inlineResult)) {
659 LLVM_DEBUG(llvm::dbgs() <<
"** Failed to inline\n");
662 inlinedAnyCalls =
true;
666 InlineHistoryT newInlineHistoryID{inlineHistory.size()};
667 inlineHistory.push_back(std::make_pair(it.
targetNode, inlineHistoryID));
669 auto historyToString = [](InlineHistoryT h) {
670 return h.has_value() ? std::to_string(*h) :
"root";
672 (void)historyToString;
673 LLVM_DEBUG(llvm::dbgs()
674 <<
"* new inlineHistory entry: " << newInlineHistoryID <<
". ["
675 <<
getNodeName(call) <<
", " << historyToString(inlineHistoryID)
678 for (
unsigned k = prevSize; k != calls.size(); ++k) {
679 callHistory.push_back(newInlineHistoryID);
680 LLVM_DEBUG(llvm::dbgs() <<
"* new call " << k <<
" {" << calls[i].call
681 <<
"}\n with historyID = " << newInlineHistoryID
682 <<
", added due to inlining of\n call {" << call
683 <<
"}\n with historyID = "
684 << historyToString(inlineHistoryID) <<
"\n");
688 useList.dropCallUses(it.
sourceNode, call.getOperation(), cg);
702 currentSCC.remove(node);
703 inlinerIface.markForDeletion(node);
706 return success(inlinedAnyCalls);
710 bool Inliner::Impl::shouldInline(
ResolvedCall &resolvedCall) {
720 return edge.getTarget() == resolvedCall.targetNode ||
721 edge.getTarget() == resolvedCall.sourceNode;
728 if (callableRegion->
isAncestor(resolvedCall.
call->getParentRegion()))
733 if (!inliner.config.getCanHandleMultipleBlocks()) {
734 bool calleeHasMultipleBlocks =
735 llvm::hasNItemsOrMore(*callableRegion, 2);
740 auto callerRegionSupportsMultipleBlocks = [&]() {
742 resolvedCall.
call->getParentOp()->getName() ||
743 !resolvedCall.
call->getParentOp()
746 if (calleeHasMultipleBlocks && !callerRegionSupportsMultipleBlocks())
750 if (!inliner.isProfitableToInline(resolvedCall))
765 InlinerInterfaceImpl inlinerIface(context, cg, symbolTable);
766 CGUseList useList(op, cg, symbolTable);
768 return impl.inlineSCC(inlinerIface, useList, scc, context);
774 inlinerIface.eraseDeadCallables();
static void collectCallOps(iterator_range< Region::iterator > blocks, CallGraphNode *sourceNode, CallGraph &cg, SymbolTableCollection &symbolTable, SmallVectorImpl< ResolvedCall > &calls, bool traverseNestedCGNodes)
Collect all of the callable operations within the given range of blocks.
Inliner::ResolvedCall ResolvedCall
static void walkReferencedSymbolNodes(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable, DenseMap< Attribute, CallGraphNode * > &resolvedRefs, function_ref< void(CallGraphNode *, Operation *)> callback)
Walk all of the used symbol callgraph nodes referenced with the given op.
static bool inlineHistoryIncludes(CallGraphNode *node, std::optional< size_t > inlineHistoryID, MutableArrayRef< std::pair< CallGraphNode *, std::optional< size_t >>> inlineHistory)
Return true if the specified inlineHistoryID indicates an inline history that already includes node.
static std::string getNodeName(CallOpInterface op)
static LogicalResult runTransformOnCGSCCs(const CallGraph &cg, function_ref< LogicalResult(CallGraphSCC &)> sccTransformer)
Run a given transformation over the SCCs of the callgraph in a bottom up traversal.
Block represents an ordered list of Operations.
This class represents a directed edge between two nodes in the callgraph.
This class represents a single callable in the callgraph.
bool isExternal() const
Returns true if this node is an external node.
bool hasChildren() const
Returns true if this node has any child edges.
Region * getCallableRegion() const
Returns the callable region this node represents.
CallGraphNode * resolveCallable(CallOpInterface call, SymbolTableCollection &symbolTable) const
Resolve the callable for given callee to a node in the callgraph, or the external node if a valid nod...
CallGraphNode * lookupNode(Region *region) const
Lookup a call graph node for the given region, or nullptr if none is registered.
unsigned getMaxInliningIterations() const
This interface provides the hooks into the inlining interface.
virtual void processInlinedBlocks(iterator_range< Region::iterator > inlinedBlocks)
Process a set of blocks that have been inlined.
LogicalResult inlineSCC(InlinerInterfaceImpl &inlinerIface, CGUseList &useList, CallGraphSCC ¤tSCC, MLIRContext *context)
Attempt to inline calls within the given scc, and run simplifications, until a fixed point is reached...
This is an implementation of the inliner that operates bottom up over the Strongly Connected Componen...
LogicalResult doInlining()
Perform inlining on a OpTrait::SymbolTable operation.
MLIRContext is the top-level object for a collection of MLIR operations.
unsigned getNumThreads()
Return the number of threads used by the thread pool in this context.
This class represents a pass manager that runs passes on either a specific operation type,...
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
bool use_empty()
Returns true if this operation has no uses.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
bool hasOneUse()
Returns true if this operation has exactly one use.
MLIRContext * getContext()
Return the context this operation is associated with.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Operation * getParentOp()
Return the parent operation this region is attached to.
This class represents a collection of SymbolTables.
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class represents a specific symbol use.
static std::optional< UseRange > getSymbolUses(Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
Include the generated interface declarations.
LogicalResult failableParallelForEach(MLIRContext *context, IteratorT begin, IteratorT end, FuncT &&func)
Invoke the given function on the elements between [begin, end) asynchronously.
static std::string debugString(T &&op)
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
LogicalResult inlineCall(InlinerInterface &interface, function_ref< InlinerInterface::CloneCallbackSigTy > cloneCallback, CallOpInterface call, CallableOpInterface callable, Region *src, bool shouldCloneInlinedRegion=true)
This function inlines a given region, 'src', of a callable operation, 'callable', into the location d...
A callable is either a symbol, or an SSA value, that is referenced by a call-like operation.
This struct represents a resolved call to a given callgraph node.
CallGraphNode * sourceNode
CallGraphNode * targetNode
This class provides APIs and verifiers for ops with regions having a single block.