25 #include "llvm/ADT/SCCIterator.h"
26 #include "llvm/Support/Debug.h"
29 #define GEN_PASS_DEF_INLINER
30 #include "mlir/Transforms/Passes.h.inc"
33 #define DEBUG_TYPE "inlining"
52 assert(symbolUses &&
"expected uses to be valid");
56 auto refIt = resolvedRefs.insert({use.getSymbolRef(),
nullptr});
64 auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp);
67 node = cg.
lookupNode(callableOp.getCallableRegion());
70 callback(node, use.getUser());
120 void decrementDiscardableUses(CGUser &uses);
137 : symbolTable(symbolTable) {
142 auto walkFn = [&](
Operation *symbolTableOp,
bool allUsesVisible) {
145 if (
auto callable = dyn_cast<CallableOpInterface>(&op)) {
146 if (
auto *node = cg.
lookupNode(callable.getCallableRegion())) {
147 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
148 if (symbol && (allUsesVisible || symbol.isPrivate()) &&
149 symbol.canDiscardOnUseEmpty()) {
150 discardableSymNodeUses.try_emplace(node, 0);
160 SymbolTable::walkSymbolTables(op, !op->
getBlock(),
164 for (
auto &it : alwaysLiveNodes)
165 discardableSymNodeUses.erase(it.second);
169 recomputeUses(node, cg);
174 auto &userRefs = nodeUses[userNode].innerUses;
176 auto parentIt = userRefs.find(node);
177 if (parentIt == userRefs.end())
180 --discardableSymNodeUses[node];
188 for (
auto &edge : *node)
190 eraseNode(edge.getTarget());
193 auto useIt = nodeUses.find(node);
194 assert(useIt != nodeUses.end() &&
"expected node to be valid");
195 decrementDiscardableUses(useIt->getSecond());
196 nodeUses.erase(useIt);
197 discardableSymNodeUses.erase(node);
203 if (!isa<SymbolOpInterface>(nodeOp))
207 auto symbolIt = discardableSymNodeUses.find(node);
208 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0;
211 bool CGUseList::hasOneUseAndDiscardable(
CallGraphNode *node)
const {
214 if (!isa<SymbolOpInterface>(nodeOp))
218 auto symbolIt = discardableSymNodeUses.find(node);
219 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1;
224 CGUser &uses = nodeUses[node];
225 decrementDiscardableUses(uses);
231 auto discardSymIt = discardableSymNodeUses.find(refNode);
232 if (discardSymIt == discardableSymNodeUses.end())
235 if (user != parentOp)
236 ++uses.innerUses[refNode];
237 else if (!uses.topLevelUses.insert(refNode).second)
239 ++discardSymIt->second;
245 auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs];
246 for (
auto &useIt : lhsUses.innerUses) {
247 rhsUses.innerUses[useIt.first] += useIt.second;
248 discardableSymNodeUses[useIt.first] += useIt.second;
252 void CGUseList::decrementDiscardableUses(CGUser &uses) {
254 --discardableSymNodeUses[node];
255 for (
auto &it : uses.innerUses)
256 discardableSymNodeUses[it.first] -= it.second;
267 CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator)
268 : parentIterator(parentIterator) {}
270 std::vector<CallGraphNode *>::iterator begin() {
return nodes.begin(); }
271 std::vector<CallGraphNode *>::iterator end() {
return nodes.end(); }
274 void reset(
const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; }
278 auto it = llvm::find(nodes, node);
279 if (it != nodes.end()) {
281 parentIterator.ReplaceNode(node,
nullptr);
286 std::vector<CallGraphNode *> nodes;
287 llvm::scc_iterator<const CallGraph *> &parentIterator;
296 llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg);
297 CallGraphSCC currentSCC(cgi);
298 while (!cgi.isAtEnd()) {
301 currentSCC.reset(*cgi);
303 if (
failed(sccTransformer(currentSCC)))
314 struct ResolvedCall {
315 ResolvedCall(CallOpInterface call,
CallGraphNode *sourceNode,
317 : call(call), sourceNode(sourceNode), targetNode(targetNode) {}
318 CallOpInterface call;
330 bool traverseNestedCGNodes) {
334 for (
Block &block : blocks)
335 worklist.emplace_back(&block, node);
338 addToWorklist(sourceNode, blocks);
339 while (!worklist.empty()) {
341 std::tie(block, sourceNode) = worklist.pop_back_val();
344 if (
auto call = dyn_cast<CallOpInterface>(op)) {
347 if (SymbolRefAttr symRef = dyn_cast<SymbolRefAttr>(callable)) {
348 if (!isa<FlatSymbolRefAttr>(symRef))
354 calls.emplace_back(call, sourceNode, targetNode);
363 if (traverseNestedCGNodes || !nestedNode)
364 addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion);
376 if (llvm::dyn_cast_if_present<SymbolRefAttr>(op.getCallableForCallee()))
378 return "_unnamed_callee_";
388 while (inlineHistoryID.has_value()) {
389 assert(*inlineHistoryID < inlineHistory.size() &&
390 "Invalid inline history ID");
391 if (inlineHistory[*inlineHistoryID].first == node)
393 inlineHistoryID = inlineHistory[*inlineHistoryID].second;
411 Region *region = inlinedBlocks.
begin()->getParent();
414 assert(region &&
"expected valid parent node");
422 void markForDeletion(
CallGraphNode *node) { deadNodes.insert(node); }
426 void eraseDeadCallables() {
454 Region *callableRegion = resolvedCall.targetNode->getCallableRegion();
455 if (callableRegion->
isAncestor(resolvedCall.call->getParentRegion()))
460 bool calleeHasMultipleBlocks =
461 llvm::hasNItemsOrMore(*callableRegion, 2);
465 auto callerRegionSupportsMultipleBlocks = [&]() {
467 resolvedCall.call->getParentOp()->getName() ||
468 !resolvedCall.call->getParentOp()
471 if (calleeHasMultipleBlocks && !callerRegionSupportsMultipleBlocks())
481 CallGraphSCC ¤tSCC) {
483 auto &calls = inliner.calls;
486 llvm::SmallSetVector<CallGraphNode *, 1> deadNodes;
496 if (useList.isDead(node)) {
497 deadNodes.insert(node);
507 using InlineHistoryT = std::optional<size_t>;
509 std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{});
512 llvm::dbgs() <<
"* Inliner: Initial calls in SCC are: {\n";
513 for (
unsigned i = 0, e = calls.size(); i < e; ++i)
514 llvm::dbgs() <<
" " << i <<
". " << calls[i].call <<
",\n";
515 llvm::dbgs() <<
"}\n";
520 bool inlinedAnyCalls =
false;
521 for (
unsigned i = 0; i < calls.size(); ++i) {
522 if (deadNodes.contains(calls[i].sourceNode))
524 ResolvedCall it = calls[i];
526 InlineHistoryT inlineHistoryID = callHistory[i];
530 CallOpInterface call = it.call;
533 llvm::dbgs() <<
"* Inlining call: " << i <<
". " << call <<
"\n";
535 llvm::dbgs() <<
"* Not inlining call: " << i <<
". " << call <<
"\n";
540 unsigned prevSize = calls.size();
541 Region *targetRegion = it.targetNode->getCallableRegion();
545 bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode);
548 inliner, call, cast<CallableOpInterface>(targetRegion->
getParentOp()),
549 targetRegion, !inlineInPlace);
550 if (
failed(inlineResult)) {
551 LLVM_DEBUG(llvm::dbgs() <<
"** Failed to inline\n");
554 inlinedAnyCalls =
true;
558 InlineHistoryT newInlineHistoryID{inlineHistory.size()};
559 inlineHistory.push_back(std::make_pair(it.targetNode, inlineHistoryID));
561 auto historyToString = [](InlineHistoryT h) {
562 return h.has_value() ? std::to_string(*h) :
"root";
564 (void)historyToString;
565 LLVM_DEBUG(llvm::dbgs()
566 <<
"* new inlineHistory entry: " << newInlineHistoryID <<
". ["
567 <<
getNodeName(call) <<
", " << historyToString(inlineHistoryID)
570 for (
unsigned k = prevSize; k != calls.size(); ++k) {
571 callHistory.push_back(newInlineHistoryID);
572 LLVM_DEBUG(llvm::dbgs() <<
"* new call " << k <<
" {" << calls[i].call
573 <<
"}\n with historyID = " << newInlineHistoryID
574 <<
", added due to inlining of\n call {" << call
575 <<
"}\n with historyID = "
576 << historyToString(inlineHistoryID) <<
"\n");
580 useList.dropCallUses(it.sourceNode, call.getOperation(), cg);
581 useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode);
588 useList.eraseNode(it.targetNode);
589 deadNodes.insert(it.targetNode);
594 currentSCC.remove(node);
595 inliner.markForDeletion(node);
598 return success(inlinedAnyCalls);
606 class InlinerPass :
public impl::InlinerBase<InlinerPass> {
609 InlinerPass(
const InlinerPass &) =
default;
610 InlinerPass(std::function<
void(
OpPassManager &)> defaultPipeline);
611 InlinerPass(std::function<
void(
OpPassManager &)> defaultPipeline,
612 llvm::StringMap<OpPassManager> opPipelines);
613 void runOnOperation()
override;
620 LogicalResult inlineSCC(Inliner &inliner, CGUseList &useList,
638 llvm::StringMap<OpPassManager> &pipelines);
658 InlinerPass::InlinerPass(
660 : defaultPipeline(std::move(defaultPipelineArg)) {
661 opPipelines.push_back({});
664 InlinerPass::InlinerPass(std::function<
void(
OpPassManager &)> defaultPipeline,
665 llvm::StringMap<OpPassManager> opPipelines)
666 : InlinerPass(std::move(defaultPipeline)) {
667 if (opPipelines.empty())
671 for (
auto &it : opPipelines)
672 opPipelineList.addValue(it.second);
673 this->opPipelines.emplace_back(std::move(opPipelines));
676 void InlinerPass::runOnOperation() {
677 CallGraph &cg = getAnalysis<CallGraph>();
684 op->
emitOpError() <<
" was scheduled to run under the inliner, but does "
685 "not define a symbol table";
686 return signalPassFailure();
691 Inliner inliner(context, cg, symbolTable);
692 CGUseList useList(getOperation(), cg, symbolTable);
694 return inlineSCC(inliner, useList, scc, context);
697 return signalPassFailure();
700 inliner.eraseDeadCallables();
703 LogicalResult InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList,
704 CallGraphSCC ¤tSCC,
709 unsigned iterationCount = 0;
711 if (
failed(optimizeSCC(inliner.cg, useList, currentSCC, context)))
715 }
while (++iterationCount < maxInliningIterations);
720 CallGraphSCC ¤tSCC,
724 for (
auto *node : currentSCC) {
739 nodesToVisit.push_back(node);
741 if (nodesToVisit.empty())
745 if (
failed(optimizeSCCAsync(nodesToVisit, context)))
750 useList.recomputeUses(node, cg);
763 if (opPipelines.size() < numThreads) {
766 opPipelines.reserve(numThreads);
767 opPipelines.resize(numThreads, opPipelines.front());
776 std::vector<std::atomic<bool>> activePMs(opPipelines.size());
777 std::fill(activePMs.begin(), activePMs.end(),
false);
780 auto it = llvm::find_if(activePMs, [](std::atomic<bool> &isActive) {
781 bool expectedInactive =
false;
782 return isActive.compare_exchange_strong(expectedInactive,
true);
784 assert(it != activePMs.end() &&
785 "could not find inactive pass manager for thread");
786 unsigned pmIndex = it - activePMs.begin();
789 LogicalResult result = optimizeCallable(node, opPipelines[pmIndex]);
792 activePMs[pmIndex].store(
false);
799 llvm::StringMap<OpPassManager> &pipelines) {
802 auto pipelineIt = pipelines.find(opName);
803 if (pipelineIt == pipelines.end()) {
805 if (!defaultPipeline)
809 defaultPipeline(defaultPM);
810 pipelineIt = pipelines.try_emplace(opName, std::move(defaultPM)).first;
812 return runPipeline(pipelineIt->second, callable);
821 if (!defaultPipelineStr.empty()) {
822 std::string defaultPipelineCopy = defaultPipelineStr;
826 }
else if (defaultPipelineStr.getNumOccurrences()) {
827 defaultPipeline =
nullptr;
831 llvm::StringMap<OpPassManager> pipelines;
833 if (!pipeline.empty())
834 pipelines.try_emplace(pipeline.getOpAnchorName(), pipeline);
835 opPipelines.assign({std::move(pipelines)});
841 return std::make_unique<InlinerPass>();
843 std::unique_ptr<Pass>
846 std::move(opPipelines));
849 llvm::StringMap<OpPassManager> opPipelines,
850 std::function<
void(
OpPassManager &)> defaultPipelineBuilder) {
851 return std::make_unique<InlinerPass>(std::move(defaultPipelineBuilder),
852 std::move(opPipelines));
static MLIRContext * getContext(OpFoldResult val)
static void collectCallOps(iterator_range< Region::iterator > blocks, CallGraphNode *sourceNode, CallGraph &cg, SymbolTableCollection &symbolTable, SmallVectorImpl< ResolvedCall > &calls, bool traverseNestedCGNodes)
Collect all of the callable operations within the given range of blocks.
static void walkReferencedSymbolNodes(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable, DenseMap< Attribute, CallGraphNode * > &resolvedRefs, function_ref< void(CallGraphNode *, Operation *)> callback)
Walk all of the used symbol callgraph nodes referenced with the given op.
static bool shouldInline(ResolvedCall &resolvedCall)
Returns true if the given call should be inlined.
static bool inlineHistoryIncludes(CallGraphNode *node, std::optional< size_t > inlineHistoryID, MutableArrayRef< std::pair< CallGraphNode *, std::optional< size_t >>> inlineHistory)
Return true if the specified inlineHistoryID indicates an inline history that already includes node.
static std::string getNodeName(CallOpInterface op)
static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList, CallGraphSCC ¤tSCC)
Attempt to inline calls within the given scc.
static LogicalResult runTransformOnCGSCCs(const CallGraph &cg, function_ref< LogicalResult(CallGraphSCC &)> sccTransformer)
Run a given transformation over the SCCs of the callgraph in a bottom up traversal.
static void defaultInlinerOptPipeline(OpPassManager &pm)
This function implements the default inliner optimization pipeline.
static llvm::ManagedStatic< PassManagerOptions > options
Block represents an ordered list of Operations.
This class represents a single callable in the callgraph.
bool isExternal() const
Returns true if this node is an external node.
bool hasChildren() const
Returns true if this node has any child edges.
Region * getCallableRegion() const
Returns the callable region this node represents.
CallGraphNode * resolveCallable(CallOpInterface call, SymbolTableCollection &symbolTable) const
Resolve the callable for given callee to a node in the callgraph, or the external node if a valid nod...
CallGraphNode * lookupNode(Region *region) const
Lookup a call graph node for the given region, or nullptr if none is registered.
This interface provides the hooks into the inlining interface.
virtual void processInlinedBlocks(iterator_range< Region::iterator > inlinedBlocks)
Process a set of blocks that have been inlined.
MLIRContext is the top-level object for a collection of MLIR operations.
unsigned getNumThreads()
Return the number of threads used by the thread pool in this context.
This class represents a pass manager that runs passes on either a specific operation type,...
void addPass(std::unique_ptr< Pass > pass)
Add the given pass to this pass manager.
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
A trait used to provide symbol table functionalities to a region operation.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
bool use_empty()
Returns true if this operation has no uses.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
bool hasOneUse()
Returns true if this operation has exactly one use.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Operation * getParentOp()
Return the parent operation this region is attached to.
This class represents a collection of SymbolTables.
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class represents a specific symbol use.
static std::optional< UseRange > getSymbolUses(Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult failableParallelForEach(MLIRContext *context, IteratorT begin, IteratorT end, FuncT &&func)
Invoke the given function on the elements between [begin, end) asynchronously.
static std::string debugString(T &&op)
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
std::unique_ptr< Pass > createInlinerPass()
Creates a pass which inlines calls and callable operations as defined by the CallGraph.
std::unique_ptr< Pass > createCanonicalizerPass()
Creates an instance of the Canonicalizer pass, configured with default settings (which can be overrid...
LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream=llvm::errs())
Parse the textual representation of a pass pipeline, adding the result to 'pm' on success.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
LogicalResult inlineCall(InlinerInterface &interface, CallOpInterface call, CallableOpInterface callable, Region *src, bool shouldCloneInlinedRegion=true)
This function inlines a given region, 'src', of a callable operation, 'callable', into the location d...
A callable is either a symbol, or an SSA value, that is referenced by a call-like operation.
This class represents an efficient way to signal success or failure.
This class provides APIs and verifiers for ops with regions having a single block.