16 #include "PassDetail.h" 25 #include "llvm/ADT/SCCIterator.h" 26 #include "llvm/Support/Debug.h" 28 #define DEBUG_TYPE "inlining" 47 assert(symbolUses &&
"expected uses to be valid");
51 auto refIt = resolvedRefs.insert({use.getSymbolRef(),
nullptr});
59 auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp);
62 node = cg.
lookupNode(callableOp.getCallableRegion());
65 callback(node, use.getUser());
115 void decrementDiscardableUses(CGUser &uses);
132 : symbolTable(symbolTable) {
137 auto walkFn = [&](
Operation *symbolTableOp,
bool allUsesVisible) {
140 if (
auto callable = dyn_cast<CallableOpInterface>(&op)) {
141 if (
auto *node = cg.
lookupNode(callable.getCallableRegion())) {
142 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
143 if (symbol && (allUsesVisible || symbol.isPrivate()) &&
144 symbol.canDiscardOnUseEmpty()) {
145 discardableSymNodeUses.try_emplace(node, 0);
159 for (
auto &it : alwaysLiveNodes)
160 discardableSymNodeUses.erase(it.second);
164 recomputeUses(node, cg);
169 auto &userRefs = nodeUses[userNode].innerUses;
171 auto parentIt = userRefs.find(node);
172 if (parentIt == userRefs.end())
175 --discardableSymNodeUses[node];
183 for (
auto &edge : *node)
185 eraseNode(edge.getTarget());
188 auto useIt = nodeUses.find(node);
189 assert(useIt != nodeUses.end() &&
"expected node to be valid");
190 decrementDiscardableUses(useIt->getSecond());
191 nodeUses.erase(useIt);
192 discardableSymNodeUses.erase(node);
198 if (!isa<SymbolOpInterface>(nodeOp))
199 return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->
use_empty();
202 auto symbolIt = discardableSymNodeUses.find(node);
203 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0;
206 bool CGUseList::hasOneUseAndDiscardable(
CallGraphNode *node)
const {
209 if (!isa<SymbolOpInterface>(nodeOp))
210 return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->
hasOneUse();
213 auto symbolIt = discardableSymNodeUses.find(node);
214 return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1;
219 CGUser &uses = nodeUses[node];
220 decrementDiscardableUses(uses);
226 auto discardSymIt = discardableSymNodeUses.find(refNode);
227 if (discardSymIt == discardableSymNodeUses.end())
230 if (user != parentOp)
231 ++uses.innerUses[refNode];
232 else if (!uses.topLevelUses.insert(refNode).second)
234 ++discardSymIt->second;
240 auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs];
241 for (
auto &useIt : lhsUses.innerUses) {
242 rhsUses.innerUses[useIt.first] += useIt.second;
243 discardableSymNodeUses[useIt.first] += useIt.second;
247 void CGUseList::decrementDiscardableUses(CGUser &uses) {
249 --discardableSymNodeUses[node];
250 for (
auto &it : uses.innerUses)
251 discardableSymNodeUses[it.first] -= it.second;
262 CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator)
263 : parentIterator(parentIterator) {}
265 std::vector<CallGraphNode *>::iterator begin() {
return nodes.begin(); }
266 std::vector<CallGraphNode *>::iterator end() {
return nodes.end(); }
269 void reset(
const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; }
273 auto it = llvm::find(nodes, node);
274 if (it != nodes.end()) {
276 parentIterator.ReplaceNode(node,
nullptr);
281 std::vector<CallGraphNode *> nodes;
282 llvm::scc_iterator<const CallGraph *> &parentIterator;
291 llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg);
292 CallGraphSCC currentSCC(cgi);
293 while (!cgi.isAtEnd()) {
296 currentSCC.reset(*cgi);
298 if (
failed(sccTransformer(currentSCC)))
309 struct ResolvedCall {
310 ResolvedCall(CallOpInterface call,
CallGraphNode *sourceNode,
312 : call(call), sourceNode(sourceNode), targetNode(targetNode) {}
313 CallOpInterface call;
325 bool traverseNestedCGNodes) {
329 for (
Block &block : blocks)
330 worklist.emplace_back(&block, node);
333 addToWorklist(sourceNode, blocks);
334 while (!worklist.empty()) {
336 std::tie(block, sourceNode) = worklist.pop_back_val();
339 if (
auto call = dyn_cast<CallOpInterface>(op)) {
342 if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) {
349 calls.emplace_back(call, sourceNode, targetNode);
358 if (traverseNestedCGNodes || !nestedNode)
359 addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion);
371 if (
auto sym = op.getCallableForCallee().dyn_cast<SymbolRefAttr>())
373 return "_unnamed_callee_";
383 while (inlineHistoryID.has_value()) {
384 assert(inlineHistoryID.value() < inlineHistory.size() &&
385 "Invalid inline history ID");
386 if (inlineHistory[inlineHistoryID.value()].first == node)
388 inlineHistoryID = inlineHistory[inlineHistoryID.value()].second;
406 Region *region = inlinedBlocks.begin()->getParent();
409 assert(region &&
"expected valid parent node");
417 void markForDeletion(
CallGraphNode *node) { deadNodes.insert(node); }
421 void eraseDeadCallables() {
449 if (resolvedCall.targetNode->getCallableRegion()->isAncestor(
450 resolvedCall.call->getParentRegion()))
460 CallGraphSCC ¤tSCC) {
462 auto &calls = inliner.calls;
465 llvm::SmallSetVector<CallGraphNode *, 1> deadNodes;
475 if (useList.isDead(node)) {
476 deadNodes.insert(node);
488 std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{});
491 llvm::dbgs() <<
"* Inliner: Initial calls in SCC are: {\n";
492 for (
unsigned i = 0, e = calls.size(); i < e; ++i)
493 llvm::dbgs() <<
" " << i <<
". " << calls[i].call <<
",\n";
494 llvm::dbgs() <<
"}\n";
499 bool inlinedAnyCalls =
false;
500 for (
unsigned i = 0; i < calls.size(); ++i) {
501 if (deadNodes.contains(calls[i].sourceNode))
503 ResolvedCall it = calls[i];
505 InlineHistoryT inlineHistoryID = callHistory[i];
509 CallOpInterface call = it.call;
512 llvm::dbgs() <<
"* Inlining call: " << i <<
". " << call <<
"\n";
514 llvm::dbgs() <<
"* Not inlining call: " << i <<
". " << call <<
"\n";
519 unsigned prevSize = calls.size();
520 Region *targetRegion = it.targetNode->getCallableRegion();
524 bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode);
527 inliner, call, cast<CallableOpInterface>(targetRegion->
getParentOp()),
528 targetRegion, !inlineInPlace);
529 if (
failed(inlineResult)) {
530 LLVM_DEBUG(llvm::dbgs() <<
"** Failed to inline\n");
533 inlinedAnyCalls =
true;
537 InlineHistoryT newInlineHistoryID{inlineHistory.size()};
538 inlineHistory.push_back(std::make_pair(it.targetNode, inlineHistoryID));
540 auto historyToString = [](InlineHistoryT h) {
541 return h.has_value() ? std::to_string(h.value()) :
"root";
543 (
void)historyToString;
544 LLVM_DEBUG(llvm::dbgs()
545 <<
"* new inlineHistory entry: " << newInlineHistoryID <<
". [" 546 <<
getNodeName(call) <<
", " << historyToString(inlineHistoryID)
549 for (
unsigned k = prevSize; k != calls.size(); ++k) {
550 callHistory.push_back(newInlineHistoryID);
551 LLVM_DEBUG(llvm::dbgs() <<
"* new call " << k <<
" {" << calls[i].call
552 <<
"}\n with historyID = " << newInlineHistoryID
553 <<
", added due to inlining of\n call {" << call
554 <<
"}\n with historyID = " 555 << historyToString(inlineHistoryID) <<
"\n");
559 useList.dropCallUses(it.sourceNode, call.getOperation(), cg);
560 useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode);
567 useList.eraseNode(it.targetNode);
568 deadNodes.insert(it.targetNode);
573 currentSCC.remove(node);
574 inliner.markForDeletion(node);
577 return success(inlinedAnyCalls);
585 class InlinerPass :
public InlinerBase<InlinerPass> {
588 InlinerPass(
const InlinerPass &) =
default;
589 InlinerPass(std::function<
void(
OpPassManager &)> defaultPipeline);
590 InlinerPass(std::function<
void(
OpPassManager &)> defaultPipeline,
591 llvm::StringMap<OpPassManager> opPipelines);
592 void runOnOperation()
override;
599 LogicalResult inlineSCC(Inliner &inliner, CGUseList &useList,
617 llvm::StringMap<OpPassManager> &pipelines);
627 std::function<void(OpPassManager &)> defaultPipeline;
637 InlinerPass::InlinerPass(std::function<
void(
OpPassManager &)> defaultPipeline)
638 : defaultPipeline(std::move(defaultPipeline)) {
639 opPipelines.push_back({});
642 if (defaultPipeline) {
644 defaultPipeline(fakePM);
645 llvm::raw_string_ostream strStream(defaultPipelineStr);
650 InlinerPass::InlinerPass(std::function<
void(
OpPassManager &)> defaultPipeline,
651 llvm::StringMap<OpPassManager> opPipelines)
652 : InlinerPass(std::move(defaultPipeline)) {
653 if (opPipelines.empty())
657 for (
auto &it : opPipelines)
658 opPipelineList.addValue(it.second);
659 this->opPipelines.emplace_back(std::move(opPipelines));
662 void InlinerPass::runOnOperation() {
663 CallGraph &cg = getAnalysis<CallGraph>();
664 auto *context = &getContext();
670 op->
emitOpError() <<
" was scheduled to run under the inliner, but does " 671 "not define a symbol table";
672 return signalPassFailure();
677 Inliner inliner(context, cg, symbolTable);
678 CGUseList useList(getOperation(), cg, symbolTable);
680 return inlineSCC(inliner, useList, scc, context);
683 return signalPassFailure();
686 inliner.eraseDeadCallables();
689 LogicalResult InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList,
690 CallGraphSCC ¤tSCC,
695 unsigned iterationCount = 0;
697 if (
failed(optimizeSCC(inliner.cg, useList, currentSCC, context)))
701 }
while (++iterationCount < maxInliningIterations);
706 CallGraphSCC ¤tSCC,
710 for (
auto *node : currentSCC) {
725 nodesToVisit.push_back(node);
727 if (nodesToVisit.empty())
731 if (
failed(optimizeSCCAsync(nodesToVisit, context)))
736 useList.recomputeUses(node, cg);
749 if (opPipelines.size() < numThreads) {
752 opPipelines.reserve(numThreads);
753 opPipelines.resize(numThreads, opPipelines.front());
762 std::vector<std::atomic<bool>> activePMs(opPipelines.size());
763 std::fill(activePMs.begin(), activePMs.end(),
false);
766 auto it = llvm::find_if(activePMs, [](std::atomic<bool> &isActive) {
767 bool expectedInactive =
false;
768 return isActive.compare_exchange_strong(expectedInactive,
true);
770 assert(it != activePMs.end() &&
771 "could not find inactive pass manager for thread");
772 unsigned pmIndex = it - activePMs.begin();
775 LogicalResult result = optimizeCallable(node, opPipelines[pmIndex]);
778 activePMs[pmIndex].store(
false);
785 llvm::StringMap<OpPassManager> &pipelines) {
788 auto pipelineIt = pipelines.find(opName);
789 if (pipelineIt == pipelines.end()) {
791 if (!defaultPipeline)
795 defaultPipeline(defaultPM);
796 pipelineIt = pipelines.try_emplace(opName, std::move(defaultPM)).first;
798 return runPipeline(pipelineIt->second, callable);
807 if (!defaultPipelineStr.empty()) {
808 std::string defaultPipelineCopy = defaultPipelineStr;
812 }
else if (defaultPipelineStr.getNumOccurrences()) {
813 defaultPipeline =
nullptr;
817 llvm::StringMap<OpPassManager> pipelines;
819 if (!pipeline.empty())
820 pipelines.try_emplace(pipeline.getOpAnchorName(), pipeline);
821 opPipelines.assign({std::move(pipelines)});
827 return std::make_unique<InlinerPass>();
829 std::unique_ptr<Pass>
832 std::move(opPipelines));
835 llvm::StringMap<OpPassManager> opPipelines,
836 std::function<
void(
OpPassManager &)> defaultPipelineBuilder) {
837 return std::make_unique<InlinerPass>(std::move(defaultPipelineBuilder),
838 std::move(opPipelines));
unsigned getNumThreads()
Return the number of threads used by the thread pool in this context.
TODO: Remove this file when SCCP and integer range analysis have been ported to the new framework...
This class contains a list of basic blocks and a link to the parent operation it is attached to...
CallGraphNode * lookupNode(Region *region) const
Lookup a call graph node for the given region, or nullptr if none is registered.
static LogicalResult runTransformOnCGSCCs(const CallGraph &cg, function_ref< LogicalResult(CallGraphSCC &)> sccTransformer)
Run a given transformation over the SCCs of the callgraph in a bottom up traversal.
Operation is a basic unit of execution within MLIR.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Block represents an ordered list of Operations.
static void defaultInlinerOptPipeline(OpPassManager &pm)
This function implements the default inliner optimization pipeline.
A symbol reference with a reference path containing a single element.
bool hasOneUse()
Returns true if this operation has exactly one use.
static Optional< UseRange > getSymbolUses(Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
This class provides the API for ops that are known to be terminators.
static void walkReferencedSymbolNodes(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable, DenseMap< Attribute, CallGraphNode *> &resolvedRefs, function_ref< void(CallGraphNode *, Operation *)> callback)
Walk all of the used symbol callgraph nodes referenced with the given op.
Block * getBlock()
Returns the operation block that contains this operation.
CallGraphNode * resolveCallable(CallOpInterface call, SymbolTableCollection &symbolTable) const
Resolve the callable for given callee to a node in the callgraph, or the external node if a valid nod...
void erase()
Remove this operation from its parent block and delete it.
static bool inlineHistoryIncludes(CallGraphNode *node, Optional< size_t > inlineHistoryID, MutableArrayRef< std::pair< CallGraphNode *, Optional< size_t >>> inlineHistory)
Return true if the specified inlineHistoryID indicates an inline history that already includes node...
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of...
std::unique_ptr< Pass > createInlinerPass()
Creates a pass which inlines calls and callable operations as defined by the CallGraph.
LogicalResult failableParallelForEach(MLIRContext *context, IteratorT begin, IteratorT end, FuncT &&func)
Invoke the given function on the elements between [begin, end) asynchronously.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
LogicalResult inlineCall(InlinerInterface &interface, CallOpInterface call, CallableOpInterface callable, Region *src, bool shouldCloneInlinedRegion=true)
This function inlines a given region, 'src', of a callable operation, 'callable', into the location d...
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
This class represents an efficient way to signal success or failure.
static std::string getNodeName(CallOpInterface op)
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
This class represents a collection of SymbolTables.
iterator_range< OpIterator > getOps()
virtual LogicalResult initializeOptions(StringRef options)
Attempt to initialize the options of this pass from the given string.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
A trait used to provide symbol table functionalities to a region operation.
A callable is either a symbol, or an SSA value, that is referenced by a call-like operation...
static void walkSymbolTables(Operation *op, bool allSymUsesVisible, function_ref< void(Operation *, bool)> callback)
Walks all symbol table operations nested within, and including, op.
bool use_empty()
Returns true if this operation has no uses.
This interface provides the hooks into the inlining interface.
Operation * getParentOp()
Return the parent operation this region is attached to.
static void collectCallOps(iterator_range< Region::iterator > blocks, CallGraphNode *sourceNode, CallGraph &cg, SymbolTableCollection &symbolTable, SmallVectorImpl< ResolvedCall > &calls, bool traverseNestedCGNodes)
Collect all of the callable operations within the given range of blocks.
static llvm::ManagedStatic< PassManagerOptions > options
This class represents a single callable in the callgraph.
bool isExternal() const
Returns true if this node is an external node.
This class provides the API for ops that are known to be isolated from above.
void printAsTextualPipeline(raw_ostream &os) const
Prints out the passes of the pass manager as the textual representation of pipelines.
LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream=llvm::errs())
Parse the textual representation of a pass pipeline, adding the result to 'pm' on success...
static std::string debugString(T &&op)
MLIRContext is the top-level object for a collection of MLIR operations.
void addPass(std::unique_ptr< Pass > pass)
Add the given pass to this pass manager.
static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList, CallGraphSCC ¤tSCC)
Attempt to inline calls within the given scc.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers...
OperationName getName()
The name of an operation is the key identifier for it.
This class represents a specific symbol use.
Region * getCallableRegion() const
Returns the callable region this node represents.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
std::unique_ptr< Pass > createCanonicalizerPass()
Creates an instance of the Canonicalizer pass, configured with default settings (which can be overrid...
This class represents a pass manager that runs passes on either a specific operation type...
bool hasChildren() const
Returns true if this node has any child edges.
static bool shouldInline(ResolvedCall &resolvedCall)
Returns true if the given call should be inlined.