23 #include "llvm/ADT/SCCIterator.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallPtrSet.h"
26 #include "llvm/Support/Debug.h"
28 #define DEBUG_TYPE "inlining"
44 assert(symbolUses &&
"expected uses to be valid");
48 auto refIt = resolvedRefs.insert({use.getSymbolRef(),
nullptr});
56 auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp);
59 node = cg.
lookupNode(callableOp.getCallableRegion());
62 callback(node, use.getUser());
112 void decrementDiscardableUses(CGUser &uses);
129 : symbolTable(symbolTable) {
134 auto walkFn = [&](
Operation *symbolTableOp,
bool allUsesVisible) {
137 if (
auto callable = dyn_cast<CallableOpInterface>(&op)) {
138 if (
auto *node = cg.
lookupNode(callable.getCallableRegion())) {
139 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
140 if (symbol && (allUsesVisible || symbol.isPrivate()) &&
141 symbol.canDiscardOnUseEmpty()) {
142 discardableSymNodeUses.try_emplace(node, 0);
152 SymbolTable::walkSymbolTables(op, !op->
getBlock(),
156 for (
auto &it : alwaysLiveNodes)
157 discardableSymNodeUses.erase(it.second);
161 recomputeUses(node, cg);
166 auto &userRefs = nodeUses[userNode].innerUses;
168 auto parentIt = userRefs.find(node);
169 if (parentIt == userRefs.end())
172 --discardableSymNodeUses[node];
180 for (
auto &edge : *node)
182 eraseNode(edge.getTarget());
185 auto useIt = nodeUses.find(node);
186 assert(useIt != nodeUses.end() &&
"expected node to be valid");
187 decrementDiscardableUses(useIt->getSecond());
188 nodeUses.erase(useIt);
189 discardableSymNodeUses.erase(node);
195 if (!isa<SymbolOpInterface>(nodeOp))
199 auto symbolIt = discardableSymNodeUses.find(node);
200 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0;
203 bool CGUseList::hasOneUseAndDiscardable(
CallGraphNode *node)
const {
206 if (!isa<SymbolOpInterface>(nodeOp))
210 auto symbolIt = discardableSymNodeUses.find(node);
211 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1;
216 CGUser &uses = nodeUses[node];
217 decrementDiscardableUses(uses);
223 auto discardSymIt = discardableSymNodeUses.find(refNode);
224 if (discardSymIt == discardableSymNodeUses.end())
227 if (user != parentOp)
228 ++uses.innerUses[refNode];
229 else if (!uses.topLevelUses.insert(refNode).second)
231 ++discardSymIt->second;
237 auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs];
238 for (
auto &useIt : lhsUses.innerUses) {
239 rhsUses.innerUses[useIt.first] += useIt.second;
240 discardableSymNodeUses[useIt.first] += useIt.second;
244 void CGUseList::decrementDiscardableUses(CGUser &uses) {
246 --discardableSymNodeUses[node];
247 for (
auto &it : uses.innerUses)
248 discardableSymNodeUses[it.first] -= it.second;
259 CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator)
260 : parentIterator(parentIterator) {}
262 std::vector<CallGraphNode *>::iterator begin() {
return nodes.begin(); }
263 std::vector<CallGraphNode *>::iterator end() {
return nodes.end(); }
266 void reset(
const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; }
270 auto it = llvm::find(nodes, node);
271 if (it != nodes.end()) {
273 parentIterator.ReplaceNode(node,
nullptr);
278 std::vector<CallGraphNode *> nodes;
279 llvm::scc_iterator<const CallGraph *> &parentIterator;
287 function_ref<LogicalResult(CallGraphSCC &)> sccTransformer) {
288 llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg);
289 CallGraphSCC currentSCC(cgi);
290 while (!cgi.isAtEnd()) {
293 currentSCC.reset(*cgi);
295 if (failed(sccTransformer(currentSCC)))
308 bool traverseNestedCGNodes) {
312 for (
Block &block : blocks)
313 worklist.emplace_back(&block, node);
316 addToWorklist(sourceNode, blocks);
317 while (!worklist.empty()) {
319 std::tie(block, sourceNode) = worklist.pop_back_val();
322 if (
auto call = dyn_cast<CallOpInterface>(op)) {
325 if (SymbolRefAttr symRef = dyn_cast<SymbolRefAttr>(callable)) {
326 if (!isa<FlatSymbolRefAttr>(symRef))
332 calls.emplace_back(call, sourceNode, targetNode);
341 if (traverseNestedCGNodes || !nestedNode)
342 addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion);
354 if (llvm::dyn_cast_if_present<SymbolRefAttr>(op.getCallableForCallee()))
356 return "_unnamed_callee_";
366 while (inlineHistoryID.has_value()) {
367 assert(*inlineHistoryID < inlineHistory.size() &&
368 "Invalid inline history ID");
369 if (inlineHistory[*inlineHistoryID].first == node)
371 inlineHistoryID = inlineHistory[*inlineHistoryID].second;
389 Region *region = inlinedBlocks.
begin()->getParent();
392 assert(region &&
"expected valid parent node");
400 void markForDeletion(
CallGraphNode *node) { deadNodes.insert(node); }
404 void eraseDeadCallables() {
433 LogicalResult
inlineSCC(InlinerInterfaceImpl &inlinerIface,
434 CGUseList &useList, CallGraphSCC ¤tSCC,
441 LogicalResult optimizeSCC(
CallGraph &cg, CGUseList &useList,
454 llvm::StringMap<OpPassManager> &pipelines);
458 LogicalResult inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
459 CGUseList &useList, CallGraphSCC ¤tSCC);
471 CallGraphSCC ¤tSCC,
476 unsigned iterationCount = 0;
478 if (failed(optimizeSCC(inlinerIface.cg, useList, currentSCC, context)))
480 if (failed(inlineCallsInSCC(inlinerIface, useList, currentSCC)))
486 LogicalResult Inliner::Impl::optimizeSCC(
CallGraph &cg, CGUseList &useList,
487 CallGraphSCC ¤tSCC,
491 for (
auto *node : currentSCC) {
506 nodesToVisit.push_back(node);
508 if (nodesToVisit.empty())
512 if (failed(optimizeSCCAsync(nodesToVisit, context)))
517 useList.recomputeUses(node, cg);
530 const auto &opPipelines = inliner.config.getOpPipelines();
531 if (pipelines.size() < numThreads) {
532 pipelines.reserve(numThreads);
533 pipelines.resize(numThreads, opPipelines);
542 std::vector<std::atomic<bool>> activePMs(pipelines.size());
543 std::fill(activePMs.begin(), activePMs.end(),
false);
546 auto it = llvm::find_if(activePMs, [](std::atomic<bool> &isActive) {
547 bool expectedInactive =
false;
548 return isActive.compare_exchange_strong(expectedInactive,
true);
550 assert(it != activePMs.end() &&
551 "could not find inactive pass manager for thread");
552 unsigned pmIndex = it - activePMs.begin();
555 LogicalResult result = optimizeCallable(node, pipelines[pmIndex]);
558 activePMs[pmIndex].store(
false);
565 llvm::StringMap<OpPassManager> &pipelines) {
568 auto pipelineIt = pipelines.find(opName);
569 const auto &defaultPipeline = inliner.config.getDefaultPipeline();
570 if (pipelineIt == pipelines.end()) {
572 if (!defaultPipeline)
576 defaultPipeline(defaultPM);
577 pipelineIt = pipelines.try_emplace(opName, std::move(defaultPM)).first;
579 return inliner.runPipelineHelper(inliner.pass, pipelineIt->second, callable);
585 Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
586 CGUseList &useList, CallGraphSCC ¤tSCC) {
588 auto &calls = inlinerIface.calls;
591 llvm::SmallSetVector<CallGraphNode *, 1> deadNodes;
601 if (useList.isDead(node)) {
602 deadNodes.insert(node);
605 inlinerIface.symbolTable, calls,
613 using InlineHistoryT = std::optional<size_t>;
615 std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{});
618 llvm::dbgs() <<
"* Inliner: Initial calls in SCC are: {\n";
619 for (
unsigned i = 0, e = calls.size(); i < e; ++i)
620 llvm::dbgs() <<
" " << i <<
". " << calls[i].call <<
",\n";
621 llvm::dbgs() <<
"}\n";
626 bool inlinedAnyCalls =
false;
627 for (
unsigned i = 0; i < calls.size(); ++i) {
628 if (deadNodes.contains(calls[i].sourceNode))
632 InlineHistoryT inlineHistoryID = callHistory[i];
635 bool doInline = !inHistory && shouldInline(it);
636 CallOpInterface call = it.
call;
639 llvm::dbgs() <<
"* Inlining call: " << i <<
". " << call <<
"\n";
641 llvm::dbgs() <<
"* Not inlining call: " << i <<
". " << call <<
"\n";
646 unsigned prevSize = calls.size();
651 bool inlineInPlace = useList.hasOneUseAndDiscardable(it.
targetNode);
653 LogicalResult inlineResult =
655 cast<CallableOpInterface>(targetRegion->
getParentOp()),
656 targetRegion, !inlineInPlace);
657 if (failed(inlineResult)) {
658 LLVM_DEBUG(llvm::dbgs() <<
"** Failed to inline\n");
661 inlinedAnyCalls =
true;
665 InlineHistoryT newInlineHistoryID{inlineHistory.size()};
666 inlineHistory.push_back(std::make_pair(it.
targetNode, inlineHistoryID));
668 auto historyToString = [](InlineHistoryT h) {
669 return h.has_value() ? std::to_string(*h) :
"root";
671 (void)historyToString;
672 LLVM_DEBUG(llvm::dbgs()
673 <<
"* new inlineHistory entry: " << newInlineHistoryID <<
". ["
674 <<
getNodeName(call) <<
", " << historyToString(inlineHistoryID)
677 for (
unsigned k = prevSize; k != calls.size(); ++k) {
678 callHistory.push_back(newInlineHistoryID);
679 LLVM_DEBUG(llvm::dbgs() <<
"* new call " << k <<
" {" << calls[i].call
680 <<
"}\n with historyID = " << newInlineHistoryID
681 <<
", added due to inlining of\n call {" << call
682 <<
"}\n with historyID = "
683 << historyToString(inlineHistoryID) <<
"\n");
687 useList.dropCallUses(it.
sourceNode, call.getOperation(), cg);
701 currentSCC.remove(node);
702 inlinerIface.markForDeletion(node);
705 return success(inlinedAnyCalls);
709 bool Inliner::Impl::shouldInline(
ResolvedCall &resolvedCall) {
718 return edge.getTarget() == resolvedCall.targetNode;
725 if (callableRegion->
isAncestor(resolvedCall.
call->getParentRegion()))
730 bool calleeHasMultipleBlocks =
731 llvm::hasNItemsOrMore(*callableRegion, 2);
735 auto callerRegionSupportsMultipleBlocks = [&]() {
737 resolvedCall.
call->getParentOp()->getName() ||
738 !resolvedCall.
call->getParentOp()
741 if (calleeHasMultipleBlocks && !callerRegionSupportsMultipleBlocks())
744 if (!inliner.isProfitableToInline(resolvedCall))
759 InlinerInterfaceImpl inlinerIface(context, cg, symbolTable);
760 CGUseList useList(op, cg, symbolTable);
762 return impl.inlineSCC(inlinerIface, useList, scc, context);
768 inlinerIface.eraseDeadCallables();
static void collectCallOps(iterator_range< Region::iterator > blocks, CallGraphNode *sourceNode, CallGraph &cg, SymbolTableCollection &symbolTable, SmallVectorImpl< ResolvedCall > &calls, bool traverseNestedCGNodes)
Collect all of the callable operations within the given range of blocks.
Inliner::ResolvedCall ResolvedCall
static void walkReferencedSymbolNodes(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable, DenseMap< Attribute, CallGraphNode * > &resolvedRefs, function_ref< void(CallGraphNode *, Operation *)> callback)
Walk all of the used symbol callgraph nodes referenced with the given op.
static bool inlineHistoryIncludes(CallGraphNode *node, std::optional< size_t > inlineHistoryID, MutableArrayRef< std::pair< CallGraphNode *, std::optional< size_t >>> inlineHistory)
Return true if the specified inlineHistoryID indicates an inline history that already includes node.
static std::string getNodeName(CallOpInterface op)
static LogicalResult runTransformOnCGSCCs(const CallGraph &cg, function_ref< LogicalResult(CallGraphSCC &)> sccTransformer)
Run a given transformation over the SCCs of the callgraph in a bottom up traversal.
Block represents an ordered list of Operations.
This class represents a directed edge between two nodes in the callgraph.
This class represents a single callable in the callgraph.
bool isExternal() const
Returns true if this node is an external node.
bool hasChildren() const
Returns true if this node has any child edges.
Region * getCallableRegion() const
Returns the callable region this node represents.
CallGraphNode * resolveCallable(CallOpInterface call, SymbolTableCollection &symbolTable) const
Resolve the callable for given callee to a node in the callgraph, or the external node if a valid nod...
CallGraphNode * lookupNode(Region *region) const
Lookup a call graph node for the given region, or nullptr if none is registered.
unsigned getMaxInliningIterations() const
This interface provides the hooks into the inlining interface.
virtual void processInlinedBlocks(iterator_range< Region::iterator > inlinedBlocks)
Process a set of blocks that have been inlined.
LogicalResult inlineSCC(InlinerInterfaceImpl &inlinerIface, CGUseList &useList, CallGraphSCC ¤tSCC, MLIRContext *context)
Attempt to inline calls within the given scc, and run simplifications, until a fixed point is reached...
This is an implementation of the inliner that operates bottom up over the Strongly Connected Componen...
LogicalResult doInlining()
Perform inlining on a OpTrait::SymbolTable operation.
MLIRContext is the top-level object for a collection of MLIR operations.
unsigned getNumThreads()
Return the number of threads used by the thread pool in this context.
This class represents a pass manager that runs passes on either a specific operation type,...
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
bool use_empty()
Returns true if this operation has no uses.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
bool hasOneUse()
Returns true if this operation has exactly one use.
MLIRContext * getContext()
Return the context this operation is associated with.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Operation * getParentOp()
Return the parent operation this region is attached to.
This class represents a collection of SymbolTables.
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class represents a specific symbol use.
static std::optional< UseRange > getSymbolUses(Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
Include the generated interface declarations.
LogicalResult failableParallelForEach(MLIRContext *context, IteratorT begin, IteratorT end, FuncT &&func)
Invoke the given function on the elements between [begin, end) asynchronously.
static std::string debugString(T &&op)
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
LogicalResult inlineCall(InlinerInterface &interface, CallOpInterface call, CallableOpInterface callable, Region *src, bool shouldCloneInlinedRegion=true)
This function inlines a given region, 'src', of a callable operation, 'callable', into the location d...
A callable is either a symbol, or an SSA value, that is referenced by a call-like operation.
This struct represents a resolved call to a given callgraph node.
CallGraphNode * sourceNode
CallGraphNode * targetNode
This class provides APIs and verifiers for ops with regions having a single block.